// SPDX-FileCopyrightText: 2024 Himbeer // // SPDX-License-Identifier: GPL-3.0-or-later package main import ( "fmt" "io" "strconv" "strings" ) var funcRegIndex int func allocReg() string { defer func() { funcRegIndex++ }() return fmt.Sprintf("%%local%d", funcRegIndex) } func resetRegs() { funcRegIndex = 0 } var globalRegIndex int var stringLiterals = map[string]string{} func allocGlobal() string { defer func() { globalRegIndex++ }() return fmt.Sprintf("$global%d", globalRegIndex) } func allocString(s string, w io.Writer) string { if reg, ok := stringLiterals[s]; ok { return reg } reg := allocGlobal() stringLiterals[s] = reg fmt.Fprintf(w, "data %s = { %q }", reg, s) return reg } type invalidToplevel struct { got expression } func (i invalidToplevel) Error() string { return fmt.Sprintf("%d: expected top-level declaration, got %s", i.got.line(), i.got) } func generate(root *rootExpr, w io.Writer, errs chan<- error) { defer wg.Done() for _, toplevel := range root.toplevels { if err := generateToplevel(toplevel, w); err != nil { errs <- err } } } func generateToplevel(toplevel expression, w io.Writer) error { switch expr := toplevel.(type) { case *functionExpr: return generateFunction(expr, w) } return invalidToplevel{got: toplevel} } func generateFunction(function *functionExpr, w io.Writer) error { resetRegs() if function.link != defaultLinkage { fmt.Fprintf(w, "%s ", function.link) } fmt.Fprintf(w, "function %s $%s", function.returnType, function.name) fmt.Fprintf(w, "(%s) {\n", function.params) fmt.Fprintln(w, "@start") if err := generateBlock(function.blk, w); err != nil { return err } fmt.Fprintf(w, "}\n") return nil } func generateBlock(blk *blockExpr, w io.Writer) error { for _, stmt := range blk.stmts { if err := generateStatement(stmt, w); err != nil { return err } } return nil } func generateStatement(stmt statementExpr, w io.Writer) error { switch s := stmt.(type) { case *returnStmt: return generateReturnStmt(s, w) } return nil } func generateReturnStmt(stmt *returnStmt, w io.Writer) error { value, err := stmt.value.generate(w) if err != nil { return err } fmt.Fprintln(w, "ret", value) return nil } func (e *equalityExpr) generate(w io.Writer) (string, error) { lhs, err := e.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range e.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (e *equalityRhs) generate(in string, w io.Writer) (string, error) { instruction := "c" switch e.op { case equalTo: instruction += "eq" case notEqualTo: instruction += "ne" } instruction += "w" rhs, err := e.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (c *comparisonExpr) generate(w io.Writer) (string, error) { lhs, err := c.lhs.generate(w) if err != nil { return "", err } out := lhs if c.rhs != nil { out, err = c.rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (c *comparisonRhs) generate(in string, w io.Writer) (string, error) { instruction := "cs" switch c.op { case lessThan: instruction += "lt" case lessThanOrEqualTo: instruction += "le" case greaterThan: instruction += "gt" case greaterThanOrEqualTo: instruction += "ge" } instruction += "w" rhs, err := c.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (t *termExpr) generate(w io.Writer) (string, error) { lhs, err := t.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range t.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (t *termRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch t.op { case shiftLeft: instruction += "shl" case shiftRight: instruction += "shr" } rhs, err := t.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (n *numeralExpr) generate(w io.Writer) (string, error) { lhs, err := n.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range n.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (n *numeralRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch n.op { case add: instruction += "add" case subtract: instruction += "sub" } rhs, err := n.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (f *factorExpr) generate(w io.Writer) (string, error) { lhs, err := f.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range f.rhs { out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (f *factorRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch f.op { case multiply: instruction += "mul" case divide: instruction += "div" } rhs, err := f.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (u *unaryExpr) generate(w io.Writer) (string, error) { value, err := u.value.generate(w) if err != nil { return "", err } if u.op == unaryIdentity { return value, nil } instruction, rhs := "", "" switch u.op { case negate: instruction += "neg" case invertLogical: instruction += "ceqw" rhs = ", 0" case invertBits: instruction += "xor" rhs = ", 0xffffffff" } out := allocReg() fmt.Fprintf(w, "%s =w %s %s%s\n", out, instruction, value, rhs) return out, nil } func (g *groupingExpr) generate(w io.Writer) (string, error) { return g.inner.generate(w) } func (s *stringExpr) generate(w io.Writer) (string, error) { value := strings.Join(s.segments, "") out := allocString(value, w) return out, nil } func (n *numberExpr) generate(w io.Writer) (string, error) { base := 10 switch { case strings.HasPrefix(n.s, "0x"): base = 16 case strings.HasPrefix(n.s, "0o"): base = 8 case strings.HasPrefix(n.s, "0b"): base = 2 } s := n.s if base != 10 { s = s[2:] } num, err := strconv.ParseInt(s, base, 64) if err != nil { return "", err } return fmt.Sprintf("%d", num), nil }