aboutsummaryrefslogtreecommitdiff
path: root/generate.go
diff options
context:
space:
mode:
Diffstat (limited to 'generate.go')
-rw-r--r--generate.go312
1 files changed, 282 insertions, 30 deletions
diff --git a/generate.go b/generate.go
index a32040a..15319b3 100644
--- a/generate.go
+++ b/generate.go
@@ -94,17 +94,30 @@ func generateFunction(function *functionExpr, w io.Writer) error {
currentFunc = function.name
defer func() { currentFunc = "" }()
- funcs[currentFunc] = struct{}{}
+ returnType := resolveType(function.returnType)
+ if returnType == nil {
+ return errUndeclared{
+ name: function.returnType,
+ kind: undeclaredType,
+ line: function.line(),
+ }
+ }
- localConsts[currentFunc] = map[string]string{}
- localMuts[currentFunc] = map[string]struct{}{}
+ funcs[currentFunc] = &funcInfo{returnType: returnType}
+
+ localConsts[currentFunc] = map[string]*localConst{}
+ localMuts[currentFunc] = map[string]*localMut{}
if function.link != defaultLinkage {
fmt.Fprintf(w, "%s ", function.link)
}
- fmt.Fprintf(w, "function %s $%s", function.returnType, function.name)
- fmt.Fprintf(w, "(%s) {\n", function.params)
+ fmt.Fprintf(w, "function %s $%s", returnType.qbeABIType(), function.name)
+ fmt.Fprintf(w, "(")
+ if err := generateParam(function.params, function.line(), w); err != nil {
+ return err
+ }
+ fmt.Fprintf(w, ") {\n")
fmt.Fprintln(w, "@start")
if err := generateBlock(function.blk, w); err != nil {
return err
@@ -114,12 +127,70 @@ func generateFunction(function *functionExpr, w io.Writer) error {
return nil
}
+func generateParam(p *paramExpr, line int, w io.Writer) error {
+ if p == nil {
+ return nil
+ }
+
+ typ := resolveType(p.typ)
+ if typ == nil {
+ return errUndeclared{name: p.typ, kind: undeclaredType, line: line}
+ }
+
+ funcs[currentFunc].params = append(funcs[currentFunc].params, paramInfo{
+ typ: typ,
+ })
+
+ fmt.Fprintf(w, "%s %%%s", typ.qbeABIType(), p.name)
+ if p.next != nil {
+ fmt.Fprintf(w, ", ")
+ if err := generateParam(p.next, line, w); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
func generateExternFunc(e *externFuncExpr, w io.Writer) error {
if ok, _ := isDeclared(e.name, true); ok {
return errAlreadyDeclared{name: e.name, line: e.line()}
}
- funcs[e.name] = struct{}{}
+ currentFunc = e.name
+ defer func() { currentFunc = "" }()
+
+ returnType := resolveType(e.returnType)
+
+ funcs[e.name] = &funcInfo{returnType: returnType}
+
+ if err := generateSignature(e.params, e.line(), w); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func generateSignature(s *signatureExpr, line int, w io.Writer) error {
+ if s == nil {
+ return nil
+ }
+
+ typ := resolveType(s.typ)
+ if typ == nil {
+ return errUndeclared{name: s.typ, kind: undeclaredType, line: line}
+ }
+
+ funcs[currentFunc].params = append(funcs[currentFunc].params, paramInfo{
+ typ: typ,
+ })
+
+ if s.next != nil {
+ if err := generateSignature(s.next, line, w); err != nil {
+ return err
+ }
+ }
+
return nil
}
@@ -159,6 +230,16 @@ func generateStatement(stmt statementExpr, w io.Writer) error {
}
func generateReturnStmt(stmt *returnStmt, w io.Writer) error {
+ returnType := funcs[currentFunc].returnType
+ valueType := stmt.value.cerType()
+ if valueType != returnType {
+ return errTypeMismatch{
+ expected: returnType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -179,9 +260,12 @@ func generateConstStmt(stmt *constStmt, toplevel bool, w io.Writer) error {
return err
}
- localConsts[currentFunc][stmt.name] = value
+ typ := stmt.initial.cerType()
+ localConsts[currentFunc][stmt.name] = &localConst{
+ typ: typ,
+ }
- fmt.Fprintf(w, "%%%s =w add %s, 0\n", stmt.name, value)
+ fmt.Fprintf(w, "%%%s =%s add %s, 0\n", stmt.name, typ.qbeABIType(), value)
return nil
}
@@ -195,7 +279,10 @@ func generateMutStmt(stmt *mutStmt, toplevel bool, w io.Writer) error {
return err
}
- localMuts[currentFunc][stmt.name] = struct{}{}
+ typ := stmt.initial.cerType()
+ localMuts[currentFunc][stmt.name] = &localMut{
+ typ: typ,
+ }
fmt.Fprintf(w, "%%%s =l alloc4 4\n", stmt.name)
fmt.Fprintf(w, "storew %s, %%%s\n", value, stmt.name)
@@ -211,6 +298,16 @@ func generateAssignStmt(stmt *assignStmt, w io.Writer) error {
return errImmutable{name: stmt.name, line: stmt.line()}
}
+ varType := localMuts[currentFunc][stmt.name].typ
+ valueType := stmt.value.cerType()
+ if varType != valueType {
+ return errTypeMismatch{
+ expected: varType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -229,6 +326,16 @@ func generateAddAssignStmt(stmt *addAssignStmt, w io.Writer) error {
return errImmutable{name: stmt.name, line: stmt.line()}
}
+ varType := localMuts[currentFunc][stmt.name].typ
+ valueType := stmt.value.cerType()
+ if varType != valueType {
+ return errTypeMismatch{
+ expected: varType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -236,9 +343,13 @@ func generateAddAssignStmt(stmt *addAssignStmt, w io.Writer) error {
in, out := allocReg(), allocReg()
- fmt.Fprintf(w, "%s =w loadl %%%s\n", in, stmt.name)
- fmt.Fprintf(w, "%s =w add %s, %s\n", out, in, value)
- fmt.Fprintf(w, "storew %s, %%%s\n", out, stmt.name)
+ typ := varType.qbeBaseType()
+ loadType := varType.qbeABIType()
+ storeType := varType.qbeExtType()
+
+ fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name)
+ fmt.Fprintf(w, "%s =%s add %s, %s\n", out, typ, in, value)
+ fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name)
return nil
}
@@ -251,6 +362,16 @@ func generateSubAssignStmt(stmt *subAssignStmt, w io.Writer) error {
return errImmutable{name: stmt.name, line: stmt.line()}
}
+ varType := localMuts[currentFunc][stmt.name].typ
+ valueType := stmt.value.cerType()
+ if varType != valueType {
+ return errTypeMismatch{
+ expected: varType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -258,9 +379,13 @@ func generateSubAssignStmt(stmt *subAssignStmt, w io.Writer) error {
in, out := allocReg(), allocReg()
- fmt.Fprintf(w, "%s =w loadl %%%s\n", in, stmt.name)
- fmt.Fprintf(w, "%s =w sub %s, %s\n", out, in, value)
- fmt.Fprintf(w, "storew %s, %%%s\n", out, stmt.name)
+ typ := varType.qbeBaseType()
+ loadType := varType.qbeABIType()
+ storeType := varType.qbeExtType()
+
+ fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name)
+ fmt.Fprintf(w, "%s =%s sub %s, %s\n", out, typ, in, value)
+ fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name)
return nil
}
@@ -273,6 +398,16 @@ func generateMulAssignStmt(stmt *mulAssignStmt, w io.Writer) error {
return errImmutable{name: stmt.name, line: stmt.line()}
}
+ varType := localMuts[currentFunc][stmt.name].typ
+ valueType := stmt.value.cerType()
+ if varType != valueType {
+ return errTypeMismatch{
+ expected: varType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -280,9 +415,13 @@ func generateMulAssignStmt(stmt *mulAssignStmt, w io.Writer) error {
in, out := allocReg(), allocReg()
- fmt.Fprintf(w, "%s =w loadl %%%s\n", in, stmt.name)
- fmt.Fprintf(w, "%s =w mul %s, %s\n", out, in, value)
- fmt.Fprintf(w, "storew %s, %%%s\n", out, stmt.name)
+ typ := varType.qbeBaseType()
+ loadType := varType.qbeABIType()
+ storeType := varType.qbeExtType()
+
+ fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name)
+ fmt.Fprintf(w, "%s =%s mul %s, %s\n", out, typ, in, value)
+ fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name)
return nil
}
@@ -295,6 +434,16 @@ func generateDivAssignStmt(stmt *divAssignStmt, w io.Writer) error {
return errImmutable{name: stmt.name, line: stmt.line()}
}
+ varType := localMuts[currentFunc][stmt.name].typ
+ valueType := stmt.value.cerType()
+ if varType != valueType {
+ return errTypeMismatch{
+ expected: varType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -302,9 +451,13 @@ func generateDivAssignStmt(stmt *divAssignStmt, w io.Writer) error {
in, out := allocReg(), allocReg()
- fmt.Fprintf(w, "%s =w loadl %%%s\n", in, stmt.name)
- fmt.Fprintf(w, "%s =w div %s, %s\n", out, in, value)
- fmt.Fprintf(w, "storew %s, %%%s\n", out, stmt.name)
+ typ := varType.qbeBaseType()
+ loadType := varType.qbeABIType()
+ storeType := varType.qbeExtType()
+
+ fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name)
+ fmt.Fprintf(w, "%s =%s div %s, %s\n", out, typ, in, value)
+ fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name)
return nil
}
@@ -317,6 +470,16 @@ func generateRemAssignStmt(stmt *remAssignStmt, w io.Writer) error {
return errImmutable{name: stmt.name, line: stmt.line()}
}
+ varType := localMuts[currentFunc][stmt.name].typ
+ valueType := stmt.value.cerType()
+ if varType != valueType {
+ return errTypeMismatch{
+ expected: varType,
+ got: valueType,
+ line: stmt.line(),
+ }
+ }
+
value, err := stmt.value.generate(w)
if err != nil {
return err
@@ -324,9 +487,13 @@ func generateRemAssignStmt(stmt *remAssignStmt, w io.Writer) error {
in, out := allocReg(), allocReg()
- fmt.Fprintf(w, "%s =w loadl %%%s\n", in, stmt.name)
- fmt.Fprintf(w, "%s =w rem %s, %s\n", out, in, value)
- fmt.Fprintf(w, "storew %s, %%%s\n", out, stmt.name)
+ typ := varType.qbeBaseType()
+ loadType := varType.qbeABIType()
+
+ storeType := varType.qbeExtType()
+ fmt.Fprintf(w, "%s =%s load%s %%%s\n", in, typ, loadType, stmt.name)
+ fmt.Fprintf(w, "%s =%s rem %s, %s\n", out, typ, in, value)
+ fmt.Fprintf(w, "store%s %s, %%%s\n", storeType, out, stmt.name)
return nil
}
@@ -338,6 +505,16 @@ func (e *equalityExpr) generate(w io.Writer) (string, error) {
out := lhs
for _, rhs := range e.rhs {
+ lhsType := e.lhs.cerType()
+ rhsType := rhs.value.cerType()
+ if rhsType != lhsType {
+ return "", errTypeMismatch{
+ expected: lhsType,
+ got: rhsType,
+ line: e.line(),
+ }
+ }
+
out, err = rhs.generate(out, w)
if err != nil {
return "", err
@@ -378,6 +555,16 @@ func (c *comparisonExpr) generate(w io.Writer) (string, error) {
out := lhs
if c.rhs != nil {
+ lhsType := c.lhs.cerType()
+ rhsType := c.rhs.value.cerType()
+ if rhsType != lhsType {
+ return "", errTypeMismatch{
+ expected: lhsType,
+ got: rhsType,
+ line: c.line(),
+ }
+ }
+
out, err = c.rhs.generate(out, w)
if err != nil {
return "", err
@@ -422,6 +609,16 @@ func (t *termExpr) generate(w io.Writer) (string, error) {
out := lhs
for _, rhs := range t.rhs {
+ lhsType := t.lhs.cerType()
+ rhsType := rhs.value.cerType()
+ if rhsType != lhsType {
+ return "", errTypeMismatch{
+ expected: lhsType,
+ got: rhsType,
+ line: t.line(),
+ }
+ }
+
out, err = rhs.generate(out, w)
if err != nil {
return "", err
@@ -460,6 +657,16 @@ func (n *numeralExpr) generate(w io.Writer) (string, error) {
out := lhs
for _, rhs := range n.rhs {
+ lhsType := n.lhs.cerType()
+ rhsType := rhs.value.cerType()
+ if rhsType != lhsType {
+ return "", errTypeMismatch{
+ expected: lhsType,
+ got: rhsType,
+ line: n.line(),
+ }
+ }
+
out, err = rhs.generate(out, w)
if err != nil {
return "", err
@@ -498,6 +705,16 @@ func (f *factorExpr) generate(w io.Writer) (string, error) {
out := lhs
for _, rhs := range f.rhs {
+ lhsType := f.lhs.cerType()
+ rhsType := rhs.value.cerType()
+ if rhsType != lhsType {
+ return "", errTypeMismatch{
+ expected: lhsType,
+ got: rhsType,
+ line: f.line(),
+ }
+ }
+
out, err = rhs.generate(out, w)
if err != nil {
return "", err
@@ -570,6 +787,15 @@ func (s *stringExpr) generate(w io.Writer) (string, error) {
}
func (n *numberExpr) generate(w io.Writer) (string, error) {
+ typ := resolveType(n.typ)
+ if typ == nil {
+ return "", errUndeclared{
+ name: n.typ,
+ kind: undeclaredType,
+ line: n.line(),
+ }
+ }
+
base := 10
switch {
case strings.HasPrefix(n.s, "0x"):
@@ -597,27 +823,53 @@ func (c *callExpr) generate(w io.Writer) (string, error) {
_, ok := funcs[c.funcName]
if !ok {
return "", errUndeclared{
- name: c.funcName,
- isFunc: true,
- line: c.line(),
+ name: c.funcName,
+ kind: undeclaredFunction,
+ line: c.line(),
}
}
+ returnType := funcs[c.funcName].returnType
+
ret := allocReg()
- fmt.Fprintf(w, "%s =w call $%s(", ret, c.funcName)
+ if c.numArgs != len(funcs[c.funcName].params) {
+ return "", errArgNumMismatch{
+ expected: len(funcs[c.funcName].params),
+ got: c.numArgs,
+ line: c.line(),
+ }
+ }
+ args := make([]string, 0)
delim := ""
+ i := 0
for arg := c.args; arg != nil; arg = arg.next {
reg, err := arg.value.generate(w)
if err != nil {
return "", err
}
- fmt.Fprintf(w, "w %s%s", reg, delim)
+ typ := arg.value.cerType()
+ neededType := funcs[c.funcName].params[i].typ
+ if typ != neededType {
+ return "", errTypeMismatch{
+ expected: neededType,
+ got: typ,
+ line: c.line(),
+ }
+ }
+
+ s := fmt.Sprintf("%s %s%s", typ.qbeABIType(), reg, delim)
+ args = append(args, s)
delim = ","
+ i++
}
+ fmt.Fprintf(w, "%s =%s call $%s(", ret, returnType.qbeABIType(), c.funcName)
+ for _, arg := range args {
+ fmt.Fprintf(w, "%s", arg)
+ }
fmt.Fprintf(w, ")\n")
return ret, nil
@@ -635,5 +887,5 @@ func (v *varExpr) generate(w io.Writer) (string, error) {
return reg, nil
}
- return localConsts[currentFunc][v.name], nil
+ return "%" + v.name, nil
}