github.com/coyove/nj@v0.0.0-20221110084952-c7f8db1065c3/compile.go (about)

     1  package nj
     2  
     3  import (
     4  	"math"
     5  
     6  	"github.com/coyove/nj/bas"
     7  	"github.com/coyove/nj/internal"
     8  	"github.com/coyove/nj/parser"
     9  	"github.com/coyove/nj/typ"
    10  )
    11  
    12  // [prog expr1 expr2 ...]
    13  func compileProgBlock(table *symTable, node *parser.Prog) uint16 {
    14  	if node.DoBlock {
    15  		table.addMaskedSymTable()
    16  	}
    17  
    18  	yx := uint16(typ.RegA)
    19  	for _, a := range node.Stats {
    20  		if a == nil {
    21  			continue
    22  		}
    23  		switch a.(type) {
    24  		case parser.Address, parser.Primitive, *parser.Symbol:
    25  			// e.g.: [prog "a string"] will be transformed into: [prog [set $a "a string"]]
    26  			yx = table.compileNode(a)
    27  			table.codeSeg.WriteInst(typ.OpSet, typ.RegA, yx)
    28  		default:
    29  			yx = table.compileNode(a)
    30  		}
    31  
    32  		table.releaseAddr(table.pendingReleases)
    33  		table.pendingReleases = table.pendingReleases[:0]
    34  	}
    35  
    36  	if node.DoBlock {
    37  		table.removeMaskedSymTable()
    38  	}
    39  	return yx
    40  }
    41  
    42  // local a = b
    43  func compileDeclare(table *symTable, node *parser.Declare) uint16 {
    44  	dest := node.Name.Name
    45  	if bas.GetTopIndex(dest) > 0 || dest == staticTrue || dest == staticFalse || dest == staticThis || dest == staticSelf {
    46  		table.panicnode(node.Name, "can't bound to a global static name")
    47  	}
    48  
    49  	destAddr := table.borrowAddress()
    50  	defer table.put(dest, destAddr) // execute in defer in case of: a = 1 do local a = a end
    51  	table.codeSeg.WriteInst(typ.OpSet, destAddr, table.compileNode(node.Value))
    52  	table.codeSeg.WriteLineNum(node.Line)
    53  	return destAddr
    54  }
    55  
    56  // a = b
    57  func compileAssign(table *symTable, node *parser.Assign) uint16 {
    58  	dest := node.Name.Name
    59  	if bas.GetTopIndex(dest) > 0 || dest == staticTrue || dest == staticFalse || dest == staticThis || dest == staticSelf {
    60  		table.panicnode(node.Name, "can't assign to a global static name")
    61  	}
    62  	destAddr, declared := table.get(dest)
    63  	if !declared {
    64  		// a is not declared yet
    65  		destAddr = table.borrowAddress()
    66  
    67  		// Do not use t.put() because it may put the symbol into masked tables
    68  		// e.g.: do a = 1 end
    69  		table.sym.Set(dest, bas.Int64(int64(destAddr)))
    70  	} else {
    71  	}
    72  	table.codeSeg.WriteInst(typ.OpSet, destAddr, table.compileNode(node.Value))
    73  	table.codeSeg.WriteLineNum(node.Line)
    74  	return destAddr
    75  }
    76  
    77  func compileUnary(table *symTable, node *parser.Unary) uint16 {
    78  	nodes := table.collapse(true, node.A)
    79  	table.compileOpcode1Node(node.Op, nodes[0])
    80  	table.releaseAddr(nodes)
    81  	table.codeSeg.WriteLineNum(node.Line)
    82  	return typ.RegA
    83  }
    84  
    85  func compileBinary(table *symTable, node *parser.Binary) uint16 {
    86  	if node.Op >= typ.OpExtBitAnd && node.Op <= typ.OpExtBitURsh {
    87  		return compileBitwise(table, node)
    88  	}
    89  	nodes := table.collapse(true, node.A, node.B)
    90  	table.compileOpcode2Node(node.Op, nodes[0], nodes[1])
    91  	table.releaseAddr(nodes)
    92  	table.codeSeg.WriteLineNum(node.Line)
    93  	return typ.RegA
    94  }
    95  
    96  func compileTenary(table *symTable, node *parser.Tenary) uint16 {
    97  	nodes := table.collapse(true, node.A, node.B, node.C)
    98  	table.compileOpcode3Node(node.Op, nodes[0], nodes[1], nodes[2])
    99  	table.releaseAddr(nodes)
   100  	table.codeSeg.WriteLineNum(node.Line)
   101  	return typ.RegA
   102  }
   103  
   104  func compileBitwise(table *symTable, node *parser.Binary) uint16 {
   105  	nodes := table.collapse(true, node.A, node.B)
   106  	a, b := nodes[0], nodes[1]
   107  	switch node.Op {
   108  	case typ.OpExtBitAnd, typ.OpExtBitOr, typ.OpExtBitXor:
   109  		if a16, ok := toInt16(a); ok {
   110  			table.compileOpcode2Node(typ.OpExt, b, parser.Address(a16))
   111  			node.Op = typ.OpExtBitAnd16 + node.Op - typ.OpExtBitAnd
   112  		} else if b16, ok := toInt16(b); ok {
   113  			table.compileOpcode2Node(typ.OpExt, a, parser.Address(b16))
   114  			node.Op = typ.OpExtBitAnd16 + node.Op - typ.OpExtBitAnd
   115  		} else {
   116  			table.compileOpcode2Node(typ.OpExt, a, b)
   117  		}
   118  	case typ.OpExtBitRsh, typ.OpExtBitLsh, typ.OpExtBitURsh:
   119  		if b16, ok := toInt16(b); ok {
   120  			table.compileOpcode2Node(typ.OpExt, a, parser.Address(b16))
   121  			node.Op = typ.OpExtBitAnd16 + node.Op - typ.OpExtBitAnd
   122  		} else {
   123  			table.compileOpcode2Node(typ.OpExt, a, b)
   124  		}
   125  	}
   126  	table.releaseAddr(nodes)
   127  	table.codeSeg.WriteLineNum(node.Line)
   128  	table.codeSeg.Code[len(table.codeSeg.Code)-1].OpcodeExt = node.Op
   129  	return typ.RegA
   130  }
   131  
   132  // [and a b] => $a = a if not a then goto out else $a = b end ::out::
   133  func compileAnd(table *symTable, node *parser.And) uint16 {
   134  	table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.A)
   135  
   136  	table.codeSeg.WriteJmpInst(typ.OpJmpFalse, 0)
   137  	part1 := table.codeSeg.Len()
   138  
   139  	table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.B)
   140  	part2 := table.codeSeg.Len()
   141  
   142  	table.codeSeg.Code[part1-1] = typ.JmpInst(typ.OpJmpFalse, part2-part1)
   143  	return typ.RegA
   144  }
   145  
   146  // [or a b]  => $a = a if not a then $a = b end
   147  func compileOr(table *symTable, node *parser.Or) uint16 {
   148  	table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.A)
   149  
   150  	table.codeSeg.WriteJmpInst(typ.OpJmpFalse, 1)
   151  	table.codeSeg.WriteJmpInst(typ.OpJmp, 0)
   152  	part1 := table.codeSeg.Len()
   153  
   154  	table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.B)
   155  	part2 := table.codeSeg.Len()
   156  
   157  	table.codeSeg.Code[part1-1] = typ.JmpInst(typ.OpJmp, part2-part1)
   158  	return typ.RegA
   159  }
   160  
   161  func compileIf(table *symTable, node *parser.If) uint16 {
   162  	condyx := table.compileNode(node.Cond)
   163  	if condyx != typ.RegA {
   164  		table.codeSeg.WriteInst(typ.OpSet, typ.RegA, condyx)
   165  	}
   166  
   167  	table.codeSeg.WriteJmpInst(typ.OpJmpFalse, 0)
   168  	init := table.codeSeg.Len()
   169  
   170  	table.addMaskedSymTable()
   171  	table.compileNode(node.True)
   172  	part1 := table.codeSeg.Len()
   173  
   174  	table.codeSeg.WriteJmpInst(typ.OpJmp, 0)
   175  
   176  	if node.False != nil {
   177  		table.compileNode(node.False)
   178  		part2 := table.codeSeg.Len()
   179  
   180  		table.removeMaskedSymTable()
   181  
   182  		table.codeSeg.Code[init-1] = typ.JmpInst(typ.OpJmpFalse, part1-init+1)
   183  		table.codeSeg.Code[part1] = typ.JmpInst(typ.OpJmp, part2-part1-1)
   184  	} else {
   185  		table.removeMaskedSymTable()
   186  
   187  		// The last inst is used to skip the false branch, since we don't have one, we don't need this jmp
   188  		table.codeSeg.TruncLast()
   189  		table.codeSeg.Code[init-1] = typ.JmpInst(typ.OpJmpFalse, part1-init)
   190  	}
   191  	return typ.RegA
   192  }
   193  
   194  // [object [k1, v1, k2, v2, ...]]
   195  func compileObject(table *symTable, node parser.ExprAssignList) uint16 {
   196  	tmp := table.collapse(true, node.ExpandAsExprList()...)
   197  	for i := 0; i < len(tmp); i += 2 {
   198  		table.compileOpcode1Node(typ.OpPush, tmp[i])
   199  		table.compileOpcode1Node(typ.OpPush, tmp[i+1])
   200  	}
   201  	table.codeSeg.WriteInst(typ.OpCreateObject, 0, 0)
   202  	return typ.RegA
   203  }
   204  
   205  // [array [a, b, c, ...]]
   206  func compileArray(table *symTable, node parser.ExprList) uint16 {
   207  	nodes := table.collapse(true, node...)
   208  	for _, x := range nodes {
   209  		table.compileOpcode1Node(typ.OpPush, x)
   210  	}
   211  	table.codeSeg.WriteInst(typ.OpCreateArray, 0, 0)
   212  	return typ.RegA
   213  }
   214  
   215  func compileCall(table *symTable, node *parser.Call) uint16 {
   216  	tmp := table.collapse(true, append(node.Args, node.Callee)...)
   217  	callee := tmp[len(tmp)-1]
   218  	args := tmp[:len(tmp)-1]
   219  
   220  	switch len(args) {
   221  	case 0:
   222  		table.compileOpcode1Node(node.Op, callee)
   223  	case 1:
   224  		if node.Vararg {
   225  			table.compileOpcode1Node(typ.OpPushUnpack, args[0])
   226  			table.compileOpcode1Node(node.Op, callee)
   227  		} else {
   228  			table.compileOpcode2Node(node.Op, callee, args[0])
   229  			table.codeSeg.Code[len(table.codeSeg.Code)-1].OpcodeExt = 1
   230  		}
   231  	default:
   232  		for i := 0; i < len(args)-2; i++ {
   233  			table.compileOpcode1Node(typ.OpPush, args[i])
   234  		}
   235  		if node.Vararg {
   236  			table.compileOpcode1Node(typ.OpPush, args[len(args)-2])
   237  			table.compileOpcode1Node(typ.OpPushUnpack, args[len(args)-1])
   238  			table.compileOpcode1Node(node.Op, callee)
   239  		} else {
   240  			table.compileOpcode3Node(node.Op, callee, args[len(args)-2], args[len(args)-1])
   241  			table.codeSeg.Code[len(table.codeSeg.Code)-1].OpcodeExt = 2
   242  		}
   243  	}
   244  
   245  	table.codeSeg.WriteLineNum(node.Line)
   246  	table.releaseAddr(tmp)
   247  	return typ.RegA
   248  }
   249  
   250  func compileFunction(table *symTable, node *parser.Function) uint16 {
   251  	newtable := newSymTable(table.options)
   252  	newtable.name = table.name
   253  	newtable.codeSeg.Pos.Name = table.name
   254  	newtable.top = table.getTopTable()
   255  	newtable.parent = table
   256  
   257  	for i, p := range node.Args {
   258  		name := p.(*parser.Symbol).Name
   259  		if newtable.sym.Contains(name) {
   260  			table.panicnode(node, "duplicated parameter %q", name)
   261  		}
   262  		newtable.put(name, uint16(i))
   263  	}
   264  
   265  	if ln := newtable.sym.Len(); ln > 255 {
   266  		table.panicnode(node, "too many parameters (%d > 255)", ln)
   267  	}
   268  
   269  	newtable.vp = uint16(newtable.sym.Len())
   270  
   271  	if len(node.VargExpand) > 0 {
   272  		src := uint16(len(node.Args) - 1)
   273  		for i, dest := range node.VargExpand {
   274  			idx := newtable.borrowAddress()
   275  			newtable.put(dest.(*parser.Symbol).Name, idx)
   276  			newtable.codeSeg.WriteInst3(typ.OpLoad, src, table.loadConst(bas.Int(i)), idx)
   277  		}
   278  	}
   279  	newtable.compileNode(node.Body)
   280  	newtable.patchGoto()
   281  
   282  	if a, ok := newtable.sym.Get(staticSelf); ok {
   283  		newtable.codeSeg.Code = append([]typ.Inst{
   284  			{Opcode: typ.OpFunction, A: typ.RegA},
   285  			{Opcode: typ.OpSet, A: uint16(a.Int64()), B: typ.RegA},
   286  		}, newtable.codeSeg.Code...)
   287  		newtable.codeSeg.Pos.Offset += 2
   288  	}
   289  
   290  	if a, ok := newtable.sym.Get(staticThis); ok {
   291  		newtable.codeSeg.Code = append([]typ.Inst{
   292  			{Opcode: typ.OpSet, A: uint16(a.Int64()), B: typ.RegA},
   293  		}, newtable.codeSeg.Code...)
   294  		newtable.codeSeg.Pos.Offset += 1
   295  	}
   296  
   297  	code := newtable.codeSeg
   298  	code.WriteInst(typ.OpRet, typ.RegNil, 0) // return nil
   299  
   300  	localDeclare := table.borrowAddress()
   301  	table.put(bas.Str(node.Name), localDeclare)
   302  
   303  	var captureList []string
   304  	if table.top != nil {
   305  		captureList = table.symbolsToDebugLocals()
   306  	}
   307  
   308  	obj := bas.NewBareFunc(
   309  		node.Name,
   310  		node.Vararg,
   311  		byte(len(node.Args)),
   312  		newtable.vp,
   313  		newtable.symbolsToDebugLocals(),
   314  		captureList,
   315  		newtable.labelPos,
   316  		code,
   317  	)
   318  
   319  	fm := &table.getTopTable().funcsMap
   320  	fidx, _ := fm.Get(bas.Str(node.Name))
   321  	// Put function into constMap, it will then be put into coreStack after all compilings are done.
   322  	table.getTopTable().constMap.Set(obj.ToValue(), fidx)
   323  
   324  	table.codeSeg.WriteInst3(typ.OpFunction, uint16(fidx.Int()),
   325  		uint16(internal.IfInt(table.top == nil, 0, 1)),
   326  		typ.RegA,
   327  	)
   328  	table.codeSeg.WriteInst(typ.OpSet, localDeclare, typ.RegA)
   329  	table.codeSeg.WriteLineNum(node.Line)
   330  	return typ.RegA
   331  }
   332  
   333  func compileBreakContinue(table *symTable, node *parser.BreakContinue) uint16 {
   334  	if len(table.forLoops) == 0 {
   335  		table.panicnode(node, "outside loop")
   336  	}
   337  	bl := table.forLoops[len(table.forLoops)-1]
   338  	if !node.Break {
   339  		table.compileNode(bl.continueNode)
   340  		table.codeSeg.WriteJmpInst(typ.OpJmp, bl.continueGoto-len(table.codeSeg.Code)-1)
   341  	} else {
   342  		bl.breakContinuePos = append(bl.breakContinuePos, table.codeSeg.Len())
   343  		table.codeSeg.WriteJmpInst(typ.OpJmp, 0)
   344  	}
   345  	return typ.RegA
   346  }
   347  
   348  func compileLoop(table *symTable, node *parser.Loop) uint16 {
   349  	init := table.codeSeg.Len()
   350  	breaks := &breakLabel{
   351  		continueNode: node.Continue,
   352  		continueGoto: init,
   353  	}
   354  
   355  	table.forLoops = append(table.forLoops, breaks)
   356  	table.addMaskedSymTable()
   357  	table.compileNode(node.Body)
   358  	table.removeMaskedSymTable()
   359  	table.forLoops = table.forLoops[:len(table.forLoops)-1]
   360  
   361  	table.codeSeg.WriteJmpInst(typ.OpJmp, -(table.codeSeg.Len()-init)-1)
   362  	for _, idx := range breaks.breakContinuePos {
   363  		table.codeSeg.Code[idx] = typ.JmpInst(typ.OpJmp, table.codeSeg.Len()-idx-1)
   364  	}
   365  	return typ.RegA
   366  }
   367  
   368  func compileGotoLabel(table *symTable, node *parser.GotoLabel) uint16 {
   369  	if !node.Goto {
   370  		if table.labelPos == nil {
   371  			table.labelPos = map[string]int{}
   372  		}
   373  		if _, ok := table.labelPos[node.Label]; ok {
   374  			table.panicnode(node, "duplicated label")
   375  		}
   376  		table.labelPos[node.Label] = table.codeSeg.Len()
   377  		return typ.RegA
   378  	}
   379  
   380  	if pos, ok := table.labelPos[node.Label]; ok {
   381  		table.codeSeg.WriteJmpInst(typ.OpJmp, pos-(table.codeSeg.Len()+1))
   382  	} else {
   383  		table.codeSeg.WriteJmpInst(typ.OpJmp, 0)
   384  		if table.forwardGoto == nil {
   385  			table.forwardGoto = map[int]*parser.GotoLabel{}
   386  		}
   387  		table.forwardGoto[table.codeSeg.Len()-1] = node
   388  	}
   389  	return typ.RegA
   390  }
   391  
   392  func (table *symTable) patchGoto() {
   393  	code := table.codeSeg.Code
   394  	for ipos, node := range table.forwardGoto {
   395  		pos, ok := table.labelPos[node.Label]
   396  		if !ok {
   397  			table.panicnode(node, "label not found")
   398  		}
   399  		code[ipos] = typ.JmpInst(typ.OpJmp, pos-(ipos+1))
   400  	}
   401  	for i := range code {
   402  		if code[i].Opcode == typ.OpJmp && code[i].D() != 0 {
   403  			// Group continuous jumps into one single jump
   404  			dest := int32(i) + code[i].D() + 1
   405  			for int(dest) < len(code) {
   406  				if c2 := code[dest]; c2.Opcode == typ.OpJmp && c2.D() != 0 {
   407  					dest += c2.D() + 1
   408  					continue
   409  				}
   410  				break
   411  			}
   412  			code[i] = code[i].SetD(dest - int32(i) - 1)
   413  		}
   414  		if code[i].Opcode == typ.OpJmp && i > 0 && (code[i-1].Opcode == typ.OpInc || code[i-1].OpcodeExt == typ.OpExtInc16) {
   415  			// Inc-then-small-jump, see OpInc in eval.go
   416  			if d := code[i].D() + 1; d >= math.MinInt16 && d <= math.MaxInt16 {
   417  				code[i-1].C = uint16(int16(d))
   418  			}
   419  		}
   420  	}
   421  }
   422  
   423  func compileRelease(table *symTable, node parser.Release) uint16 {
   424  	for _, s := range node {
   425  		s := s.Name
   426  		yx, _ := table.get(s)
   427  		table.releaseAddr(yx)
   428  		t := table.sym
   429  		if len(table.maskedSym) > 0 {
   430  			t = table.maskedSym[len(table.maskedSym)-1]
   431  		}
   432  		if !t.Contains(s) {
   433  			internal.ShouldNotHappen(node)
   434  		}
   435  		t.Delete(s)
   436  	}
   437  	return typ.RegA
   438  }
   439  
   440  // collapse will accept a list of expressions, each of them will be collapsed into a temporal variable
   441  // and become an Address node. If optLast is true, the last expression will be directly using regA.
   442  func (table *symTable) collapse(optLast bool, nodes ...parser.Node) []parser.Node {
   443  	var lastNode parser.Node
   444  	var lastNodeIndex int
   445  
   446  	for i, n := range nodes {
   447  		switch n.(type) {
   448  		case parser.Address, parser.Primitive, *parser.Symbol:
   449  			// No need to collapse
   450  		case *parser.If:
   451  			// 'if' is special because it can be used as an expresison, we can't optimize just one branch.
   452  			// e.g.: if(cond, a[0], a[1])
   453  			tmp := table.borrowAddress()
   454  			res := compileIf(table, n.(*parser.If))
   455  			table.codeSeg.Code = append(table.codeSeg.Code, typ.Inst{Opcode: typ.OpSet, A: tmp, B: res})
   456  			nodes[i] = parser.Address(tmp)
   457  			lastNode, lastNodeIndex = n, i
   458  		default:
   459  			tmp := table.borrowAddress()
   460  			table.codeSeg.WriteInst(typ.OpSet, tmp, table.compileNode(n))
   461  			nodes[i] = parser.Address(tmp)
   462  			lastNode, lastNodeIndex = n, i
   463  		}
   464  	}
   465  
   466  	if optLast && lastNode != nil {
   467  		if i := table.codeSeg.LastInst(); i.Opcode == typ.OpSet && i.B == typ.RegA {
   468  			// [set something $a]
   469  			table.codeSeg.TruncLast()
   470  			table.releaseAddr(i.A)
   471  			nodes[lastNodeIndex] = parser.Address(uint16(i.B))
   472  		}
   473  	}
   474  	return nodes
   475  }