diff options
author | Himbeer <himbeer@disroot.org> | 2024-08-01 18:16:27 +0200 |
---|---|---|
committer | Himbeer <himbeer@disroot.org> | 2024-08-01 18:23:50 +0200 |
commit | f3480ccf9c3bedfb5951f052bb221b9c1074568f (patch) | |
tree | dd796b975bc33158371359ca1fe2ff043fa0afec | |
parent | 191edb1dd53552c46959a8ab2f3155dfbaef6d11 (diff) |
channel: Support receiver filtering
-rw-r--r-- | src/channel.zig | 23 | ||||
-rw-r--r-- | src/syscall.zig | 37 |
2 files changed, 36 insertions, 24 deletions
diff --git a/src/channel.zig b/src/channel.zig index 6367463..b3084b8 100644 --- a/src/channel.zig +++ b/src/channel.zig @@ -14,6 +14,7 @@ pub const Message = struct { bytes: []const u8, refcount: usize = 1, process_filter: u16, + sender: u16, fn addReference(self: *Message) !void { self.refcount = try std.math.add(usize, self.refcount, 1); @@ -33,10 +34,10 @@ var alloc: Allocator = undefined; const Queues = std.AutoArrayHashMap(usize, Messages); -const Processes = std.AutoArrayHashMap(usize, Queues); +const Processes = std.AutoArrayHashMap(u16, Queues); var joined: Processes = undefined; -pub fn join(pid: usize, id: usize) !void { +pub fn join(pid: u16, id: usize) !void { const queues = try joined.getOrPut(pid); if (!queues.found_existing) { initProcess(queues.value_ptr); @@ -48,7 +49,7 @@ pub fn join(pid: usize, id: usize) !void { } } -pub fn leave(pid: usize, id: usize) void { +pub fn leave(pid: u16, id: usize) void { const queues = joined.getPtr(pid) orelse return; freeQueues(queues); queues.clearAndFree(); @@ -56,11 +57,15 @@ pub fn leave(pid: usize, id: usize) void { } // The channel takes ownership of `bytes`. -pub fn pass(id: usize, receiver: u16, bytes: []const u8) !void { +pub fn pass(pid: u16, id: usize, receiver: u16, bytes: []const u8) !void { const message = try alloc.create(Message); defer message.dropReference(); - message.* = .{ .bytes = bytes, .process_filter = receiver }; + message.* = .{ + .bytes = bytes, + .process_filter = receiver, + .sender = pid, + }; var it = joined.iterator(); while (it.next()) |queues| { @@ -73,7 +78,7 @@ pub fn pass(id: usize, receiver: u16, bytes: []const u8) !void { } } -pub fn receive(pid: usize, id: usize, buffer: []u8) !usize { +pub fn receive(pid: u16, id: usize, sender: ?*u16, buffer: []u8) !usize { const queues = joined.getPtr(pid) orelse return Error.NotJoined; const messages = queues.getPtr(id) orelse return Error.NotJoined; const message = messages.popFirst() orelse return Error.WouldBlock; @@ -81,6 +86,12 @@ pub fn receive(pid: usize, id: usize, buffer: []u8) !usize { defer alloc.destroy(message); defer message.data.dropReference(); + if (message.data.process_filter != pid and message.data.process_filter != 0) { + return Error.WouldBlock; + } + + if (sender) |sender_id| sender_id.* = message.data.sender; + const len = @min(buffer.len, message.data.bytes.len); @memcpy(buffer[0..len], message.data.bytes[0..len]); diff --git a/src/syscall.zig b/src/syscall.zig index 46ac631..dd894e8 100644 --- a/src/syscall.zig +++ b/src/syscall.zig @@ -12,8 +12,9 @@ const paging = @import("paging.zig"); const process = @import("process.zig"); const riscv = @import("riscv.zig"); -pub const Error = error{ +pub const ArgumentError = error{ ZeroAddressSupplied, + OutOfRange, }; pub const HandleError = error{ @@ -32,24 +33,20 @@ pub fn handler(proc: *process.Info, trap_frame: *TrapFrame) !void { 100007 => trap_frame.setReturnValue(devicesByKind(trap_frame)), 100008 => trap_frame.setReturnValue(join(proc, trap_frame)), 100009 => trap_frame.setReturnValue(leave(proc, trap_frame)), - 100010 => trap_frame.setReturnValue(pass(trap_frame)), + 100010 => trap_frame.setReturnValue(pass(proc, trap_frame)), 100011 => trap_frame.setReturnValue(receive(proc, trap_frame)), else => return HandleError.UnknownSyscall, } } -pub const ErrorNameError = error{ErrorCodeOutOfRange}; - // errorName(code: u16, buffer: [*]u8, len: usize) !usize fn errorName(trap_frame: *const TrapFrame) !usize { const code_wide = trap_frame.general_purpose_registers[10]; const buffer_opt: ?[*]u8 = @ptrFromInt(trap_frame.general_purpose_registers[11]); - const buffer_ptr = buffer_opt orelse return Error.ZeroAddressSupplied; + const buffer_ptr = buffer_opt orelse return ArgumentError.ZeroAddressSupplied; const len = trap_frame.general_purpose_registers[12]; - const code = std.math.cast(u16, code_wide) orelse { - return ErrorNameError.ErrorCodeOutOfRange; - }; + const code = std.math.cast(u16, code_wide) orelse return ArgumentError.OutOfRange; const buffer = buffer_ptr[0..len]; if (code == 0) return 0; @@ -91,7 +88,7 @@ fn launch(trap_frame: *const TrapFrame) !usize { const alignment = @alignOf(std.elf.Elf64_Ehdr); const bytes_addr = trap_frame.general_purpose_registers[10]; const bytes_opt: ?[*]const u8 = @ptrFromInt(bytes_addr); - const bytes_noalign = bytes_opt orelse return Error.ZeroAddressSupplied; + const bytes_noalign = bytes_opt orelse return ArgumentError.ZeroAddressSupplied; const bytes_ptr = try std.math.alignCast(alignment, bytes_noalign); const len = trap_frame.general_purpose_registers[11]; @@ -177,24 +174,28 @@ fn leave(proc: *const process.Info, trap_frame: *const TrapFrame) void { channel.leave(proc.id, id); } -// pass(channel: usize, bytes: [*]const u8, len: usize) !void -fn pass(trap_frame: *const TrapFrame) !void { +// pass(channel: usize, receiver: u16, bytes: [*]const u8, len: usize) !void +fn pass(proc: *const process.Info, trap_frame: *const TrapFrame) !void { const id = trap_frame.general_purpose_registers[10]; - const bytes_ptr: [*]const u8 = @ptrFromInt(trap_frame.general_purpose_registers[11]); - const len = trap_frame.general_purpose_registers[12]; + const receiver_wide = trap_frame.general_purpose_registers[11]; + const bytes_ptr: [*]const u8 = @ptrFromInt(trap_frame.general_purpose_registers[12]); + const len = trap_frame.general_purpose_registers[13]; + const receiver = std.math.cast(u16, receiver_wide) orelse return ArgumentError.OutOfRange; const bytes = bytes_ptr[0..len]; + const copy = try channel.allocator().alloc(u8, bytes.len); @memcpy(copy, bytes); - try channel.pass(id, copy); + try channel.pass(proc.id, id, receiver, copy); } -// receive(channel: usize, buffer: [*]u8, len: usize) !usize +// receive(channel: usize, sender: ?*u16, buffer: [*]u8, len: usize) !usize fn receive(proc: *const process.Info, trap_frame: *const TrapFrame) !usize { const id = trap_frame.general_purpose_registers[10]; - const buffer_ptr: [*]u8 = @ptrFromInt(trap_frame.general_purpose_registers[11]); - const len = trap_frame.general_purpose_registers[12]; + const sender: ?*u16 = @ptrFromInt(trap_frame.general_purpose_registers[11]); + const buffer_ptr: [*]u8 = @ptrFromInt(trap_frame.general_purpose_registers[12]); + const len = trap_frame.general_purpose_registers[13]; const buffer = buffer_ptr[0..len]; - return channel.receive(proc.id, id, buffer); + return channel.receive(proc.id, id, sender, buffer); } |