diff options
Diffstat (limited to 'generate.go')
-rw-r--r-- | generate.go | 311 |
1 files changed, 301 insertions, 10 deletions
diff --git a/generate.go b/generate.go index 75da82e..5ff27c8 100644 --- a/generate.go +++ b/generate.go @@ -3,8 +3,42 @@ package main import ( "fmt" "io" + "strconv" + "strings" ) +var funcRegIndex int + +func allocReg() string { + defer func() { funcRegIndex++ }() + return fmt.Sprintf("%%local%d", funcRegIndex) +} + +func resetRegs() { + funcRegIndex = 0 +} + +var globalRegIndex int +var stringLiterals = map[string]string{} + +func allocGlobal() string { + defer func() { globalRegIndex++ }() + return fmt.Sprintf("$global%d", globalRegIndex) +} + +func allocString(s string, w io.Writer) string { + if reg, ok := stringLiterals[s]; ok { + return reg + } + + reg := allocGlobal() + stringLiterals[s] = reg + + fmt.Fprintf(w, "data %s = { %q }", reg, s) + + return reg +} + type invalidToplevel struct { got expression } @@ -33,6 +67,8 @@ func generateToplevel(toplevel expression, w io.Writer) error { } func generateFunction(function *functionExpr, w io.Writer) error { + resetRegs() + if function.link != defaultLinkage { fmt.Fprintf(w, "%s ", function.link) } @@ -68,18 +104,273 @@ func generateStatement(stmt statementExpr, w io.Writer) error { } func generateReturnStmt(stmt *returnStmt, w io.Writer) error { - eq := stmt.value.(*equalityExpr) - cmp := eq.lhs - trm := cmp.lhs - num := trm.lhs - fct := num.lhs - un := fct.lhs - prm := un.value - lit := prm.(literalExpr) - n := lit.(*numberExpr) - value := n.s + value, err := stmt.value.generate(w) + if err != nil { + return err + } fmt.Fprintln(w, "ret", value) return nil } + +func (e *equalityExpr) generate(w io.Writer) (string, error) { + lhs, err := e.lhs.generate(w) + if err != nil { + return "", err + } + + out := lhs + for _, rhs := range e.rhs { + out, err = rhs.generate(out, w) + if err != nil { + return "", err + } + } + + return out, nil +} + +func (e *equalityRhs) generate(in string, w io.Writer) (string, error) { + instruction := "c" + + switch e.op { + case equalTo: + instruction += "eq" + case notEqualTo: + instruction += "ne" + } + + instruction += "w" + + rhs, err := e.value.generate(w) + if err != nil { + return "", err + } + + out := allocReg() + + fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) + return out, nil +} + +func (c *comparisonExpr) generate(w io.Writer) (string, error) { + lhs, err := c.lhs.generate(w) + if err != nil { + return "", err + } + + out := lhs + if c.rhs != nil { + out, err = c.rhs.generate(out, w) + if err != nil { + return "", err + } + } + + return out, nil +} + +func (c *comparisonRhs) generate(in string, w io.Writer) (string, error) { + instruction := "cs" + + switch c.op { + case lessThan: + instruction += "lt" + case lessThanOrEqualTo: + instruction += "le" + case greaterThan: + instruction += "gt" + case greaterThanOrEqualTo: + instruction += "ge" + } + + instruction += "w" + + rhs, err := c.value.generate(w) + if err != nil { + return "", err + } + + out := allocReg() + + fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) + return out, nil +} + +func (t *termExpr) generate(w io.Writer) (string, error) { + lhs, err := t.lhs.generate(w) + if err != nil { + return "", err + } + + out := lhs + for _, rhs := range t.rhs { + out, err = rhs.generate(out, w) + if err != nil { + return "", err + } + } + + return out, nil +} + +func (t *termRhs) generate(in string, w io.Writer) (string, error) { + instruction := "" + + switch t.op { + case shiftLeft: + instruction += "shl" + case shiftRight: + instruction += "shr" + } + + rhs, err := t.value.generate(w) + if err != nil { + return "", err + } + + out := allocReg() + + fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) + return out, nil +} + +func (n *numeralExpr) generate(w io.Writer) (string, error) { + lhs, err := n.lhs.generate(w) + if err != nil { + return "", err + } + + out := lhs + for _, rhs := range n.rhs { + out, err = rhs.generate(out, w) + if err != nil { + return "", err + } + } + + return out, nil +} + +func (n *numeralRhs) generate(in string, w io.Writer) (string, error) { + instruction := "" + + switch n.op { + case add: + instruction += "add" + case subtract: + instruction += "sub" + } + + rhs, err := n.value.generate(w) + if err != nil { + return "", err + } + + out := allocReg() + + fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) + return out, nil +} + +func (f *factorExpr) generate(w io.Writer) (string, error) { + lhs, err := f.lhs.generate(w) + if err != nil { + return "", err + } + + out := lhs + for _, rhs := range f.rhs { + out, err = rhs.generate(out, w) + if err != nil { + return "", err + } + } + + return out, nil +} + +func (f *factorRhs) generate(in string, w io.Writer) (string, error) { + instruction := "" + + switch f.op { + case multiply: + instruction += "mul" + case divide: + instruction += "div" + } + + rhs, err := f.value.generate(w) + if err != nil { + return "", err + } + + out := allocReg() + + fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs) + return out, nil +} + +func (u *unaryExpr) generate(w io.Writer) (string, error) { + value, err := u.value.generate(w) + if err != nil { + return "", err + } + + if u.op == unaryIdentity { + return value, nil + } + + instruction, rhs := "", "" + + switch u.op { + case negate: + instruction += "neg" + case invertLogical: + instruction += "ceqw" + rhs = ", 0" + case invertBits: + instruction += "xor" + rhs = ", 0xffffffff" + } + + out := allocReg() + + fmt.Fprintf(w, "%s =w %s %s%s\n", out, instruction, value, rhs) + return out, nil +} + +func (g *groupingExpr) generate(w io.Writer) (string, error) { + return g.inner.generate(w) +} + +func (s *stringExpr) generate(w io.Writer) (string, error) { + value := strings.Join(s.segments, "") + out := allocString(value, w) + return out, nil +} + +func (n *numberExpr) generate(w io.Writer) (string, error) { + base := 10 + switch { + case strings.HasPrefix(n.s, "0x"): + base = 16 + case strings.HasPrefix(n.s, "0o"): + base = 8 + case strings.HasPrefix(n.s, "0b"): + base = 2 + } + + s := n.s + if base != 10 { + s = s[2:] + } + + num, err := strconv.ParseInt(s, base, 64) + if err != nil { + return "", err + } + + return fmt.Sprintf("%d", num), nil +} |