github.com/coyove/nj@v0.0.0-20221110084952-c7f8db1065c3/table.go (about)

     1  package nj
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"math"
     7  
     8  	"github.com/coyove/nj/bas"
     9  	"github.com/coyove/nj/internal"
    10  	"github.com/coyove/nj/parser"
    11  	"github.com/coyove/nj/typ"
    12  )
    13  
    14  type breakLabel struct {
    15  	continueNode     parser.Node
    16  	continueGoto     int
    17  	breakContinuePos []int
    18  }
    19  
    20  // symTable is responsible for recording the state of compilation
    21  type symTable struct {
    22  	name    string
    23  	options *LoadOptions
    24  
    25  	// toplevel symtable
    26  	top, parent *symTable
    27  
    28  	codeSeg internal.Packet
    29  
    30  	// variable lookup
    31  	sym       bas.Map   // str -> address: uint16
    32  	maskedSym []bas.Map // str -> address: uint16
    33  
    34  	forLoops []*breakLabel
    35  
    36  	pendingReleases []uint16
    37  
    38  	vp uint16
    39  
    40  	constMap bas.Map // value -> address: uint16
    41  	funcsMap bas.Map // func name -> address: uint16
    42  
    43  	reusableTmps      bas.Map // address: uint16 -> used: bool
    44  	reusableTmpsArray []uint16
    45  
    46  	forwardGoto map[int]*parser.GotoLabel // position to goto label node
    47  	labelPos    map[string]int            // label name to position
    48  }
    49  
    50  func newSymTable(opt *LoadOptions) *symTable {
    51  	t := &symTable{
    52  		options: opt,
    53  	}
    54  	return t
    55  }
    56  
    57  func (table *symTable) panicnode(node parser.GetLine, msg string, args ...interface{}) {
    58  	who, line := node.GetLine()
    59  	panic(fmt.Sprintf("%q at %s:%d\t", who, table.name, line) + fmt.Sprintf(msg, args...))
    60  }
    61  
    62  func (table *symTable) symbolsToDebugLocals() []string {
    63  	x := make([]string, table.vp)
    64  	table.sym.Foreach(func(sym bas.Value, addr *bas.Value) bool {
    65  		x[addr.Int64()] = sym.Str()
    66  		return true
    67  	})
    68  	for _, s := range table.maskedSym {
    69  		s.Foreach(func(sym bas.Value, addr *bas.Value) bool {
    70  			x[addr.Int64()] = sym.Str()
    71  			return true
    72  		})
    73  	}
    74  	return x
    75  }
    76  
    77  func (table *symTable) borrowAddress() uint16 {
    78  	if len(table.reusableTmpsArray) > 0 {
    79  		tmp := bas.Int64(int64(table.reusableTmpsArray[0]))
    80  		table.reusableTmpsArray = table.reusableTmpsArray[1:]
    81  		if v, _ := table.reusableTmps.Get(tmp); v.IsFalse() {
    82  			internal.ShouldNotHappen()
    83  		}
    84  		table.reusableTmps.Set(tmp, bas.False)
    85  		return uint16(tmp.Int64())
    86  	}
    87  	if table.vp > typ.RegMaxAddress {
    88  		panic("too many variables in a single scope")
    89  	}
    90  	return table.borrowAddressNoReuse()
    91  }
    92  
    93  func (table *symTable) borrowAddressNoReuse() uint16 {
    94  	table.reusableTmps.Set(bas.Int64(int64(table.vp)), bas.False)
    95  	table.vp++
    96  	return table.vp - 1
    97  }
    98  
    99  func (table *symTable) releaseAddr(a interface{}) {
   100  	switch a := a.(type) {
   101  	case []parser.Node:
   102  		for _, n := range a {
   103  			if a, ok := n.(parser.Address); ok {
   104  				table.releaseAddr(uint16(a))
   105  			}
   106  		}
   107  	case []uint16:
   108  		for _, n := range a {
   109  			table.releaseAddr(n)
   110  		}
   111  	case uint16:
   112  		if a == typ.RegA {
   113  			return
   114  		}
   115  		if a > typ.RegLocalMask {
   116  			// We don't free global variables
   117  			return
   118  		}
   119  		if available, ok := table.reusableTmps.Get(bas.Int64(int64(a))); ok && available.IsFalse() {
   120  			table.reusableTmpsArray = append(table.reusableTmpsArray, a)
   121  			table.reusableTmps.Set(bas.Int64(int64(a)), bas.True)
   122  		}
   123  	default:
   124  		internal.ShouldNotHappen()
   125  	}
   126  }
   127  
   128  var (
   129  	staticNil   = parser.SNil.Name
   130  	staticTrue  = bas.Str("true")
   131  	staticFalse = bas.Str("false")
   132  	staticThis  = bas.Str("this")
   133  	staticSelf  = bas.Str("self")
   134  	staticA     = parser.Sa.Name
   135  )
   136  
   137  func (table *symTable) get(name bas.Value) (uint16, bool) {
   138  	stubLoad := func(name bas.Value) uint16 {
   139  		k, _ := table.sym.Get(name)
   140  		if k.Type() == typ.Number {
   141  			return uint16(k.Int64())
   142  		}
   143  		k = bas.Int64(int64(table.borrowAddressNoReuse()))
   144  		table.sym.Set(name, k)
   145  		return uint16(k.Int64())
   146  	}
   147  
   148  	switch name {
   149  	case staticNil:
   150  		return typ.RegNil, true
   151  	case staticTrue:
   152  		return table.loadConst(bas.True), true
   153  	case staticFalse:
   154  		return table.loadConst(bas.False), true
   155  	case staticThis, staticSelf:
   156  		return stubLoad(name), true
   157  	case staticA:
   158  		return typ.RegA, true
   159  	}
   160  
   161  	calc := func(k uint16, depth uint16) (uint16, bool) {
   162  		addr := (depth << 15) | (k & typ.RegLocalMask)
   163  		return addr, true
   164  	}
   165  
   166  	// Firstly we will iterate local masked symbols,
   167  	// which are local variables inside do-blocks, like "if then .. end" and "do ... end".
   168  	// The rightmost map of this slice is the innermost do-block.
   169  	for i := len(table.maskedSym) - 1; i >= 0; i-- {
   170  		if k, ok := table.maskedSym[i].Get(name); ok {
   171  			return calc(uint16(k.Int64()), 0)
   172  		}
   173  	}
   174  
   175  	// Then local variables.
   176  	if k, ok := table.sym.Get(name); ok {
   177  		return calc(uint16(k.Int64()), 0)
   178  	}
   179  
   180  	// // If parent exists and parent != top, it means we are inside a closure.
   181  	// for p := table.parent; p != nil && p != table.top; p = p.parent {
   182  	// 	if _, ok := p.sym.Get(name); ok {
   183  	// 		// self := stubLoad(staticSelf)
   184  	// 		table.codeSeg.WriteInst3(typ.OpLoad, self)
   185  	// 	}
   186  	// }
   187  
   188  	// Finally top variables.
   189  	if table.top != nil {
   190  		if k, ok := table.top.sym.Get(name); ok {
   191  			return calc(uint16(k.Int64()), 1)
   192  		}
   193  	}
   194  
   195  	return typ.RegNil, false
   196  }
   197  
   198  func (table *symTable) put(name bas.Value, addr uint16) {
   199  	if addr == typ.RegA {
   200  		internal.ShouldNotHappen()
   201  	}
   202  	sym := bas.Int64(int64(addr))
   203  	if len(table.maskedSym) > 0 {
   204  		table.maskedSym[len(table.maskedSym)-1].Set(name, sym)
   205  	} else {
   206  		table.sym.Set(name, sym)
   207  	}
   208  }
   209  
   210  func (table *symTable) addMaskedSymTable() {
   211  	table.maskedSym = append(table.maskedSym, bas.Map{})
   212  }
   213  
   214  func (table *symTable) removeMaskedSymTable() {
   215  	table.maskedSym[len(table.maskedSym)-1].Foreach(func(sym bas.Value, addr *bas.Value) bool {
   216  		table.releaseAddr(uint16(addr.Int64()))
   217  		return true
   218  	})
   219  	table.maskedSym = table.maskedSym[:len(table.maskedSym)-1]
   220  }
   221  
   222  func (table *symTable) loadConst(v bas.Value) uint16 {
   223  	if table.top != nil {
   224  		return table.top.loadConst(v)
   225  	}
   226  	if i, ok := table.constMap.Get(v); ok {
   227  		return uint16(i.Int64())
   228  	}
   229  	panic("loadConst: shouldn't happen")
   230  }
   231  
   232  func (table *symTable) compileOpcode1Node(op byte, n parser.Node) {
   233  	addr, ok := table.compileStaticNode(n)
   234  	if !ok {
   235  		table.codeSeg.WriteInst(op, table.compileNode(n), 0)
   236  	} else {
   237  		table.codeSeg.WriteInst(op, addr, 0)
   238  	}
   239  }
   240  
   241  func (table *symTable) compileAtom(n parser.Node, releases *[]uint16) uint16 {
   242  	addr, ok := table.compileStaticNode(n)
   243  	if !ok {
   244  		addr := table.borrowAddress()
   245  		table.codeSeg.WriteInst(typ.OpSet, addr, table.compileNode(n))
   246  		*releases = append(*releases, addr)
   247  		return addr
   248  	}
   249  	return addr
   250  }
   251  
   252  func (table *symTable) compileOpcode2Node(op byte, n0, n1 parser.Node) {
   253  	var r []uint16
   254  	var __n0__16, __n0__ = toInt16(n0)
   255  	var __n1__16, __n1__ = toInt16(n1)
   256  	switch {
   257  	case op == typ.OpAdd && __n1__:
   258  		table.codeSeg.WriteInst2Ext(typ.OpExtAdd16, table.compileAtom(n0, &r), __n1__16)
   259  	case op == typ.OpAdd && __n0__:
   260  		table.codeSeg.WriteInst2Ext(typ.OpExtAdd16, table.compileAtom(n1, &r), __n0__16)
   261  	case op == typ.OpSub && __n0__:
   262  		table.codeSeg.WriteInst2Ext(typ.OpExtRSub16, table.compileAtom(n1, &r), __n0__16)
   263  	case op == typ.OpSub && __n1__:
   264  		table.codeSeg.WriteInst2Ext(typ.OpExtAdd16, table.compileAtom(n0, &r), uint16(-int16(__n1__16)))
   265  	case op == typ.OpEq && __n1__:
   266  		table.codeSeg.WriteInst2Ext(typ.OpExtEq16, table.compileAtom(n0, &r), __n1__16)
   267  	case op == typ.OpEq && __n0__:
   268  		table.codeSeg.WriteInst2Ext(typ.OpExtEq16, table.compileAtom(n1, &r), __n0__16)
   269  	case op == typ.OpNeq && __n1__:
   270  		table.codeSeg.WriteInst2Ext(typ.OpExtNeq16, table.compileAtom(n0, &r), __n1__16)
   271  	case op == typ.OpNeq && __n0__:
   272  		table.codeSeg.WriteInst2Ext(typ.OpExtNeq16, table.compileAtom(n1, &r), __n0__16)
   273  	case op == typ.OpLess && __n1__:
   274  		table.codeSeg.WriteInst2Ext(typ.OpExtLess16, table.compileAtom(n0, &r), __n1__16)
   275  	case op == typ.OpLess && __n0__:
   276  		table.codeSeg.WriteInst2Ext(typ.OpExtGreat16, table.compileAtom(n1, &r), __n0__16)
   277  	case op == typ.OpLessEq && __n1__ && int16(__n1__16) <= math.MaxInt16-1:
   278  		table.codeSeg.WriteInst2Ext(typ.OpExtLess16, table.compileAtom(n0, &r), uint16(int16(__n1__16+1)))
   279  	case op == typ.OpLessEq && __n0__ && int16(__n0__16) >= math.MinInt16+1:
   280  		table.codeSeg.WriteInst2Ext(typ.OpExtGreat16, table.compileAtom(n1, &r), uint16(int16(__n0__16-1)))
   281  	case op == typ.OpInc && __n1__:
   282  		table.codeSeg.WriteInst2Ext(typ.OpExtInc16, table.compileAtom(n0, &r), __n1__16)
   283  	default:
   284  		table.codeSeg.WriteInst(op, table.compileAtom(n0, &r), table.compileAtom(n1, &r))
   285  	}
   286  	table.releaseAddr(r)
   287  }
   288  
   289  func (table *symTable) compileOpcode3Node(op byte, n0, n1, n2 parser.Node) {
   290  	var r []uint16
   291  	var __n1__16, __n1__ = toInt16(n1)
   292  	switch {
   293  	case op == typ.OpLoad && __n1__:
   294  		table.codeSeg.WriteInst3Ext(typ.OpExtLoad16, table.compileAtom(n0, &r), __n1__16, table.compileAtom(n2, &r))
   295  	case op == typ.OpStore && __n1__:
   296  		table.codeSeg.WriteInst3Ext(typ.OpExtStore16, table.compileAtom(n0, &r), __n1__16, table.compileAtom(n2, &r))
   297  	default:
   298  		table.codeSeg.WriteInst3(op, table.compileAtom(n0, &r), table.compileAtom(n1, &r), table.compileAtom(n2, &r))
   299  	}
   300  	table.releaseAddr(r)
   301  }
   302  
   303  func (table *symTable) compileStaticNode(node parser.Node) (uint16, bool) {
   304  	switch v := node.(type) {
   305  	case parser.Address:
   306  		return uint16(v), true
   307  	case parser.Primitive:
   308  		return table.loadConst(bas.Value(v)), true
   309  	case *parser.Symbol:
   310  		idx, ok := table.get(v.Name)
   311  		if !ok {
   312  			if idx := bas.GetTopIndex(v.Name); idx > 0 {
   313  				c := table.borrowAddress()
   314  				table.codeSeg.WriteInst3(typ.OpLoadTop, uint16(idx), typ.RegPhantom, c)
   315  				table.pendingReleases = append(table.pendingReleases, c)
   316  				return c, true
   317  			}
   318  			table.panicnode(v, "symbol not defined")
   319  		}
   320  		return idx, true
   321  	}
   322  	return 0, false
   323  }
   324  
   325  func (table *symTable) compileNode(node parser.Node) uint16 {
   326  	if addr, ok := table.compileStaticNode(node); ok {
   327  		return addr
   328  	}
   329  
   330  	switch v := node.(type) {
   331  	case *parser.LoadConst:
   332  		table.constMap = v.Table
   333  		table.constMap.Foreach(func(k bas.Value, v *bas.Value) bool {
   334  			addr := int(typ.RegA | table.borrowAddressNoReuse())
   335  			*v = bas.Int(addr)
   336  			return true
   337  		})
   338  		table.funcsMap = v.Funcs
   339  		table.funcsMap.Foreach(func(k bas.Value, v *bas.Value) bool {
   340  			addr := int(table.borrowAddressNoReuse())
   341  			*v = bas.Int(typ.RegA | addr)
   342  			table.sym.Set(k, bas.Int(addr))
   343  			return true
   344  		})
   345  		return typ.RegA
   346  	case *parser.Prog:
   347  		return compileProgBlock(table, v)
   348  	case *parser.Declare:
   349  		return compileDeclare(table, v)
   350  	case *parser.Assign:
   351  		return compileAssign(table, v)
   352  	case parser.Release:
   353  		return compileRelease(table, v)
   354  	case *parser.Unary:
   355  		return compileUnary(table, v)
   356  	case *parser.Binary:
   357  		return compileBinary(table, v)
   358  	case *parser.Tenary:
   359  		return compileTenary(table, v)
   360  	case *parser.And:
   361  		return compileAnd(table, v)
   362  	case *parser.Or:
   363  		return compileOr(table, v)
   364  	case parser.ExprList:
   365  		return compileArray(table, v)
   366  	case parser.ExprAssignList:
   367  		return compileObject(table, v)
   368  	case *parser.GotoLabel:
   369  		return compileGotoLabel(table, v)
   370  	case *parser.Call:
   371  		return compileCall(table, v)
   372  	case *parser.If:
   373  		return compileIf(table, v)
   374  	case *parser.Loop:
   375  		return compileLoop(table, v)
   376  	case *parser.BreakContinue:
   377  		return compileBreakContinue(table, v)
   378  	case *parser.Function:
   379  		return compileFunction(table, v)
   380  	}
   381  
   382  	panic("compileNode: shouldn't happen")
   383  }
   384  
   385  func (table *symTable) getTopTable() *symTable {
   386  	if table.top != nil {
   387  		return table.top
   388  	}
   389  	return table
   390  }
   391  
   392  func compileNodeTopLevel(name, source string, n parser.Node, opt *LoadOptions) (cls *bas.Program, err error) {
   393  	defer internal.CatchError(&err)
   394  
   395  	table := newSymTable(opt)
   396  	table.name = name
   397  	table.codeSeg.Pos.Name = name
   398  	// Load nil first to ensure its address == 0
   399  	table.borrowAddress()
   400  
   401  	coreStack := []bas.Value{bas.Nil, bas.Nil}
   402  
   403  	push := func(k, v bas.Value) uint16 {
   404  		idx, ok := table.get(k)
   405  		if ok {
   406  			coreStack[idx] = v
   407  		} else {
   408  			idx = uint16(len(coreStack))
   409  			table.put(k, idx)
   410  			coreStack = append(coreStack, v)
   411  		}
   412  		return idx
   413  	}
   414  
   415  	if opt != nil {
   416  		opt.Globals.Foreach(func(k bas.Value, v *bas.Value) bool { push(k, *v); return true })
   417  	}
   418  
   419  	gi := push(bas.Str("Program"), bas.Nil)
   420  
   421  	table.vp = uint16(len(coreStack))
   422  
   423  	table.compileNode(n)
   424  	table.codeSeg.WriteInst(typ.OpRet, typ.RegA, 0)
   425  	table.patchGoto()
   426  
   427  	coreStack = append(coreStack, make([]bas.Value, int(table.vp)-len(coreStack))...)
   428  	table.constMap.Foreach(func(konst bas.Value, addr *bas.Value) bool {
   429  		coreStack[addr.Int64()&typ.RegLocalMask] = konst
   430  		return true
   431  	})
   432  
   433  	cls = bas.NewBareProgram(
   434  		coreStack,
   435  		bas.NewBareFunc(
   436  			"main",
   437  			false,
   438  			0,
   439  			table.vp,
   440  			table.symbolsToDebugLocals(),
   441  			nil,
   442  			table.labelPos,
   443  			table.codeSeg,
   444  		),
   445  		&table.sym,
   446  		&table.funcsMap)
   447  	cls.File = name
   448  	cls.Source = source
   449  	if opt != nil {
   450  		cls.MaxStackSize = opt.MaxStackSize
   451  		cls.Globals = opt.Globals
   452  		cls.Stdout = internal.Or(opt.Stdout, cls.Stdout).(io.Writer)
   453  		cls.Stderr = internal.Or(opt.Stderr, cls.Stderr).(io.Writer)
   454  		cls.Stdin = internal.Or(opt.Stdin, cls.Stdin).(io.Reader)
   455  	}
   456  
   457  	coreStack[gi] = bas.ValueOf(cls)
   458  	return cls, err
   459  }
   460  
   461  func toInt16(n parser.Node) (uint16, bool) {
   462  	if a, ok := n.(parser.Primitive); ok && bas.Value(a).IsInt64() {
   463  		a := bas.Value(a).UnsafeInt64()
   464  		if a >= math.MinInt16+1 && a <= math.MaxInt16 {
   465  			return uint16(int16(a)), true // don't take -1<<15 into consideration because we may negate n.
   466  		}
   467  	}
   468  	return 0, false
   469  }
   470  
   471  func isStrNode(n parser.Node) bool {
   472  	a, ok := n.(parser.Primitive)
   473  	return ok && bas.Value(a).IsString()
   474  }