// SPDX-FileCopyrightText: 2024 Himbeer // // SPDX-License-Identifier: AGPL-3.0-or-later const builtin = @import("builtin"); const std = @import("std"); const TrapFrame = @import("TrapFrame.zig"); const paging = @import("paging.zig"); const riscv = @import("riscv.zig"); const time = @import("sbi/time.zig"); const Allocator = std.mem.Allocator; const elf = std.elf; pub const schedule_interval_millis = 1; pub var list = std.mem.zeroInit(std.DoublyLinkedList(Info), .{}); var next_pid: u16 = 1; const num_stack_pages = 2; pub const Error = error{ EmptySchedule, NoInit, TooManyThreads, }; pub const ExeError = error{ TooSmall, BadEndian, BadArch, BadBitLen, NotStaticExe, LengthOutOfBounds, BranchPerms, WritableCode, }; pub const State = enum(u8) { waiting, active, sleeping, terminated, }; pub const Info = struct { allocator: Allocator, id: u16, thread_id: usize, trap_frame: TrapFrame, sections: std.ArrayList([]align(paging.page_size) u8), stack: []align(paging.page_size) u8, pc: usize, page_table: *paging.Table, state: State, pub fn satp(self: *const Info) paging.Satp { return self.page_table.satp(self.id); } pub fn createThread(self: *const Info, allocator: ?Allocator, entry: usize) !*Info { const alloc = allocator orelse self.allocator; var trap_frame = std.mem.zeroInit(TrapFrame, .{}); const stack = try paging.zeroedAlloc(num_stack_pages); errdefer paging.free(stack); const stack_top = @intFromPtr(stack.ptr) + num_stack_pages * paging.page_size; try self.page_table.identityMapRange(@intFromPtr(stack.ptr), stack_top, paging.EntryFlags.userReadWrite); trap_frame.general_purpose_registers[2] = stack_top; const thread_id = std.math.add(usize, self.thread_id, 1) catch { return Error.TooManyThreads; }; const proc = Info{ .allocator = alloc, .id = self.id, .thread_id = thread_id, .trap_frame = trap_frame, .pages = self.pages, .stack = stack, .pc = entry, .cleanup_hook = null, .page_table = self.page_table, .state = .waiting, }; const proc_node = try alloc.create(std.DoublyLinkedList(Info).Node); proc_node.data = proc; list.prepend(proc_node); return &proc_node.data; } pub fn terminate( self: *Info, ) void { riscv.satp.write(paging.kmem.satp(0)); var node = list.first; while (node) |proc_node| : (node = proc_node.next) { if (self.shouldTerminate(&proc_node.data)) { if (proc_node.data.thread_id != self.thread_id) { proc_node.data.terminate(); } list.remove(proc_node); self.allocator.destroy(proc_node); } } paging.free(self.stack); if (self.thread_id == 0) { self.page_table.unmap(); paging.free(self.page_table); self.freeSections(); } } fn freeSections(self: *Info) void { defer self.sections.deinit(); for (self.sections.items) |section| { paging.free(section); } } pub fn allowResume(self: *Info) void { self.pc += 4; // Skip ecall instruction self.state = .waiting; } pub fn shouldTerminate(self: *const Info, candidate: *const Info) bool { return candidate.id == self.id and self.shouldTerminateThread(candidate); } fn shouldTerminateThread(self: *const Info, candidate: *const Info) bool { return candidate.thread_id == self.thread_id or self.thread_id == 0; } }; pub fn next() ?*Info { if (list.popFirst()) |info| { list.append(info); if (info.data.state != .waiting) return next(); return &info.data; } return null; } pub fn schedule() !noreturn { if (next()) |proc| { try time.interruptInMillis(schedule_interval_millis); switchTo(proc); } return Error.EmptySchedule; } pub fn switchTo(proc: *Info) noreturn { proc.state = .active; var sstatus = riscv.sstatus.read(); sstatus.previous_privilege = .user; sstatus.user_interrupt_enable = 0; sstatus.supervisor_interrupt_enable = 0; sstatus.user_prior_interrupt_enable = 1; sstatus.supervisor_prior_interrupt_enable = 1; riscv.sstatus.write(sstatus); riscv.sscratch.write(@intFromPtr(&proc.trap_frame)); riscv.sepc.write(proc.pc); riscv.satp.write(proc.satp()); // Probably not always needed. Let's not take the risk for now. asm volatile ( \\ sfence.vma ); asm volatile ( \\ csrr t6, sscratch \\ \\ ld x1, 8(t6) \\ ld x2, 16(t6) \\ ld x3, 24(t6) \\ ld x4, 32(t6) \\ ld x5, 40(t6) \\ ld x6, 48(t6) \\ ld x7, 56(t6) \\ ld x8, 64(t6) \\ ld x9, 72(t6) \\ ld x10, 80(t6) \\ ld x11, 88(t6) \\ ld x12, 96(t6) \\ ld x13, 104(t6) \\ ld x14, 112(t6) \\ ld x15, 120(t6) \\ ld x16, 128(t6) \\ ld x17, 136(t6) \\ ld x18, 144(t6) \\ ld x19, 152(t6) \\ ld x20, 160(t6) \\ ld x21, 168(t6) \\ ld x22, 176(t6) \\ ld x23, 184(t6) \\ ld x24, 192(t6) \\ ld x25, 200(t6) \\ ld x26, 208(t6) \\ ld x27, 216(t6) \\ ld x28, 224(t6) \\ ld x29, 232(t6) \\ ld x30, 240(t6) \\ ld x31, 248(t6) \\ \\ sret ); unreachable; } const HdrBuf = *align(@alignOf(elf.Elf64_Ehdr)) const [@sizeOf(elf.Elf64_Ehdr)]u8; pub fn create(allocator: Allocator, elf_buf: []align(@alignOf(elf.Elf64_Ehdr)) const u8) !*Info { if (elf_buf.len < @sizeOf(elf.Elf64_Ehdr)) return ExeError.TooSmall; const hdr_buf: HdrBuf = elf_buf[0..@sizeOf(elf.Elf64_Ehdr)]; const hdr = try elf.Header.parse(@ptrCast(hdr_buf)); try validateElfHeader(hdr, hdr_buf); const procmem: *paging.Table = @ptrCast(try paging.zeroedAlloc(1)); errdefer paging.free(procmem); try procmem.mapKernel(); const parse_source = std.io.fixedBufferStream(elf_buf); var sections = std.ArrayList([]align(paging.page_size) u8).init(allocator); var it = hdr.program_header_iterator(parse_source); while (try it.next()) |phdr| { if (phdr.p_type != elf.PT_LOAD) continue; if (phdr.p_filesz == 0 or phdr.p_memsz == 0) continue; if (phdr.p_offset + phdr.p_filesz >= elf_buf.len) { return ExeError.LengthOutOfBounds; } const offset = paging.offsetOf(phdr.p_vaddr); const memsz_aligned = std.mem.alignForward(usize, offset + phdr.p_memsz, paging.page_size); const num_pages = @divExact(memsz_aligned, paging.page_size); const pages = try paging.zeroedAlloc(num_pages); errdefer paging.free(pages); try sections.append(pages); const sz = @min(phdr.p_filesz, phdr.p_memsz); @memcpy(pages[offset .. offset + sz], elf_buf[phdr.p_offset .. phdr.p_offset + sz]); for (0..num_pages) |page| { const vaddr = phdr.p_vaddr + page * paging.page_size; const paddr = @intFromPtr(pages.ptr) + page * paging.page_size; const flags = paging.EntryFlags{ .valid = 1, .read = @bitCast(phdr.p_flags & elf.PF_R != 0), .write = @bitCast(phdr.p_flags & elf.PF_W != 0), .exec = @bitCast(phdr.p_flags & elf.PF_X != 0), .user = 1, .global = 0, .accessed = 1, .dirty = @bitCast(phdr.p_flags & elf.PF_W != 0), }; if (!@bitCast(flags.read) and !@bitCast(flags.write) and !@bitCast(flags.exec)) { return ExeError.BranchPerms; } if (@bitCast(flags.write) and @bitCast(flags.exec)) { return ExeError.WritableCode; } try procmem.map(vaddr, paddr, flags, 0); } } const stack = try paging.zeroedAlloc(num_stack_pages); errdefer paging.free(stack); const stack_top = @intFromPtr(stack.ptr) + num_stack_pages * paging.page_size; try procmem.identityMapRange(@intFromPtr(stack.ptr), stack_top, paging.EntryFlags.userReadWrite); var proc = Info{ .allocator = allocator, .id = next_pid, .thread_id = 0, .trap_frame = std.mem.zeroInit(TrapFrame, .{}), .sections = sections, .stack = @ptrCast(stack), .pc = hdr.entry, .page_table = procmem, .state = .waiting, }; proc.trap_frame.general_purpose_registers[2] = stack_top; next_pid += 1; const proc_node = try allocator.create(std.DoublyLinkedList(Info).Node); proc_node.data = proc; list.prepend(proc_node); return &proc_node.data; } pub fn runInit(allocator: Allocator, reader: anytype) !noreturn { var file_name_buffer: [4096]u8 = undefined; var link_name_buffer: [4096]u8 = undefined; var it = std.tar.iterator(reader, .{ .file_name_buffer = file_name_buffer[0..], .link_name_buffer = link_name_buffer[0..], }); const exe = while (try it.next()) |file| { if (std.mem.eql(u8, file.name, "./init")) { break file; } } else return Error.NoInit; const alignment = @alignOf(elf.Elf64_Ehdr); var exe_list = std.ArrayListAligned(u8, alignment).init(allocator); defer exe_list.deinit(); try exe.reader().readAllArrayListAligned(alignment, &exe_list, exe.size); const proc = try create(allocator, exe_list.items); try time.interruptInMillis(schedule_interval_millis); switchTo(proc); } fn validateElfHeader(hdr: elf.Header, hdr_buf: *align(@alignOf(elf.Elf64_Ehdr)) const [@sizeOf(elf.Elf64_Ehdr)]u8) !void { const arch = builtin.cpu.arch; if (hdr.endian != arch.endian()) return ExeError.BadEndian; if (hdr.machine != arch.toElfMachine()) return ExeError.BadArch; if (!hdr.is_64) return ExeError.BadBitLen; const hdr64 = @as(*const elf.Elf64_Ehdr, @ptrCast(hdr_buf)); if (hdr64.e_type != .EXEC) return ExeError.NotStaticExe; } fn usizeFromArg(arg: anytype) usize { return switch (@typeInfo(@TypeOf(arg))) { .Pointer => |ptr| switch (ptr.size) { .Slice => @intFromPtr(arg.ptr), else => @intFromPtr(arg), }, else => arg, }; } pub fn findThread(pid: u16, thread_id: usize) ?*Info { var node = list.first; while (node) |proc_node| : (node = proc_node.next) { if (proc_node.data.id == pid and proc_node.data.thread_id == thread_id) { return &proc_node.data; } } return null; } pub fn mainThread(pid: u16) ?*Info { return findThread(pid, 0); } pub fn latestThread(pid: u16) ?*Info { var latest: ?*Info = null; var node = list.first; while (node) |proc_node| : (node = proc_node.next) { if (proc_node.data.id == pid) { if (latest) |proc| { if (proc_node.data.thread_id > proc.thread_id) { latest = &proc_node.data; } } else latest = &proc_node.data; } } return latest; }