// SPDX-FileCopyrightText: 2024 Himbeer // // SPDX-License-Identifier: GPL-3.0-or-later package main import ( "fmt" "io" "strconv" "strings" ) var currentFunc string var funcRegIndex int func allocReg() string { defer func() { funcRegIndex++ }() return fmt.Sprintf("%%local%d", funcRegIndex) } func resetRegs() { funcRegIndex = 0 } var globalRegIndex int var stringLiterals = map[string]string{} func allocGlobal() string { defer func() { globalRegIndex++ }() return fmt.Sprintf("$global%d", globalRegIndex) } func allocString(s string, w io.Writer) string { if reg, ok := stringLiterals[s]; ok { return reg } reg := allocGlobal() stringLiterals[s] = reg return reg } func generateStringLiterals(w io.Writer) { for s, reg := range stringLiterals { fmt.Fprintf(w, "data %s = { b %q }\n", reg, s) } } type invalidToplevel struct { got expression } func (i invalidToplevel) Error() string { return fmt.Sprintf("%d: expected top-level declaration, got %s", i.got.line(), i.got) } func generate(root *rootExpr, w io.Writer, errs chan<- error) { defer close(errs) for _, toplevel := range root.toplevels { if err := generateToplevel(toplevel, w); err != nil { errs <- err select {} } } generateStringLiterals(w) } func generateToplevel(toplevel expression, w io.Writer) error { switch expr := toplevel.(type) { case *functionExpr: return generateFunction(expr, w) case *externFuncExpr: return generateExternFunc(expr, w) } return invalidToplevel{got: toplevel} } func generateFunction(function *functionExpr, w io.Writer) error { if ok, _ := isDeclared(function.name, true); ok { return errAlreadyDeclared{ name: function.name, line: function.line(), } } resetRegs() currentFunc = function.name defer func() { currentFunc = "" }() returnType := resolveType(function.returnType) if returnType == nil { return errUndeclared{ name: function.returnType, kind: undeclaredType, line: function.line(), } } funcs[currentFunc] = &funcInfo{returnType: returnType} localConsts[currentFunc] = map[string]*localConst{} localMuts[currentFunc] = map[string]*localMut{} if function.link != defaultLinkage { fmt.Fprintf(w, "%s ", function.link) } fmt.Fprintf(w, "function %s $%s", returnType.qbeABIType(), function.name) fmt.Fprintf(w, "(") if err := generateParam(function.params, function.line(), w); err != nil { return err } fmt.Fprintf(w, ") {\n") fmt.Fprintln(w, "@start") if err := generateBlock(function.blk, w); err != nil { return err } fmt.Fprintf(w, "}\n") return nil } func generateParam(p *paramExpr, line int, w io.Writer) error { if p == nil { return nil } typ := resolveType(p.typ) if typ == nil { return errUndeclared{name: p.typ, kind: undeclaredType, line: line} } funcs[currentFunc].params = append(funcs[currentFunc].params, paramInfo{ typ: typ, }) localConsts[currentFunc][p.name] = &localConst{ typ: typ, } fmt.Fprintf(w, "%s %%%s", typ.qbeABIType(), p.name) if p.next != nil { fmt.Fprintf(w, ", ") if err := generateParam(p.next, line, w); err != nil { return err } } return nil } func generateExternFunc(e *externFuncExpr, w io.Writer) error { if ok, _ := isDeclared(e.name, true); ok { return errAlreadyDeclared{name: e.name, line: e.line()} } currentFunc = e.name defer func() { currentFunc = "" }() returnType := resolveType(e.returnType) funcs[e.name] = &funcInfo{returnType: returnType} if err := generateSignature(e.params, e.line(), w); err != nil { return err } return nil } func generateSignature(s *signatureExpr, line int, w io.Writer) error { if s == nil { return nil } typ := resolveType(s.typ) if typ == nil { return errUndeclared{name: s.typ, kind: undeclaredType, line: line} } funcs[currentFunc].params = append(funcs[currentFunc].params, paramInfo{ typ: typ, }) if s.next != nil { if err := generateSignature(s.next, line, w); err != nil { return err } } return nil } func generateBlock(blk *blockExpr, w io.Writer) error { for _, stmt := range blk.stmts { if err := generateStatement(stmt, w); err != nil { return err } } return nil } func generateStatement(stmt statementExpr, w io.Writer) error { switch s := stmt.(type) { case *returnStmt: return generateReturnStmt(s, w) case *constStmt: return generateConstStmt(s, false, w) case *mutStmt: return generateMutStmt(s, false, w) case *assignStmt: return generateAssignStmt(s, w) case *addAssignStmt: return generateAddAssignStmt(s, w) case *subAssignStmt: return generateSubAssignStmt(s, w) case *mulAssignStmt: return generateMulAssignStmt(s, w) case *divAssignStmt: return generateDivAssignStmt(s, w) case *remAssignStmt: return generateRemAssignStmt(s, w) } return nil } func generateReturnStmt(stmt *returnStmt, w io.Writer) error { returnType := funcs[currentFunc].returnType valueType := stmt.value.cerType() if valueType != returnType { return errTypeMismatch{ expected: returnType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } fmt.Fprintln(w, "ret", value) return nil } func generateConstStmt(stmt *constStmt, toplevel bool, w io.Writer) error { if ok, _ := isDeclared(stmt.name, toplevel); ok { return errAlreadyDeclared{name: stmt.name, line: stmt.line()} } value, err := stmt.initial.generate(w) if err != nil { return err } typ := stmt.initial.cerType() localConsts[currentFunc][stmt.name] = &localConst{ typ: typ, } fmt.Fprintf(w, "%%%s =%s add %s, 0\n", stmt.name, typ.qbeABIType(), value) return nil } func generateMutStmt(stmt *mutStmt, toplevel bool, w io.Writer) error { if ok, _ := isDeclared(stmt.name, toplevel); ok { return errAlreadyDeclared{name: stmt.name, line: stmt.line()} } value, err := stmt.initial.generate(w) if err != nil { return err } typ := stmt.initial.cerType() localMuts[currentFunc][stmt.name] = &localMut{ typ: typ, } fmt.Fprintf(w, "%%%s =l alloc4 4\n", stmt.name) fmt.Fprintf(w, "storew %s, %%%s\n", value, stmt.name) return nil } func generateAssignStmt(stmt *assignStmt, w io.Writer) error { ok, mutable := isDeclared(stmt.name, false) if !ok { return errUndeclared{name: stmt.name, line: stmt.line()} } if !mutable { return errImmutable{name: stmt.name, line: stmt.line()} } varType := localMuts[currentFunc][stmt.name].typ valueType := stmt.value.cerType() if varType != valueType { return errTypeMismatch{ expected: varType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } fmt.Fprintf(w, "storew %s, %%%s\n", value, stmt.name) return nil } func generateAddAssignStmt(stmt *addAssignStmt, w io.Writer) error { ok, mutable := isDeclared(stmt.name, false) if !ok { return errUndeclared{name: stmt.name, line: stmt.line()} } if !mutable { return errImmutable{name: stmt.name, line: stmt.line()} } varType := localMuts[currentFunc][stmt.name].typ valueType := stmt.value.cerType() if varType != valueType { return errTypeMismatch{ expected: varType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } in, out := allocReg(), allocReg() typ := varType.qbeBaseType() loadType := varType.qbeABIType() storeType := varType.qbeExtType() fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name) fmt.Fprintf(w, "%s =%s add %s, %s\n", out, typ, in, value) fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name) return nil } func generateSubAssignStmt(stmt *subAssignStmt, w io.Writer) error { ok, mutable := isDeclared(stmt.name, false) if !ok { return errUndeclared{name: stmt.name, line: stmt.line()} } if !mutable { return errImmutable{name: stmt.name, line: stmt.line()} } varType := localMuts[currentFunc][stmt.name].typ valueType := stmt.value.cerType() if varType != valueType { return errTypeMismatch{ expected: varType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } in, out := allocReg(), allocReg() typ := varType.qbeBaseType() loadType := varType.qbeABIType() storeType := varType.qbeExtType() fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name) fmt.Fprintf(w, "%s =%s sub %s, %s\n", out, typ, in, value) fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name) return nil } func generateMulAssignStmt(stmt *mulAssignStmt, w io.Writer) error { ok, mutable := isDeclared(stmt.name, false) if !ok { return errUndeclared{name: stmt.name, line: stmt.line()} } if !mutable { return errImmutable{name: stmt.name, line: stmt.line()} } varType := localMuts[currentFunc][stmt.name].typ valueType := stmt.value.cerType() if varType != valueType { return errTypeMismatch{ expected: varType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } in, out := allocReg(), allocReg() typ := varType.qbeBaseType() loadType := varType.qbeABIType() storeType := varType.qbeExtType() fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name) fmt.Fprintf(w, "%s =%s mul %s, %s\n", out, typ, in, value) fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name) return nil } func generateDivAssignStmt(stmt *divAssignStmt, w io.Writer) error { ok, mutable := isDeclared(stmt.name, false) if !ok { return errUndeclared{name: stmt.name, line: stmt.line()} } if !mutable { return errImmutable{name: stmt.name, line: stmt.line()} } varType := localMuts[currentFunc][stmt.name].typ valueType := stmt.value.cerType() if varType != valueType { return errTypeMismatch{ expected: varType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } in, out := allocReg(), allocReg() typ := varType.qbeBaseType() loadType := varType.qbeABIType() storeType := varType.qbeExtType() fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name) fmt.Fprintf(w, "%s =%s div %s, %s\n", out, typ, in, value) fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name) return nil } func generateRemAssignStmt(stmt *remAssignStmt, w io.Writer) error { ok, mutable := isDeclared(stmt.name, false) if !ok { return errUndeclared{name: stmt.name, line: stmt.line()} } if !mutable { return errImmutable{name: stmt.name, line: stmt.line()} } varType := localMuts[currentFunc][stmt.name].typ valueType := stmt.value.cerType() if varType != valueType { return errTypeMismatch{ expected: varType, got: valueType, line: stmt.line(), } } value, err := stmt.value.generate(w) if err != nil { return err } in, out := allocReg(), allocReg() typ := varType.qbeBaseType() loadType := varType.qbeABIType() storeType := varType.qbeExtType() fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name) fmt.Fprintf(w, "%s =%s rem %s, %s\n", out, typ, in, value) fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name) return nil } func (e *equalityExpr) generate(w io.Writer) (string, error) { lhs, err := e.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range e.rhs { lhsType := e.lhs.cerType() rhsType := rhs.value.cerType() if rhsType != lhsType { return "", errTypeMismatch{ expected: lhsType, got: rhsType, line: e.line(), } } out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (e *equalityRhs) generate(in string, w io.Writer) (string, error) { instruction := "c" switch e.op { case equalTo: instruction += "eq" case notEqualTo: instruction += "ne" } instruction += "w" rhs, err := e.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (c *comparisonExpr) generate(w io.Writer) (string, error) { lhs, err := c.lhs.generate(w) if err != nil { return "", err } out := lhs if c.rhs != nil { lhsType := c.lhs.cerType() rhsType := c.rhs.value.cerType() if rhsType != lhsType { return "", errTypeMismatch{ expected: lhsType, got: rhsType, line: c.line(), } } out, err = c.rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (c *comparisonRhs) generate(in string, w io.Writer) (string, error) { instruction := "cs" switch c.op { case lessThan: instruction += "lt" case lessThanOrEqualTo: instruction += "le" case greaterThan: instruction += "gt" case greaterThanOrEqualTo: instruction += "ge" } instruction += "w" rhs, err := c.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (t *termExpr) generate(w io.Writer) (string, error) { lhs, err := t.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range t.rhs { lhsType := t.lhs.cerType() rhsType := rhs.value.cerType() if rhsType != lhsType { return "", errTypeMismatch{ expected: lhsType, got: rhsType, line: t.line(), } } out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (t *termRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch t.op { case shiftLeft: instruction += "shl" case shiftRight: instruction += "shr" } rhs, err := t.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (n *numeralExpr) generate(w io.Writer) (string, error) { lhs, err := n.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range n.rhs { lhsType := n.lhs.cerType() rhsType := rhs.value.cerType() if rhsType != lhsType { return "", errTypeMismatch{ expected: lhsType, got: rhsType, line: n.line(), } } out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (n *numeralRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch n.op { case add: instruction += "add" case subtract: instruction += "sub" } rhs, err := n.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (f *factorExpr) generate(w io.Writer) (string, error) { lhs, err := f.lhs.generate(w) if err != nil { return "", err } out := lhs for _, rhs := range f.rhs { lhsType := f.lhs.cerType() rhsType := rhs.value.cerType() if rhsType != lhsType { return "", errTypeMismatch{ expected: lhsType, got: rhsType, line: f.line(), } } out, err = rhs.generate(out, w) if err != nil { return "", err } } return out, nil } func (f *factorRhs) generate(in string, w io.Writer) (string, error) { instruction := "" switch f.op { case multiply: instruction += "mul" case divide: instruction += "div" case remainder: instruction += "rem" } rhs, err := f.value.generate(w) if err != nil { return "", err } out := allocReg() fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) return out, nil } func (u *unaryExpr) generate(w io.Writer) (string, error) { value, err := u.value.generate(w) if err != nil { return "", err } if u.op == unaryIdentity { return value, nil } instruction, rhs := "", "" switch u.op { case negate: instruction += "neg" case invertLogical: instruction += "ceqw" rhs = ", 0" case invertBits: instruction += "xor" rhs = ", 0xffffffff" } out := allocReg() fmt.Fprintf(w, "%s =w %s %s%s\n", out, instruction, value, rhs) return out, nil } func (g *groupingExpr) generate(w io.Writer) (string, error) { return g.inner.generate(w) } func (s *stringExpr) generate(w io.Writer) (string, error) { value := strings.Join(s.segments, "") out := allocString(value, w) return out, nil } func (n *numberExpr) generate(w io.Writer) (string, error) { typ := resolveType(n.typ) if typ == nil { return "", errUndeclared{ name: n.typ, kind: undeclaredType, line: n.line(), } } base := 10 switch { case strings.HasPrefix(n.s, "0x"): base = 16 case strings.HasPrefix(n.s, "0o"): base = 8 case strings.HasPrefix(n.s, "0b"): base = 2 } s := n.s if base != 10 { s = s[2:] } num, err := strconv.ParseInt(s, base, 64) if err != nil { return "", err } return fmt.Sprintf("%d", num), nil } func (c *callExpr) generate(w io.Writer) (string, error) { _, ok := funcs[c.funcName] if !ok { return "", errUndeclared{ name: c.funcName, kind: undeclaredFunction, line: c.line(), } } returnType := funcs[c.funcName].returnType ret := allocReg() if c.numArgs != len(funcs[c.funcName].params) { return "", errArgNumMismatch{ expected: len(funcs[c.funcName].params), got: c.numArgs, line: c.line(), } } args := make([]string, 0) delim := "" i := 0 for arg := c.args; arg != nil; arg = arg.next { reg, err := arg.value.generate(w) if err != nil { return "", err } typ := arg.value.cerType() neededType := funcs[c.funcName].params[i].typ if typ != neededType { return "", errTypeMismatch{ expected: neededType, got: typ, line: c.line(), } } s := fmt.Sprintf("%s %s%s", typ.qbeABIType(), reg, delim) args = append(args, s) delim = "," i++ } fmt.Fprintf(w, "%s =%s call $%s(", ret, returnType.qbeABIType(), c.funcName) for _, arg := range args { fmt.Fprintf(w, "%s", arg) } fmt.Fprintf(w, ")\n") return ret, nil } func (v *varExpr) generate(w io.Writer) (string, error) { ok, mutable := isDeclared(v.name, false) if !ok { return "", errUndeclared{name: v.name, line: v.line()} } if mutable { reg := allocReg() fmt.Fprintf(w, "%s =w loadl %%%s\n", reg, v.name) return reg, nil } return "%" + v.name, nil }