aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHimbeer <himbeer@disroot.org>2024-08-30 10:47:31 +0200
committerHimbeer <himbeer@disroot.org>2024-08-30 10:47:31 +0200
commite5bba0f77dbed61b41131fa04a5f52eb5a454a01 (patch)
tree21bc5f032a889a8f0cbce4b3b591752b22de2933
parent3adfc2d52c29fd90257813f33b8286225f7adc04 (diff)
Implement IL generation for expression evaluation
-rw-r--r--expression.go6
-rw-r--r--generate.go311
2 files changed, 306 insertions, 11 deletions
diff --git a/expression.go b/expression.go
index 0498d35..6ae3399 100644
--- a/expression.go
+++ b/expression.go
@@ -1,6 +1,9 @@
package main
-import "fmt"
+import (
+ "fmt"
+ "io"
+)
type expression interface {
markExpr()
@@ -97,6 +100,7 @@ func (r *returnStmt) line() int { return r.ln }
type exprExpr interface {
expression
markExprExpr()
+ generate(io.Writer) (string, error)
}
type equalityExpr struct {
diff --git a/generate.go b/generate.go
index 75da82e..5ff27c8 100644
--- a/generate.go
+++ b/generate.go
@@ -3,8 +3,42 @@ package main
import (
"fmt"
"io"
+ "strconv"
+ "strings"
)
+var funcRegIndex int
+
+func allocReg() string {
+ defer func() { funcRegIndex++ }()
+ return fmt.Sprintf("%%local%d", funcRegIndex)
+}
+
+func resetRegs() {
+ funcRegIndex = 0
+}
+
+var globalRegIndex int
+var stringLiterals = map[string]string{}
+
+func allocGlobal() string {
+ defer func() { globalRegIndex++ }()
+ return fmt.Sprintf("$global%d", globalRegIndex)
+}
+
+func allocString(s string, w io.Writer) string {
+ if reg, ok := stringLiterals[s]; ok {
+ return reg
+ }
+
+ reg := allocGlobal()
+ stringLiterals[s] = reg
+
+ fmt.Fprintf(w, "data %s = { %q }", reg, s)
+
+ return reg
+}
+
type invalidToplevel struct {
got expression
}
@@ -33,6 +67,8 @@ func generateToplevel(toplevel expression, w io.Writer) error {
}
func generateFunction(function *functionExpr, w io.Writer) error {
+ resetRegs()
+
if function.link != defaultLinkage {
fmt.Fprintf(w, "%s ", function.link)
}
@@ -68,18 +104,273 @@ func generateStatement(stmt statementExpr, w io.Writer) error {
}
func generateReturnStmt(stmt *returnStmt, w io.Writer) error {
- eq := stmt.value.(*equalityExpr)
- cmp := eq.lhs
- trm := cmp.lhs
- num := trm.lhs
- fct := num.lhs
- un := fct.lhs
- prm := un.value
- lit := prm.(literalExpr)
- n := lit.(*numberExpr)
- value := n.s
+ value, err := stmt.value.generate(w)
+ if err != nil {
+ return err
+ }
fmt.Fprintln(w, "ret", value)
return nil
}
+
+func (e *equalityExpr) generate(w io.Writer) (string, error) {
+ lhs, err := e.lhs.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := lhs
+ for _, rhs := range e.rhs {
+ out, err = rhs.generate(out, w)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return out, nil
+}
+
+func (e *equalityRhs) generate(in string, w io.Writer) (string, error) {
+ instruction := "c"
+
+ switch e.op {
+ case equalTo:
+ instruction += "eq"
+ case notEqualTo:
+ instruction += "ne"
+ }
+
+ instruction += "w"
+
+ rhs, err := e.value.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := allocReg()
+
+ fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs)
+ return out, nil
+}
+
+func (c *comparisonExpr) generate(w io.Writer) (string, error) {
+ lhs, err := c.lhs.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := lhs
+ if c.rhs != nil {
+ out, err = c.rhs.generate(out, w)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return out, nil
+}
+
+func (c *comparisonRhs) generate(in string, w io.Writer) (string, error) {
+ instruction := "cs"
+
+ switch c.op {
+ case lessThan:
+ instruction += "lt"
+ case lessThanOrEqualTo:
+ instruction += "le"
+ case greaterThan:
+ instruction += "gt"
+ case greaterThanOrEqualTo:
+ instruction += "ge"
+ }
+
+ instruction += "w"
+
+ rhs, err := c.value.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := allocReg()
+
+ fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs)
+ return out, nil
+}
+
+func (t *termExpr) generate(w io.Writer) (string, error) {
+ lhs, err := t.lhs.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := lhs
+ for _, rhs := range t.rhs {
+ out, err = rhs.generate(out, w)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return out, nil
+}
+
+func (t *termRhs) generate(in string, w io.Writer) (string, error) {
+ instruction := ""
+
+ switch t.op {
+ case shiftLeft:
+ instruction += "shl"
+ case shiftRight:
+ instruction += "shr"
+ }
+
+ rhs, err := t.value.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := allocReg()
+
+ fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs)
+ return out, nil
+}
+
+func (n *numeralExpr) generate(w io.Writer) (string, error) {
+ lhs, err := n.lhs.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := lhs
+ for _, rhs := range n.rhs {
+ out, err = rhs.generate(out, w)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return out, nil
+}
+
+func (n *numeralRhs) generate(in string, w io.Writer) (string, error) {
+ instruction := ""
+
+ switch n.op {
+ case add:
+ instruction += "add"
+ case subtract:
+ instruction += "sub"
+ }
+
+ rhs, err := n.value.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := allocReg()
+
+ fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs)
+ return out, nil
+}
+
+func (f *factorExpr) generate(w io.Writer) (string, error) {
+ lhs, err := f.lhs.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := lhs
+ for _, rhs := range f.rhs {
+ out, err = rhs.generate(out, w)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return out, nil
+}
+
+func (f *factorRhs) generate(in string, w io.Writer) (string, error) {
+ instruction := ""
+
+ switch f.op {
+ case multiply:
+ instruction += "mul"
+ case divide:
+ instruction += "div"
+ }
+
+ rhs, err := f.value.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ out := allocReg()
+
+ fmt.Fprintf(w, "%s =w %s %s, %s\n", out, instruction, in, rhs)
+ return out, nil
+}
+
+func (u *unaryExpr) generate(w io.Writer) (string, error) {
+ value, err := u.value.generate(w)
+ if err != nil {
+ return "", err
+ }
+
+ if u.op == unaryIdentity {
+ return value, nil
+ }
+
+ instruction, rhs := "", ""
+
+ switch u.op {
+ case negate:
+ instruction += "neg"
+ case invertLogical:
+ instruction += "ceqw"
+ rhs = ", 0"
+ case invertBits:
+ instruction += "xor"
+ rhs = ", 0xffffffff"
+ }
+
+ out := allocReg()
+
+ fmt.Fprintf(w, "%s =w %s %s%s\n", out, instruction, value, rhs)
+ return out, nil
+}
+
+func (g *groupingExpr) generate(w io.Writer) (string, error) {
+ return g.inner.generate(w)
+}
+
+func (s *stringExpr) generate(w io.Writer) (string, error) {
+ value := strings.Join(s.segments, "")
+ out := allocString(value, w)
+ return out, nil
+}
+
+func (n *numberExpr) generate(w io.Writer) (string, error) {
+ base := 10
+ switch {
+ case strings.HasPrefix(n.s, "0x"):
+ base = 16
+ case strings.HasPrefix(n.s, "0o"):
+ base = 8
+ case strings.HasPrefix(n.s, "0b"):
+ base = 2
+ }
+
+ s := n.s
+ if base != 10 {
+ s = s[2:]
+ }
+
+ num, err := strconv.ParseInt(s, base, 64)
+ if err != nil {
+ return "", err
+ }
+
+ return fmt.Sprintf("%d", num), nil
+}