github.com/coyove/nj@v0.0.0-20221110084952-c7f8db1065c3/compile.go (about) 1 package nj 2 3 import ( 4 "math" 5 6 "github.com/coyove/nj/bas" 7 "github.com/coyove/nj/internal" 8 "github.com/coyove/nj/parser" 9 "github.com/coyove/nj/typ" 10 ) 11 12 // [prog expr1 expr2 ...] 13 func compileProgBlock(table *symTable, node *parser.Prog) uint16 { 14 if node.DoBlock { 15 table.addMaskedSymTable() 16 } 17 18 yx := uint16(typ.RegA) 19 for _, a := range node.Stats { 20 if a == nil { 21 continue 22 } 23 switch a.(type) { 24 case parser.Address, parser.Primitive, *parser.Symbol: 25 // e.g.: [prog "a string"] will be transformed into: [prog [set $a "a string"]] 26 yx = table.compileNode(a) 27 table.codeSeg.WriteInst(typ.OpSet, typ.RegA, yx) 28 default: 29 yx = table.compileNode(a) 30 } 31 32 table.releaseAddr(table.pendingReleases) 33 table.pendingReleases = table.pendingReleases[:0] 34 } 35 36 if node.DoBlock { 37 table.removeMaskedSymTable() 38 } 39 return yx 40 } 41 42 // local a = b 43 func compileDeclare(table *symTable, node *parser.Declare) uint16 { 44 dest := node.Name.Name 45 if bas.GetTopIndex(dest) > 0 || dest == staticTrue || dest == staticFalse || dest == staticThis || dest == staticSelf { 46 table.panicnode(node.Name, "can't bound to a global static name") 47 } 48 49 destAddr := table.borrowAddress() 50 defer table.put(dest, destAddr) // execute in defer in case of: a = 1 do local a = a end 51 table.codeSeg.WriteInst(typ.OpSet, destAddr, table.compileNode(node.Value)) 52 table.codeSeg.WriteLineNum(node.Line) 53 return destAddr 54 } 55 56 // a = b 57 func compileAssign(table *symTable, node *parser.Assign) uint16 { 58 dest := node.Name.Name 59 if bas.GetTopIndex(dest) > 0 || dest == staticTrue || dest == staticFalse || dest == staticThis || dest == staticSelf { 60 table.panicnode(node.Name, "can't assign to a global static name") 61 } 62 destAddr, declared := table.get(dest) 63 if !declared { 64 // a is not declared yet 65 destAddr = table.borrowAddress() 66 67 // Do not use t.put() because it may put the symbol into masked tables 68 // e.g.: do a = 1 end 69 table.sym.Set(dest, bas.Int64(int64(destAddr))) 70 } else { 71 } 72 table.codeSeg.WriteInst(typ.OpSet, destAddr, table.compileNode(node.Value)) 73 table.codeSeg.WriteLineNum(node.Line) 74 return destAddr 75 } 76 77 func compileUnary(table *symTable, node *parser.Unary) uint16 { 78 nodes := table.collapse(true, node.A) 79 table.compileOpcode1Node(node.Op, nodes[0]) 80 table.releaseAddr(nodes) 81 table.codeSeg.WriteLineNum(node.Line) 82 return typ.RegA 83 } 84 85 func compileBinary(table *symTable, node *parser.Binary) uint16 { 86 if node.Op >= typ.OpExtBitAnd && node.Op <= typ.OpExtBitURsh { 87 return compileBitwise(table, node) 88 } 89 nodes := table.collapse(true, node.A, node.B) 90 table.compileOpcode2Node(node.Op, nodes[0], nodes[1]) 91 table.releaseAddr(nodes) 92 table.codeSeg.WriteLineNum(node.Line) 93 return typ.RegA 94 } 95 96 func compileTenary(table *symTable, node *parser.Tenary) uint16 { 97 nodes := table.collapse(true, node.A, node.B, node.C) 98 table.compileOpcode3Node(node.Op, nodes[0], nodes[1], nodes[2]) 99 table.releaseAddr(nodes) 100 table.codeSeg.WriteLineNum(node.Line) 101 return typ.RegA 102 } 103 104 func compileBitwise(table *symTable, node *parser.Binary) uint16 { 105 nodes := table.collapse(true, node.A, node.B) 106 a, b := nodes[0], nodes[1] 107 switch node.Op { 108 case typ.OpExtBitAnd, typ.OpExtBitOr, typ.OpExtBitXor: 109 if a16, ok := toInt16(a); ok { 110 table.compileOpcode2Node(typ.OpExt, b, parser.Address(a16)) 111 node.Op = typ.OpExtBitAnd16 + node.Op - typ.OpExtBitAnd 112 } else if b16, ok := toInt16(b); ok { 113 table.compileOpcode2Node(typ.OpExt, a, parser.Address(b16)) 114 node.Op = typ.OpExtBitAnd16 + node.Op - typ.OpExtBitAnd 115 } else { 116 table.compileOpcode2Node(typ.OpExt, a, b) 117 } 118 case typ.OpExtBitRsh, typ.OpExtBitLsh, typ.OpExtBitURsh: 119 if b16, ok := toInt16(b); ok { 120 table.compileOpcode2Node(typ.OpExt, a, parser.Address(b16)) 121 node.Op = typ.OpExtBitAnd16 + node.Op - typ.OpExtBitAnd 122 } else { 123 table.compileOpcode2Node(typ.OpExt, a, b) 124 } 125 } 126 table.releaseAddr(nodes) 127 table.codeSeg.WriteLineNum(node.Line) 128 table.codeSeg.Code[len(table.codeSeg.Code)-1].OpcodeExt = node.Op 129 return typ.RegA 130 } 131 132 // [and a b] => $a = a if not a then goto out else $a = b end ::out:: 133 func compileAnd(table *symTable, node *parser.And) uint16 { 134 table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.A) 135 136 table.codeSeg.WriteJmpInst(typ.OpJmpFalse, 0) 137 part1 := table.codeSeg.Len() 138 139 table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.B) 140 part2 := table.codeSeg.Len() 141 142 table.codeSeg.Code[part1-1] = typ.JmpInst(typ.OpJmpFalse, part2-part1) 143 return typ.RegA 144 } 145 146 // [or a b] => $a = a if not a then $a = b end 147 func compileOr(table *symTable, node *parser.Or) uint16 { 148 table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.A) 149 150 table.codeSeg.WriteJmpInst(typ.OpJmpFalse, 1) 151 table.codeSeg.WriteJmpInst(typ.OpJmp, 0) 152 part1 := table.codeSeg.Len() 153 154 table.compileOpcode2Node(typ.OpSet, parser.Address(typ.RegA), node.B) 155 part2 := table.codeSeg.Len() 156 157 table.codeSeg.Code[part1-1] = typ.JmpInst(typ.OpJmp, part2-part1) 158 return typ.RegA 159 } 160 161 func compileIf(table *symTable, node *parser.If) uint16 { 162 condyx := table.compileNode(node.Cond) 163 if condyx != typ.RegA { 164 table.codeSeg.WriteInst(typ.OpSet, typ.RegA, condyx) 165 } 166 167 table.codeSeg.WriteJmpInst(typ.OpJmpFalse, 0) 168 init := table.codeSeg.Len() 169 170 table.addMaskedSymTable() 171 table.compileNode(node.True) 172 part1 := table.codeSeg.Len() 173 174 table.codeSeg.WriteJmpInst(typ.OpJmp, 0) 175 176 if node.False != nil { 177 table.compileNode(node.False) 178 part2 := table.codeSeg.Len() 179 180 table.removeMaskedSymTable() 181 182 table.codeSeg.Code[init-1] = typ.JmpInst(typ.OpJmpFalse, part1-init+1) 183 table.codeSeg.Code[part1] = typ.JmpInst(typ.OpJmp, part2-part1-1) 184 } else { 185 table.removeMaskedSymTable() 186 187 // The last inst is used to skip the false branch, since we don't have one, we don't need this jmp 188 table.codeSeg.TruncLast() 189 table.codeSeg.Code[init-1] = typ.JmpInst(typ.OpJmpFalse, part1-init) 190 } 191 return typ.RegA 192 } 193 194 // [object [k1, v1, k2, v2, ...]] 195 func compileObject(table *symTable, node parser.ExprAssignList) uint16 { 196 tmp := table.collapse(true, node.ExpandAsExprList()...) 197 for i := 0; i < len(tmp); i += 2 { 198 table.compileOpcode1Node(typ.OpPush, tmp[i]) 199 table.compileOpcode1Node(typ.OpPush, tmp[i+1]) 200 } 201 table.codeSeg.WriteInst(typ.OpCreateObject, 0, 0) 202 return typ.RegA 203 } 204 205 // [array [a, b, c, ...]] 206 func compileArray(table *symTable, node parser.ExprList) uint16 { 207 nodes := table.collapse(true, node...) 208 for _, x := range nodes { 209 table.compileOpcode1Node(typ.OpPush, x) 210 } 211 table.codeSeg.WriteInst(typ.OpCreateArray, 0, 0) 212 return typ.RegA 213 } 214 215 func compileCall(table *symTable, node *parser.Call) uint16 { 216 tmp := table.collapse(true, append(node.Args, node.Callee)...) 217 callee := tmp[len(tmp)-1] 218 args := tmp[:len(tmp)-1] 219 220 switch len(args) { 221 case 0: 222 table.compileOpcode1Node(node.Op, callee) 223 case 1: 224 if node.Vararg { 225 table.compileOpcode1Node(typ.OpPushUnpack, args[0]) 226 table.compileOpcode1Node(node.Op, callee) 227 } else { 228 table.compileOpcode2Node(node.Op, callee, args[0]) 229 table.codeSeg.Code[len(table.codeSeg.Code)-1].OpcodeExt = 1 230 } 231 default: 232 for i := 0; i < len(args)-2; i++ { 233 table.compileOpcode1Node(typ.OpPush, args[i]) 234 } 235 if node.Vararg { 236 table.compileOpcode1Node(typ.OpPush, args[len(args)-2]) 237 table.compileOpcode1Node(typ.OpPushUnpack, args[len(args)-1]) 238 table.compileOpcode1Node(node.Op, callee) 239 } else { 240 table.compileOpcode3Node(node.Op, callee, args[len(args)-2], args[len(args)-1]) 241 table.codeSeg.Code[len(table.codeSeg.Code)-1].OpcodeExt = 2 242 } 243 } 244 245 table.codeSeg.WriteLineNum(node.Line) 246 table.releaseAddr(tmp) 247 return typ.RegA 248 } 249 250 func compileFunction(table *symTable, node *parser.Function) uint16 { 251 newtable := newSymTable(table.options) 252 newtable.name = table.name 253 newtable.codeSeg.Pos.Name = table.name 254 newtable.top = table.getTopTable() 255 newtable.parent = table 256 257 for i, p := range node.Args { 258 name := p.(*parser.Symbol).Name 259 if newtable.sym.Contains(name) { 260 table.panicnode(node, "duplicated parameter %q", name) 261 } 262 newtable.put(name, uint16(i)) 263 } 264 265 if ln := newtable.sym.Len(); ln > 255 { 266 table.panicnode(node, "too many parameters (%d > 255)", ln) 267 } 268 269 newtable.vp = uint16(newtable.sym.Len()) 270 271 if len(node.VargExpand) > 0 { 272 src := uint16(len(node.Args) - 1) 273 for i, dest := range node.VargExpand { 274 idx := newtable.borrowAddress() 275 newtable.put(dest.(*parser.Symbol).Name, idx) 276 newtable.codeSeg.WriteInst3(typ.OpLoad, src, table.loadConst(bas.Int(i)), idx) 277 } 278 } 279 newtable.compileNode(node.Body) 280 newtable.patchGoto() 281 282 if a, ok := newtable.sym.Get(staticSelf); ok { 283 newtable.codeSeg.Code = append([]typ.Inst{ 284 {Opcode: typ.OpFunction, A: typ.RegA}, 285 {Opcode: typ.OpSet, A: uint16(a.Int64()), B: typ.RegA}, 286 }, newtable.codeSeg.Code...) 287 newtable.codeSeg.Pos.Offset += 2 288 } 289 290 if a, ok := newtable.sym.Get(staticThis); ok { 291 newtable.codeSeg.Code = append([]typ.Inst{ 292 {Opcode: typ.OpSet, A: uint16(a.Int64()), B: typ.RegA}, 293 }, newtable.codeSeg.Code...) 294 newtable.codeSeg.Pos.Offset += 1 295 } 296 297 code := newtable.codeSeg 298 code.WriteInst(typ.OpRet, typ.RegNil, 0) // return nil 299 300 localDeclare := table.borrowAddress() 301 table.put(bas.Str(node.Name), localDeclare) 302 303 var captureList []string 304 if table.top != nil { 305 captureList = table.symbolsToDebugLocals() 306 } 307 308 obj := bas.NewBareFunc( 309 node.Name, 310 node.Vararg, 311 byte(len(node.Args)), 312 newtable.vp, 313 newtable.symbolsToDebugLocals(), 314 captureList, 315 newtable.labelPos, 316 code, 317 ) 318 319 fm := &table.getTopTable().funcsMap 320 fidx, _ := fm.Get(bas.Str(node.Name)) 321 // Put function into constMap, it will then be put into coreStack after all compilings are done. 322 table.getTopTable().constMap.Set(obj.ToValue(), fidx) 323 324 table.codeSeg.WriteInst3(typ.OpFunction, uint16(fidx.Int()), 325 uint16(internal.IfInt(table.top == nil, 0, 1)), 326 typ.RegA, 327 ) 328 table.codeSeg.WriteInst(typ.OpSet, localDeclare, typ.RegA) 329 table.codeSeg.WriteLineNum(node.Line) 330 return typ.RegA 331 } 332 333 func compileBreakContinue(table *symTable, node *parser.BreakContinue) uint16 { 334 if len(table.forLoops) == 0 { 335 table.panicnode(node, "outside loop") 336 } 337 bl := table.forLoops[len(table.forLoops)-1] 338 if !node.Break { 339 table.compileNode(bl.continueNode) 340 table.codeSeg.WriteJmpInst(typ.OpJmp, bl.continueGoto-len(table.codeSeg.Code)-1) 341 } else { 342 bl.breakContinuePos = append(bl.breakContinuePos, table.codeSeg.Len()) 343 table.codeSeg.WriteJmpInst(typ.OpJmp, 0) 344 } 345 return typ.RegA 346 } 347 348 func compileLoop(table *symTable, node *parser.Loop) uint16 { 349 init := table.codeSeg.Len() 350 breaks := &breakLabel{ 351 continueNode: node.Continue, 352 continueGoto: init, 353 } 354 355 table.forLoops = append(table.forLoops, breaks) 356 table.addMaskedSymTable() 357 table.compileNode(node.Body) 358 table.removeMaskedSymTable() 359 table.forLoops = table.forLoops[:len(table.forLoops)-1] 360 361 table.codeSeg.WriteJmpInst(typ.OpJmp, -(table.codeSeg.Len()-init)-1) 362 for _, idx := range breaks.breakContinuePos { 363 table.codeSeg.Code[idx] = typ.JmpInst(typ.OpJmp, table.codeSeg.Len()-idx-1) 364 } 365 return typ.RegA 366 } 367 368 func compileGotoLabel(table *symTable, node *parser.GotoLabel) uint16 { 369 if !node.Goto { 370 if table.labelPos == nil { 371 table.labelPos = map[string]int{} 372 } 373 if _, ok := table.labelPos[node.Label]; ok { 374 table.panicnode(node, "duplicated label") 375 } 376 table.labelPos[node.Label] = table.codeSeg.Len() 377 return typ.RegA 378 } 379 380 if pos, ok := table.labelPos[node.Label]; ok { 381 table.codeSeg.WriteJmpInst(typ.OpJmp, pos-(table.codeSeg.Len()+1)) 382 } else { 383 table.codeSeg.WriteJmpInst(typ.OpJmp, 0) 384 if table.forwardGoto == nil { 385 table.forwardGoto = map[int]*parser.GotoLabel{} 386 } 387 table.forwardGoto[table.codeSeg.Len()-1] = node 388 } 389 return typ.RegA 390 } 391 392 func (table *symTable) patchGoto() { 393 code := table.codeSeg.Code 394 for ipos, node := range table.forwardGoto { 395 pos, ok := table.labelPos[node.Label] 396 if !ok { 397 table.panicnode(node, "label not found") 398 } 399 code[ipos] = typ.JmpInst(typ.OpJmp, pos-(ipos+1)) 400 } 401 for i := range code { 402 if code[i].Opcode == typ.OpJmp && code[i].D() != 0 { 403 // Group continuous jumps into one single jump 404 dest := int32(i) + code[i].D() + 1 405 for int(dest) < len(code) { 406 if c2 := code[dest]; c2.Opcode == typ.OpJmp && c2.D() != 0 { 407 dest += c2.D() + 1 408 continue 409 } 410 break 411 } 412 code[i] = code[i].SetD(dest - int32(i) - 1) 413 } 414 if code[i].Opcode == typ.OpJmp && i > 0 && (code[i-1].Opcode == typ.OpInc || code[i-1].OpcodeExt == typ.OpExtInc16) { 415 // Inc-then-small-jump, see OpInc in eval.go 416 if d := code[i].D() + 1; d >= math.MinInt16 && d <= math.MaxInt16 { 417 code[i-1].C = uint16(int16(d)) 418 } 419 } 420 } 421 } 422 423 func compileRelease(table *symTable, node parser.Release) uint16 { 424 for _, s := range node { 425 s := s.Name 426 yx, _ := table.get(s) 427 table.releaseAddr(yx) 428 t := table.sym 429 if len(table.maskedSym) > 0 { 430 t = table.maskedSym[len(table.maskedSym)-1] 431 } 432 if !t.Contains(s) { 433 internal.ShouldNotHappen(node) 434 } 435 t.Delete(s) 436 } 437 return typ.RegA 438 } 439 440 // collapse will accept a list of expressions, each of them will be collapsed into a temporal variable 441 // and become an Address node. If optLast is true, the last expression will be directly using regA. 442 func (table *symTable) collapse(optLast bool, nodes ...parser.Node) []parser.Node { 443 var lastNode parser.Node 444 var lastNodeIndex int 445 446 for i, n := range nodes { 447 switch n.(type) { 448 case parser.Address, parser.Primitive, *parser.Symbol: 449 // No need to collapse 450 case *parser.If: 451 // 'if' is special because it can be used as an expresison, we can't optimize just one branch. 452 // e.g.: if(cond, a[0], a[1]) 453 tmp := table.borrowAddress() 454 res := compileIf(table, n.(*parser.If)) 455 table.codeSeg.Code = append(table.codeSeg.Code, typ.Inst{Opcode: typ.OpSet, A: tmp, B: res}) 456 nodes[i] = parser.Address(tmp) 457 lastNode, lastNodeIndex = n, i 458 default: 459 tmp := table.borrowAddress() 460 table.codeSeg.WriteInst(typ.OpSet, tmp, table.compileNode(n)) 461 nodes[i] = parser.Address(tmp) 462 lastNode, lastNodeIndex = n, i 463 } 464 } 465 466 if optLast && lastNode != nil { 467 if i := table.codeSeg.LastInst(); i.Opcode == typ.OpSet && i.B == typ.RegA { 468 // [set something $a] 469 table.codeSeg.TruncLast() 470 table.releaseAddr(i.A) 471 nodes[lastNodeIndex] = parser.Address(uint16(i.B)) 472 } 473 } 474 return nodes 475 }