// SPDX-FileCopyrightText: 2024 Himbeer // // SPDX-License-Identifier: GPL-3.0-or-later package main import ( "fmt" "io" "strconv" "strings" ) var currentFunc string var funcRegIndex int func allocReg() string { defer func() { funcRegIndex++ }() return fmt.Sprintf("%%local%d", funcRegIndex) } func resetRegs() { funcRegIndex = 0 } var globalRegIndex int var stringLiterals = map[string]string{} func allocGlobal() string { defer func() { globalRegIndex++ }() return fmt.Sprintf("$global%d", globalRegIndex) } func allocString(s string, w io.Writer) string { if reg, ok := stringLiterals[s]; ok { return reg } reg := allocGlobal() stringLiterals[s] = reg return reg } func generateStringLiterals(w io.Writer) { for s, reg := range stringLiterals { fmt.Fprintf(w, "data %s = { b %q }\n", reg, s) } } type invalidToplevel struct { got expression } func (i invalidToplevel) Error() string { return fmt.Sprintf("%d: expected top-level declaration, got %s", i.got.line(), i.got) } func generate(root *rootExpr, w io.Writer, errs chan<- error) { defer wg.Done() for _, toplevel := range root.toplevels { if err := generateToplevel(toplevel, w); err != nil { errs <- err select {} } } generateStringLiterals(w) } func generateToplevel(toplevel expression, w io.Writer) error { switch expr := toplevel.(type) { case *functionExpr: return generateFunction(expr, w) } return invalidToplevel{got: toplevel} } func generateFunction(function *functionExpr, w io.Writer) error { if isDeclared(function.name, true) { return errAlreadyDeclared{name: function.name, line: function.line()} } resetRegs() currentFunc = function.name defer func() { currentFunc = "" }() funcs[currentFunc] = struct{}{} localConsts[currentFunc] = map[string]string{} localMuts[currentFunc] = map[string]struct{}{} if function.link != defaultLinkage { fmt.Fprintf(w, "%s ", function.link) } fmt.Fprintf(w, "function %s $%s", function.returnType, function.name) fmt.Fprintf(w, "(%s) {\n", function.params) fmt.Fprintln(w, "@start") if err := generateBlock(function.blk, w); err != nil { return err } fmt.Fprintf(w, "}\n") return nil } func generateBlock(blk *blockExpr, w io.Writer) error { for _, stmt := range blk.stmts { if err := generateStatement(stmt, w); err != nil { return err } } return nil } func generateStatement(stmt statementExpr, w io.Writer) error { switch s := stmt.(type) { case *returnStmt: return generateReturnStmt(s, w) case *constStmt: return generateConstStmt(s, false, w) case *mutStmt: return generateMutStmt(s, false, w) } return nil } func generateReturnStmt(stmt *returnStmt, w io.Writer) error { value, err := stmt.value.generate(w) if err != nil { return err } fmt.Fprintln(w, "ret", value) return nil } func generateConstStmt(stmt *constStmt, toplevel bool, w io.Writer) error { if isDeclared(stmt.name, toplevel) { return errAlreadyDeclared{name: stmt.name, line: stmt.line()} } value, err := stmt.initial.generate(w) if err != nil { return err } localConsts[currentFunc][stmt.name] = value fmt.Fprintf(w, "%%%s =w add %s, 0\n", stmt.name, value) return nil } func generateMutStmt(stmt *mutStmt, toplevel bool, w io.Writer) error { if isDeclared(stmt.name, toplevel) { return errAlreadyDeclared{name: stmt.name, line: stmt.line()} } value, err := stmt.initial.generate(w) if err != nil { return err } localMuts[currentFunc][stmt.name] = struct{}{} fmt.Fprintf(w, "%%%s =l alloc4 4\n", stmt.name) fmt.Fprintf(w, "storew %s, %%%s\n", value, stmt.name) return nil } func (e *equalityExpr) generate(w io.Writer) (string, error) { lhs, err := e.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range e.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (e *equalityRhs) generate(in string, w io.Writer) (string, error) { instruction := "c" switch e.op { case equalTo: instruction += "eq" case notEqualTo: instruction += "ne" } instruction += "w" rhs, err := e.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (c *comparisonExpr) generate(w io.Writer) (string, error) { lhs, err := c.lhs.generate(w) if err != nil { return "", err } out := lhs if c.rhs != nil { out, err = c.rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (c *comparisonRhs) generate(in string, w io.Writer) (string, error) { instruction := "cs" switch c.op { case lessThan: instruction += "lt" case lessThanOrEqualTo: instruction += "le" case greaterThan: instruction += "gt" case greaterThanOrEqualTo: instruction += "ge" } instruction += "w" rhs, err := c.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (t *termExpr) generate(w io.Writer) (string, error) { lhs, err := t.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range t.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (t *termRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch t.op { case shiftLeft: instruction += "shl" case shiftRight: instruction += "shr" } rhs, err := t.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (n *numeralExpr) generate(w io.Writer) (string, error) { lhs, err := n.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range n.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (n *numeralRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch n.op { case add: instruction += "add" case subtract: instruction += "sub" } rhs, err := n.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (f *factorExpr) generate(w io.Writer) (string, error) { lhs, err := f.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range f.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (f *factorRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch f.op { case multiply: instruction += "mul" case divide: instruction += "div" } rhs, err := f.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (u *unaryExpr) generate(w io.Writer) (string, error) { value, err := u.value.generate(w) if err != nil { return "", err } if u.op == unaryIdentity { return value, nil } instruction, rhs := "", "" switch u.op { case negate: instruction += "neg" case invertLogical: instruction += "ceqw" rhs = ", 0" case invertBits: instruction += "xor" rhs = ", 0xffffffff" } out := allocReg() fmt.Fprintf(w, "%s =w %s %s%s\n", out, instruction, value, rhs) return out, nil } func (g *groupingExpr) generate(w io.Writer) (string, error) { return g.inner.generate(w) } func (s *stringExpr) generate(w io.Writer) (string, error) { value := strings.Join(s.segments, "") out := allocString(value, w) return out, nil } func (n *numberExpr) generate(w io.Writer) (string, error) { base := 10 switch { case strings.HasPrefix(n.s, "0x"): base = 16 case strings.HasPrefix(n.s, "0o"): base = 8 case strings.HasPrefix(n.s, "0b"): base = 2 } s := n.s if base != 10 { s = s[2:] } num, err := strconv.ParseInt(s, base, 64) if err != nil { return "", err } return fmt.Sprintf("%d", num), nil } func (c *callExpr) generate(w io.Writer) (string, error) { ret := allocReg() fmt.Fprintf(w, "%s =w call $%s(", ret, c.funcName) delim := "" for arg := c.args; arg != nil; arg = arg.next { reg, err := arg.value.generate(w) if err != nil { return "", err } fmt.Fprintf(w, "w %s%s", reg, delim) delim = "," } fmt.Fprintf(w, ")\n") return ret, nil }