aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHimbeer <himbeer@disroot.org>2024-08-01 18:16:27 +0200
committerHimbeer <himbeer@disroot.org>2024-08-01 18:23:50 +0200
commitf3480ccf9c3bedfb5951f052bb221b9c1074568f (patch)
treedd796b975bc33158371359ca1fe2ff043fa0afec
parent191edb1dd53552c46959a8ab2f3155dfbaef6d11 (diff)
channel: Support receiver filtering
-rw-r--r--src/channel.zig23
-rw-r--r--src/syscall.zig37
2 files changed, 36 insertions, 24 deletions
diff --git a/src/channel.zig b/src/channel.zig
index 6367463..b3084b8 100644
--- a/src/channel.zig
+++ b/src/channel.zig
@@ -14,6 +14,7 @@ pub const Message = struct {
bytes: []const u8,
refcount: usize = 1,
process_filter: u16,
+ sender: u16,
fn addReference(self: *Message) !void {
self.refcount = try std.math.add(usize, self.refcount, 1);
@@ -33,10 +34,10 @@ var alloc: Allocator = undefined;
const Queues = std.AutoArrayHashMap(usize, Messages);
-const Processes = std.AutoArrayHashMap(usize, Queues);
+const Processes = std.AutoArrayHashMap(u16, Queues);
var joined: Processes = undefined;
-pub fn join(pid: usize, id: usize) !void {
+pub fn join(pid: u16, id: usize) !void {
const queues = try joined.getOrPut(pid);
if (!queues.found_existing) {
initProcess(queues.value_ptr);
@@ -48,7 +49,7 @@ pub fn join(pid: usize, id: usize) !void {
}
}
-pub fn leave(pid: usize, id: usize) void {
+pub fn leave(pid: u16, id: usize) void {
const queues = joined.getPtr(pid) orelse return;
freeQueues(queues);
queues.clearAndFree();
@@ -56,11 +57,15 @@ pub fn leave(pid: usize, id: usize) void {
}
// The channel takes ownership of `bytes`.
-pub fn pass(id: usize, receiver: u16, bytes: []const u8) !void {
+pub fn pass(pid: u16, id: usize, receiver: u16, bytes: []const u8) !void {
const message = try alloc.create(Message);
defer message.dropReference();
- message.* = .{ .bytes = bytes, .process_filter = receiver };
+ message.* = .{
+ .bytes = bytes,
+ .process_filter = receiver,
+ .sender = pid,
+ };
var it = joined.iterator();
while (it.next()) |queues| {
@@ -73,7 +78,7 @@ pub fn pass(id: usize, receiver: u16, bytes: []const u8) !void {
}
}
-pub fn receive(pid: usize, id: usize, buffer: []u8) !usize {
+pub fn receive(pid: u16, id: usize, sender: ?*u16, buffer: []u8) !usize {
const queues = joined.getPtr(pid) orelse return Error.NotJoined;
const messages = queues.getPtr(id) orelse return Error.NotJoined;
const message = messages.popFirst() orelse return Error.WouldBlock;
@@ -81,6 +86,12 @@ pub fn receive(pid: usize, id: usize, buffer: []u8) !usize {
defer alloc.destroy(message);
defer message.data.dropReference();
+ if (message.data.process_filter != pid and message.data.process_filter != 0) {
+ return Error.WouldBlock;
+ }
+
+ if (sender) |sender_id| sender_id.* = message.data.sender;
+
const len = @min(buffer.len, message.data.bytes.len);
@memcpy(buffer[0..len], message.data.bytes[0..len]);
diff --git a/src/syscall.zig b/src/syscall.zig
index 46ac631..dd894e8 100644
--- a/src/syscall.zig
+++ b/src/syscall.zig
@@ -12,8 +12,9 @@ const paging = @import("paging.zig");
const process = @import("process.zig");
const riscv = @import("riscv.zig");
-pub const Error = error{
+pub const ArgumentError = error{
ZeroAddressSupplied,
+ OutOfRange,
};
pub const HandleError = error{
@@ -32,24 +33,20 @@ pub fn handler(proc: *process.Info, trap_frame: *TrapFrame) !void {
100007 => trap_frame.setReturnValue(devicesByKind(trap_frame)),
100008 => trap_frame.setReturnValue(join(proc, trap_frame)),
100009 => trap_frame.setReturnValue(leave(proc, trap_frame)),
- 100010 => trap_frame.setReturnValue(pass(trap_frame)),
+ 100010 => trap_frame.setReturnValue(pass(proc, trap_frame)),
100011 => trap_frame.setReturnValue(receive(proc, trap_frame)),
else => return HandleError.UnknownSyscall,
}
}
-pub const ErrorNameError = error{ErrorCodeOutOfRange};
-
// errorName(code: u16, buffer: [*]u8, len: usize) !usize
fn errorName(trap_frame: *const TrapFrame) !usize {
const code_wide = trap_frame.general_purpose_registers[10];
const buffer_opt: ?[*]u8 = @ptrFromInt(trap_frame.general_purpose_registers[11]);
- const buffer_ptr = buffer_opt orelse return Error.ZeroAddressSupplied;
+ const buffer_ptr = buffer_opt orelse return ArgumentError.ZeroAddressSupplied;
const len = trap_frame.general_purpose_registers[12];
- const code = std.math.cast(u16, code_wide) orelse {
- return ErrorNameError.ErrorCodeOutOfRange;
- };
+ const code = std.math.cast(u16, code_wide) orelse return ArgumentError.OutOfRange;
const buffer = buffer_ptr[0..len];
if (code == 0) return 0;
@@ -91,7 +88,7 @@ fn launch(trap_frame: *const TrapFrame) !usize {
const alignment = @alignOf(std.elf.Elf64_Ehdr);
const bytes_addr = trap_frame.general_purpose_registers[10];
const bytes_opt: ?[*]const u8 = @ptrFromInt(bytes_addr);
- const bytes_noalign = bytes_opt orelse return Error.ZeroAddressSupplied;
+ const bytes_noalign = bytes_opt orelse return ArgumentError.ZeroAddressSupplied;
const bytes_ptr = try std.math.alignCast(alignment, bytes_noalign);
const len = trap_frame.general_purpose_registers[11];
@@ -177,24 +174,28 @@ fn leave(proc: *const process.Info, trap_frame: *const TrapFrame) void {
channel.leave(proc.id, id);
}
-// pass(channel: usize, bytes: [*]const u8, len: usize) !void
-fn pass(trap_frame: *const TrapFrame) !void {
+// pass(channel: usize, receiver: u16, bytes: [*]const u8, len: usize) !void
+fn pass(proc: *const process.Info, trap_frame: *const TrapFrame) !void {
const id = trap_frame.general_purpose_registers[10];
- const bytes_ptr: [*]const u8 = @ptrFromInt(trap_frame.general_purpose_registers[11]);
- const len = trap_frame.general_purpose_registers[12];
+ const receiver_wide = trap_frame.general_purpose_registers[11];
+ const bytes_ptr: [*]const u8 = @ptrFromInt(trap_frame.general_purpose_registers[12]);
+ const len = trap_frame.general_purpose_registers[13];
+ const receiver = std.math.cast(u16, receiver_wide) orelse return ArgumentError.OutOfRange;
const bytes = bytes_ptr[0..len];
+
const copy = try channel.allocator().alloc(u8, bytes.len);
@memcpy(copy, bytes);
- try channel.pass(id, copy);
+ try channel.pass(proc.id, id, receiver, copy);
}
-// receive(channel: usize, buffer: [*]u8, len: usize) !usize
+// receive(channel: usize, sender: ?*u16, buffer: [*]u8, len: usize) !usize
fn receive(proc: *const process.Info, trap_frame: *const TrapFrame) !usize {
const id = trap_frame.general_purpose_registers[10];
- const buffer_ptr: [*]u8 = @ptrFromInt(trap_frame.general_purpose_registers[11]);
- const len = trap_frame.general_purpose_registers[12];
+ const sender: ?*u16 = @ptrFromInt(trap_frame.general_purpose_registers[11]);
+ const buffer_ptr: [*]u8 = @ptrFromInt(trap_frame.general_purpose_registers[12]);
+ const len = trap_frame.general_purpose_registers[13];
const buffer = buffer_ptr[0..len];
- return channel.receive(proc.id, id, buffer);
+ return channel.receive(proc.id, id, sender, buffer);
}