aboutsummaryrefslogtreecommitdiff
path: root/src/channel.zig
blob: 061331e1198b5bb3f1c650f6fc4f92fdb96d47de (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// SPDX-FileCopyrightText: 2024 Himbeer <himbeer@disroot.org>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

const std = @import("std");
const Allocator = std.mem.Allocator;

pub const Error = error{
    NotJoined,
    WouldBlock,
};

pub const Message = struct {
    bytes: []const u8,
    refcount: usize = 1,
    process_filter: u16,
    sender: u16,

    fn addReference(self: *Message) !void {
        self.refcount = try std.math.add(usize, self.refcount, 1);
    }

    fn dropReference(self: *Message) void {
        self.refcount -= 1;
        if (self.refcount == 0) {
            defer alloc.free(self.bytes);
            defer alloc.destroy(self);
        }
    }
};

pub const Messages = std.TailQueue(*Message);
var alloc: Allocator = undefined;

const Queues = std.AutoArrayHashMap(usize, Messages);

const Processes = std.AutoArrayHashMap(u16, Queues);
var joined: Processes = undefined;

pub fn join(pid: u16, id: usize) !void {
    const queues = try joined.getOrPut(pid);
    if (!queues.found_existing) {
        initProcess(queues.value_ptr);
    }

    const messages = try queues.value_ptr.getOrPut(id);
    if (!messages.found_existing) {
        initQueue(messages.value_ptr);
    }
}

pub fn leave(pid: u16, id: usize) void {
    const queues = joined.getPtr(pid) orelse return;
    freeQueues(queues);
    queues.clearAndFree();
    _ = queues.swapRemove(id);
}

// The channel takes ownership of `bytes`.
pub fn pass(pid: u16, id: usize, receiver: u16, identify: bool, bytes: []const u8) !void {
    const message = try alloc.create(Message);
    defer message.dropReference();

    message.* = .{
        .bytes = bytes,
        .process_filter = receiver,
        .sender = if (identify) pid else 0,
    };

    var it = joined.iterator();
    while (it.next()) |queues| {
        if (queues.value_ptr.getPtr(id)) |messages| {
            try message.addReference();
            errdefer message.dropReference();

            try enqueue(messages, message);
        }
    }
}

pub fn receive(pid: u16, id: usize, sender: ?*u16, buffer: []u8) !usize {
    const queues = joined.getPtr(pid) orelse return Error.NotJoined;
    const messages = queues.getPtr(id) orelse return Error.NotJoined;
    const message = messages.popFirst() orelse return Error.WouldBlock;

    defer alloc.destroy(message);
    defer message.data.dropReference();

    if (message.data.process_filter != pid and message.data.process_filter != 0) {
        return Error.WouldBlock;
    }

    if (sender) |sender_id| sender_id.* = message.data.sender;

    const len = @min(buffer.len, message.data.bytes.len);
    @memcpy(buffer[0..len], message.data.bytes[0..len]);

    return len;
}

fn initQueue(messages: *Messages) void {
    messages.* = .{};
}

fn initProcess(queues: *Queues) void {
    queues.* = Queues.init(alloc);
}

fn freeQueues(queues: *Queues) void {
    var it = queues.iterator();
    while (it.next()) |messages| {
        freeMessages(messages.value_ptr);
    }
}

fn freeMessages(messages: *Messages) void {
    while (messages.popFirst()) |message| {
        message.data.dropReference();
    }
}

fn enqueue(messages: *Messages, message: *Message) !void {
    const node = try alloc.create(Messages.Node);
    node.data = message;
    messages.append(node);
}

pub fn init(with_allocator: Allocator) void {
    joined = Processes.init(with_allocator);
    alloc = with_allocator;
}

pub fn allocator() Allocator {
    return alloc;
}