aboutsummaryrefslogtreecommitdiff
path: root/src/lib/channel.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/channel.zig')
-rw-r--r--src/lib/channel.zig90
1 files changed, 74 insertions, 16 deletions
diff --git a/src/lib/channel.zig b/src/lib/channel.zig
index de02f2b..30d8a5a 100644
--- a/src/lib/channel.zig
+++ b/src/lib/channel.zig
@@ -9,33 +9,73 @@ pub const Error = error{
WouldBlock,
};
-pub const Messages = std.TailQueue([]const u8);
-var message_allocator: Allocator = undefined;
+pub const Message = struct {
+ bytes: []const u8,
+ refcount: usize = 0,
+
+ fn clone(self: *Message) !void {
+ self.refcount = try std.math.add(self.refcount, 1);
+ }
+
+ fn deinit(self: *Message) void {
+ self.refcount -= 1;
+ if (self.refcount == 0) {
+ defer alloc.free(self.bytes);
+ defer alloc.destroy(self);
+ }
+ }
+};
+
+pub const Messages = std.TailQueue(*Message);
+var alloc: Allocator = undefined;
const Queues = std.AutoArrayHashMap(usize, Messages);
-var queues: Queues = undefined;
+var unjoined_queues: Queues = undefined;
+
+const Processes = std.AutoArrayHashMap(usize, Queues);
+var joined: Processes = undefined;
+
+pub fn join(pid: usize, id: usize) !void {
+ const queues = try joined.getOrPut(pid);
+ if (!queues.found_existing) {
+ initProcess(queues.value_ptr);
+ }
+
+ const messages = try queues.value_ptr.getOrPut(id);
+ if (!messages.found_existing) {
+ initQueue(messages.value_ptr);
+ }
+}
+
+pub fn leave(pid: usize, id: usize) void {
+ const queues = joined.getPtr(pid) orelse return;
+ freeQueues(queues);
+ queues.clearAndFree();
+ _ = queues.swapRemove(id);
+}
// The channel takes ownership of `bytes`.
pub fn pass(channel: usize, bytes: []const u8) !void {
- const entry = try queues.getOrPut(channel);
+ const entry = try unjoined_queues.getOrPut(channel);
if (!entry.found_existing) {
initQueue(entry.value_ptr);
}
- const node = try message_allocator.create(Messages.Node);
- node.data = bytes;
+ const node = try alloc.create(Messages.Node);
+ node.data = try alloc.create(Message);
+ node.data.* = .{ .bytes = bytes };
entry.value_ptr.append(node);
}
pub fn receive(channel: usize, buffer: []u8) !usize {
- const messages = queues.getPtr(channel) orelse return Error.WouldBlock;
+ const messages = unjoined_queues.getPtr(channel) orelse return Error.WouldBlock;
const message = messages.popFirst() orelse return Error.WouldBlock;
- defer message_allocator.free(message.data);
- defer message_allocator.destroy(message);
+ defer alloc.destroy(message);
+ defer message.data.deinit();
- const len = @min(buffer.len, message.data.len);
- @memcpy(buffer[0..len], message.data[0..len]);
+ const len = @min(buffer.len, message.data.bytes.len);
+ @memcpy(buffer[0..len], message.data.bytes[0..len]);
return len;
}
@@ -44,11 +84,29 @@ fn initQueue(messages: *Messages) void {
messages.* = .{};
}
-pub fn init(allocator: Allocator) void {
- queues = Queues.init(allocator);
- message_allocator = allocator;
+fn initProcess(queues: *Queues) void {
+ queues.* = Queues.init(alloc);
+}
+
+fn freeQueues(queues: *Queues) void {
+ var it = queues.iterator();
+ while (it.next()) |messages| {
+ freeMessages(messages.value_ptr);
+ }
+}
+
+fn freeMessages(messages: *Messages) void {
+ while (messages.popFirst()) |message| {
+ message.data.deinit();
+ }
+}
+
+pub fn init(with_allocator: Allocator) void {
+ unjoined_queues = Queues.init(with_allocator);
+ joined = Processes.init(with_allocator);
+ alloc = with_allocator;
}
-pub fn messageAllocator() Allocator {
- return message_allocator;
+pub fn allocator() Allocator {
+ return alloc;
}