// SPDX-FileCopyrightText: 2024 Himbeer // // SPDX-License-Identifier: AGPL-3.0-or-later const std = @import("std"); const Allocator = std.mem.Allocator; pub const Error = error{ NotJoined, WouldBlock, }; pub const Message = struct { bytes: []const u8, refcount: usize = 1, sender: u16, fn addReference(self: *Message) !void { self.refcount = try std.math.add(usize, self.refcount, 1); } fn dropReference(self: *Message) void { self.refcount -= 1; if (self.refcount == 0) { defer alloc.free(self.bytes); defer alloc.destroy(self); } } }; pub const Messages = std.TailQueue(*Message); var alloc: Allocator = undefined; const Queues = std.AutoArrayHashMap(usize, Messages); const Processes = std.AutoArrayHashMap(u16, Queues); var joined: Processes = undefined; pub fn join(pid: u16, id: usize) !void { const queues = try joined.getOrPut(pid); if (!queues.found_existing) { initProcess(queues.value_ptr); } const messages = try queues.value_ptr.getOrPut(id); if (!messages.found_existing) { initQueue(messages.value_ptr); } } pub fn leave(pid: u16, id: usize) void { const queues = joined.getPtr(pid) orelse return; freeQueue(queues, id); _ = queues.swapRemove(id); } pub fn leaveAll(pid: u16) void { const queues = joined.getPtr(pid) orelse return; freeQueues(queues); queues.clearAndFree(); } // The channel takes ownership of `bytes`. pub fn pass(pid: u16, id: usize, receiver: u16, identify: bool, bytes: []const u8) !void { const message = try alloc.create(Message); defer message.dropReference(); message.* = .{ .bytes = bytes, .sender = if (identify) pid else 0, }; if (receiver != 0) { try message.addReference(); errdefer message.dropReference(); try passTo(joined.getPtr(receiver) orelse return, id, message); return; } var it = joined.iterator(); while (it.next()) |queues| { try message.addReference(); errdefer message.dropReference(); try passTo(queues.value_ptr, id, message); } } fn passTo(queues: *Queues, id: usize, message: *Message) !void { if (queues.getPtr(id)) |messages| { try enqueue(messages, message); } } pub fn receive(pid: u16, id: usize, sender: ?*u16, buffer: []u8) !usize { const queues = joined.getPtr(pid) orelse return Error.NotJoined; const messages = queues.getPtr(id) orelse return Error.NotJoined; const message = messages.popFirst() orelse return Error.WouldBlock; defer alloc.destroy(message); defer message.data.dropReference(); if (sender) |sender_id| sender_id.* = message.data.sender; const len = @min(buffer.len, message.data.bytes.len); @memcpy(buffer[0..len], message.data.bytes[0..len]); return len; } fn initQueue(messages: *Messages) void { messages.* = .{}; } fn initProcess(queues: *Queues) void { queues.* = Queues.init(alloc); } fn freeQueues(queues: *Queues) void { var it = queues.iterator(); while (it.next()) |messages| { freeMessages(messages.value_ptr); } } fn freeQueue(queues: *Queues, id: usize) void { const messages = queues.getPtr(id) orelse return; freeMessages(messages); } fn freeMessages(messages: *Messages) void { while (messages.popFirst()) |message| { message.data.dropReference(); } } fn enqueue(messages: *Messages, message: *Message) !void { const node = try alloc.create(Messages.Node); node.data = message; messages.append(node); } pub fn init(with_allocator: Allocator) void { joined = Processes.init(with_allocator); alloc = with_allocator; } pub fn allocator() Allocator { return alloc; }