aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHimbeer <himbeer@disroot.org>2024-07-30 17:41:27 +0200
committerHimbeer <himbeer@disroot.org>2024-07-30 17:41:27 +0200
commit48d6fa3e80193a3cc735f5e0b0390a18a7bf5a83 (patch)
tree23468ff08749507d70153e63f91a47cbe8926be2
parentee824544423fa55884cf0f716a977cd123c6fedf (diff)
channel: Implement joining and leaving
-rw-r--r--src/lib/channel.zig90
-rw-r--r--src/lib/syscall.zig16
2 files changed, 89 insertions, 17 deletions
diff --git a/src/lib/channel.zig b/src/lib/channel.zig
index de02f2b..30d8a5a 100644
--- a/src/lib/channel.zig
+++ b/src/lib/channel.zig
@@ -9,33 +9,73 @@ pub const Error = error{
WouldBlock,
};
-pub const Messages = std.TailQueue([]const u8);
-var message_allocator: Allocator = undefined;
+pub const Message = struct {
+ bytes: []const u8,
+ refcount: usize = 0,
+
+ fn clone(self: *Message) !void {
+ self.refcount = try std.math.add(self.refcount, 1);
+ }
+
+ fn deinit(self: *Message) void {
+ self.refcount -= 1;
+ if (self.refcount == 0) {
+ defer alloc.free(self.bytes);
+ defer alloc.destroy(self);
+ }
+ }
+};
+
+pub const Messages = std.TailQueue(*Message);
+var alloc: Allocator = undefined;
const Queues = std.AutoArrayHashMap(usize, Messages);
-var queues: Queues = undefined;
+var unjoined_queues: Queues = undefined;
+
+const Processes = std.AutoArrayHashMap(usize, Queues);
+var joined: Processes = undefined;
+
+pub fn join(pid: usize, id: usize) !void {
+ const queues = try joined.getOrPut(pid);
+ if (!queues.found_existing) {
+ initProcess(queues.value_ptr);
+ }
+
+ const messages = try queues.value_ptr.getOrPut(id);
+ if (!messages.found_existing) {
+ initQueue(messages.value_ptr);
+ }
+}
+
+pub fn leave(pid: usize, id: usize) void {
+ const queues = joined.getPtr(pid) orelse return;
+ freeQueues(queues);
+ queues.clearAndFree();
+ _ = queues.swapRemove(id);
+}
// The channel takes ownership of `bytes`.
pub fn pass(channel: usize, bytes: []const u8) !void {
- const entry = try queues.getOrPut(channel);
+ const entry = try unjoined_queues.getOrPut(channel);
if (!entry.found_existing) {
initQueue(entry.value_ptr);
}
- const node = try message_allocator.create(Messages.Node);
- node.data = bytes;
+ const node = try alloc.create(Messages.Node);
+ node.data = try alloc.create(Message);
+ node.data.* = .{ .bytes = bytes };
entry.value_ptr.append(node);
}
pub fn receive(channel: usize, buffer: []u8) !usize {
- const messages = queues.getPtr(channel) orelse return Error.WouldBlock;
+ const messages = unjoined_queues.getPtr(channel) orelse return Error.WouldBlock;
const message = messages.popFirst() orelse return Error.WouldBlock;
- defer message_allocator.free(message.data);
- defer message_allocator.destroy(message);
+ defer alloc.destroy(message);
+ defer message.data.deinit();
- const len = @min(buffer.len, message.data.len);
- @memcpy(buffer[0..len], message.data[0..len]);
+ const len = @min(buffer.len, message.data.bytes.len);
+ @memcpy(buffer[0..len], message.data.bytes[0..len]);
return len;
}
@@ -44,11 +84,29 @@ fn initQueue(messages: *Messages) void {
messages.* = .{};
}
-pub fn init(allocator: Allocator) void {
- queues = Queues.init(allocator);
- message_allocator = allocator;
+fn initProcess(queues: *Queues) void {
+ queues.* = Queues.init(alloc);
+}
+
+fn freeQueues(queues: *Queues) void {
+ var it = queues.iterator();
+ while (it.next()) |messages| {
+ freeMessages(messages.value_ptr);
+ }
+}
+
+fn freeMessages(messages: *Messages) void {
+ while (messages.popFirst()) |message| {
+ message.data.deinit();
+ }
+}
+
+pub fn init(with_allocator: Allocator) void {
+ unjoined_queues = Queues.init(with_allocator);
+ joined = Processes.init(with_allocator);
+ alloc = with_allocator;
}
-pub fn messageAllocator() Allocator {
- return message_allocator;
+pub fn allocator() Allocator {
+ return alloc;
}
diff --git a/src/lib/syscall.zig b/src/lib/syscall.zig
index 03a7979..97dc811 100644
--- a/src/lib/syscall.zig
+++ b/src/lib/syscall.zig
@@ -32,6 +32,8 @@ pub fn handler(proc: *process.Info, trap_frame: *TrapFrame) !void {
100006 => trap_frame.setReturnValue(threadId(proc)),
100007 => trap_frame.setReturnValue(rawUserinit(trap_frame)),
100008 => trap_frame.setReturnValue(devicesByKind(trap_frame)),
+ 100009 => trap_frame.setReturnValue(join(proc, trap_frame)),
+ 100010 => trap_frame.setReturnValue(leave(proc, trap_frame)),
100011 => trap_frame.setReturnValue(pass(trap_frame)),
100012 => trap_frame.setReturnValue(receive(trap_frame)),
else => return HandleError.UnknownSyscall,
@@ -171,6 +173,18 @@ fn devicesByKind(trap_frame: *const TrapFrame) !usize {
return i;
}
+// join(channel_id: usize) !void
+fn join(proc: *process.Info, trap_frame: *const TrapFrame) !void {
+ const id = trap_frame.general_purpose_registers[10];
+ return channel.join(proc.id, id);
+}
+
+// leave(channel_id: usize) void
+fn leave(proc: *process.Info, trap_frame: *const TrapFrame) void {
+ const id = trap_frame.general_purpose_registers[10];
+ channel.leave(proc.id, id);
+}
+
// pass(channel_id: usize, bytes: [*]const u8, len: usize) !void
fn pass(trap_frame: *const TrapFrame) !void {
const id = trap_frame.general_purpose_registers[10];
@@ -178,7 +192,7 @@ fn pass(trap_frame: *const TrapFrame) !void {
const len = trap_frame.general_purpose_registers[12];
const bytes = bytes_ptr[0..len];
- const copy = try channel.messageAllocator().alloc(u8, bytes.len);
+ const copy = try channel.allocator().alloc(u8, bytes.len);
@memcpy(copy, bytes);
try channel.pass(id, copy);
}