github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/runtime/luacont.go (about)

     1  package runtime
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"unsafe"
     7  
     8  	"github.com/arnodel/golua/code"
     9  )
    10  
    11  // LuaCont is a Lua continuation, made from a closure, values for registers and
    12  // some state.
    13  type LuaCont struct {
    14  	*Closure
    15  	registers      []Value
    16  	cells          []Cell
    17  	pc             int16
    18  	acc            []Value
    19  	running        bool
    20  	borrowedCells  bool
    21  	closeStackBase int
    22  }
    23  
    24  var _ Cont = (*LuaCont)(nil)
    25  
    26  // NewLuaCont returns a new LuaCont from a closure and next, a continuation to
    27  // push results into.
    28  func NewLuaCont(t *Thread, clos *Closure, next Cont) *LuaCont {
    29  	if clos.upvalueIndex < len(clos.Upvalues) {
    30  		panic("Closure not ready")
    31  	}
    32  	var cells []Cell
    33  	borrowCells := clos.UpvalueCount == clos.CellCount
    34  	if borrowCells {
    35  		cells = clos.Upvalues
    36  	} else {
    37  		cells = t.cellPool.get(int(clos.CellCount))
    38  		copy(cells, clos.Upvalues)
    39  		t.RequireArrSize(unsafe.Sizeof(Cell{}), int(clos.CellCount))
    40  		for i := clos.UpvalueCount; i < clos.CellCount; i++ {
    41  			cells[i] = newCell(NilValue)
    42  		}
    43  	}
    44  	t.RequireArrSize(unsafe.Sizeof(Value{}), int(clos.RegCount))
    45  	registers := t.regPool.get(int(clos.RegCount))
    46  	registers[0] = ContValue(next)
    47  	cont := t.luaContPool.get()
    48  	t.RequireSize(unsafe.Sizeof(LuaCont{}))
    49  	*cont = LuaCont{
    50  		Closure:        clos,
    51  		registers:      registers,
    52  		cells:          cells,
    53  		borrowedCells:  borrowCells,
    54  		closeStackBase: t.closeStack.size(),
    55  	}
    56  	return cont
    57  }
    58  
    59  func (c *LuaCont) release(r *Runtime) {
    60  	r.regPool.release(c.registers)
    61  	r.ReleaseArrSize(unsafe.Sizeof(Value{}), int(c.RegCount))
    62  	if !c.borrowedCells {
    63  		r.ReleaseArrSize(unsafe.Sizeof(Cell{}), int(c.CellCount))
    64  		r.cellPool.release(c.cells)
    65  	}
    66  	r.luaContPool.release(c)
    67  	r.ReleaseSize(unsafe.Sizeof(LuaCont{}))
    68  }
    69  
    70  // Push implements Cont.Push.
    71  func (c *LuaCont) Push(r *Runtime, val Value) {
    72  	opcode := c.code[c.pc]
    73  	if opcode.HasType0() {
    74  		r.RequireCPU(1)
    75  		dst := opcode.GetA()
    76  		if opcode.GetF() {
    77  			// It's an etc
    78  			r.RequireSize(unsafe.Sizeof(Value{}))
    79  			c.acc = append(c.acc, val)
    80  		} else {
    81  			c.pc++
    82  			setReg(c.registers, c.cells, dst, val)
    83  		}
    84  	}
    85  }
    86  
    87  // PushEtc implements Cont.PushEtc.  TODO: optimise.
    88  func (c *LuaCont) PushEtc(r *Runtime, vals []Value) {
    89  	for _, val := range vals {
    90  		c.Push(r, val)
    91  	}
    92  }
    93  
    94  // Next implements Cont.Next.
    95  func (c *LuaCont) Next() Cont {
    96  	next, ok := c.registers[0].TryCont()
    97  	if !ok {
    98  		return nil
    99  	}
   100  	return next
   101  }
   102  
   103  func (c *LuaCont) Parent() Cont {
   104  	return c.Next()
   105  }
   106  
   107  // RunInThread implements Cont.RunInThread.
   108  func (c *LuaCont) RunInThread(t *Thread) (Cont, error) {
   109  	pc := c.pc
   110  	consts := c.consts
   111  	lines := c.lines
   112  	var lastLine int32
   113  	c.running = true
   114  	opcodes := c.code
   115  	regs := c.registers
   116  	cells := c.cells
   117  RunLoop:
   118  	for {
   119  		t.RequireCPU(1)
   120  
   121  		if t.DebugHooks.areFlagsEnabled(HookFlagLine) {
   122  			line := lines[pc]
   123  			if line > 0 && line != lastLine {
   124  				lastLine = line
   125  				if err := t.triggerLine(t, c, line); err != nil {
   126  					return nil, err
   127  				}
   128  			}
   129  		}
   130  		opcode := opcodes[pc]
   131  		if opcode.HasType1() {
   132  			dst := opcode.GetA()
   133  			x := getReg(regs, cells, opcode.GetB())
   134  			y := getReg(regs, cells, opcode.GetC())
   135  			var res Value
   136  			var err error
   137  			var ok bool
   138  			switch opcode.GetX() {
   139  
   140  			// Arithmetic
   141  
   142  			case code.OpAdd:
   143  				res, ok = Add(x, y)
   144  				if !ok {
   145  					res, err = binaryArithFallback(t, "__add", x, y)
   146  				}
   147  			case code.OpSub:
   148  				res, ok = Sub(x, y)
   149  				if !ok {
   150  					res, err = binaryArithFallback(t, "__sub", x, y)
   151  				}
   152  			case code.OpMul:
   153  				res, ok = Mul(x, y)
   154  				if !ok {
   155  					res, err = binaryArithFallback(t, "__mul", x, y)
   156  				}
   157  			case code.OpDiv:
   158  				res, ok = Div(x, y)
   159  				if !ok {
   160  					res, err = binaryArithFallback(t, "__div", x, y)
   161  				}
   162  			case code.OpFloorDiv:
   163  				res, ok, err = Idiv(x, y)
   164  				if !ok {
   165  					res, err = binaryArithFallback(t, "__idiv", x, y)
   166  				}
   167  			case code.OpMod:
   168  				res, ok, err = Mod(x, y)
   169  				if !ok {
   170  					res, err = binaryArithFallback(t, "__mod", x, y)
   171  				}
   172  			case code.OpPow:
   173  				res, ok = Pow(x, y)
   174  				if !ok {
   175  					res, err = binaryArithFallback(t, "__pow", x, y)
   176  				}
   177  
   178  			// Bitwise
   179  
   180  			case code.OpBitAnd:
   181  				res, err = band(t, x, y)
   182  			case code.OpBitOr:
   183  				res, err = bor(t, x, y)
   184  			case code.OpBitXor:
   185  				res, err = bxor(t, x, y)
   186  			case code.OpShiftL:
   187  				res, err = shl(t, x, y)
   188  			case code.OpShiftR:
   189  				res, err = shr(t, x, y)
   190  
   191  			// Comparison
   192  
   193  			case code.OpEq:
   194  				var r bool
   195  				r, err = eq(t, x, y)
   196  				res = BoolValue(r)
   197  			case code.OpLt:
   198  				var r bool
   199  				r, err = Lt(t, x, y)
   200  				res = BoolValue(r)
   201  			case code.OpLeq:
   202  				var r bool
   203  				r, err = le(t, x, y)
   204  				res = BoolValue(r)
   205  
   206  			// Concatenation
   207  
   208  			case code.OpConcat:
   209  				res, err = Concat(t, x, y)
   210  			default:
   211  				panic("unsupported")
   212  			}
   213  			if err != nil {
   214  				c.pc = pc
   215  				return nil, err
   216  			}
   217  			setReg(regs, cells, dst, res)
   218  			pc++
   219  			continue RunLoop
   220  		}
   221  		switch opcode.TypePfx() {
   222  		case code.Type0Pfx:
   223  			dst := opcode.GetA()
   224  			if opcode.GetF() {
   225  				// It's an etc
   226  				setReg(regs, cells, dst, ArrayValue(c.acc))
   227  			} else {
   228  				setReg(regs, cells, dst, NilValue)
   229  			}
   230  			pc++
   231  			continue RunLoop
   232  		case code.Type2Pfx:
   233  			reg := opcode.GetA()
   234  			coll := getReg(regs, cells, opcode.GetB())
   235  			idx := getReg(regs, cells, opcode.GetC())
   236  			if !opcode.GetF() {
   237  				val, err := Index(t, coll, idx)
   238  				if err != nil {
   239  					c.pc = pc
   240  					return nil, err
   241  				}
   242  				setReg(regs, cells, reg, val)
   243  			} else {
   244  				err := SetIndex(t, coll, idx, getReg(regs, cells, reg))
   245  				if err != nil {
   246  					c.pc = pc
   247  					return nil, err
   248  				}
   249  			}
   250  			pc++
   251  			continue RunLoop
   252  		case code.Type3Pfx:
   253  			n := opcode.GetN()
   254  			var val Value
   255  			switch opcode.GetY() {
   256  			case code.OpInt16:
   257  				val = IntValue(int64(int16(n)))
   258  			case code.OpStr2:
   259  				val = StringValue(string(code.Lit16(n).ToStr2()))
   260  			case code.OpK:
   261  				val = consts[n]
   262  			case code.OpClosureK:
   263  				val = FunctionValue(NewClosure(t.Runtime, consts[n].AsCode()))
   264  			default:
   265  				panic("Unsupported opcode")
   266  			}
   267  			dst := opcode.GetA()
   268  			if opcode.GetF() {
   269  				// dst must contain a continuation
   270  				cont := getReg(regs, cells, dst).AsCont()
   271  				cont.Push(t.Runtime, val)
   272  			} else {
   273  				setReg(regs, cells, dst, val)
   274  			}
   275  			pc++
   276  			continue RunLoop
   277  		case code.Type4Pfx:
   278  			dst := opcode.GetA()
   279  			var res Value
   280  			var ok bool
   281  			var err error
   282  			if opcode.HasType4a() {
   283  				val := getReg(regs, cells, opcode.GetB())
   284  				switch opcode.GetUnOp() {
   285  				case code.OpNeg:
   286  					res, ok = Unm(val)
   287  					if !ok {
   288  						res, err = unaryArithFallback(t, "__unm", val)
   289  					}
   290  				case code.OpBitNot:
   291  					res, err = bnot(t, val)
   292  				case code.OpLen:
   293  					res, err = Len(t, val)
   294  				case code.OpCont:
   295  					var cont Cont
   296  					cont, err = Continue(t, val, c)
   297  					res = ContValue(cont)
   298  				case code.OpTailCont:
   299  					var cont Cont
   300  					cont, err = Continue(t, val, c.Next())
   301  					res = ContValue(cont)
   302  				case code.OpId:
   303  					res = val
   304  				case code.OpEtcId:
   305  					// We assume it's a push?
   306  					cont := getReg(regs, cells, dst).AsCont()
   307  					cont.PushEtc(t.Runtime, val.AsArray())
   308  					pc++
   309  					continue RunLoop
   310  				case code.OpTruth:
   311  					res = BoolValue(Truth(val))
   312  				case code.OpNot:
   313  					res = BoolValue(!Truth(val))
   314  				case code.OpUpvalue:
   315  					// TODO: wasteful as we already have got getReg
   316  					cell := c.getRegCell(opcode.GetB())
   317  					getReg(regs, cells, dst).AsClosure().AddUpvalue(cell)
   318  					pc++
   319  					continue RunLoop
   320  				default:
   321  					panic("unsupported")
   322  				}
   323  			} else {
   324  				// Type 4b
   325  				switch code.UnOpK(opcode.GetUnOp()) {
   326  				case code.OpCC:
   327  					res = ContValue(c)
   328  				case code.OpTable:
   329  					res = TableValue(NewTable())
   330  				case code.OpStr0:
   331  					res = StringValue("")
   332  				case code.OpStr1:
   333  					res = StringValue(string(opcode.GetL().ToStr1()))
   334  				case code.OpBool:
   335  					res = BoolValue(opcode.GetL().ToBool())
   336  				case code.OpNil:
   337  					res = NilValue
   338  				case code.OpClear:
   339  					// Special case: clear reg
   340  					c.clearReg(dst)
   341  					pc++
   342  					continue RunLoop
   343  				default:
   344  					panic("unsupported")
   345  				}
   346  			}
   347  			if err != nil {
   348  				c.pc = pc
   349  				return nil, err
   350  			}
   351  			if opcode.GetF() {
   352  				getReg(regs, cells, dst).AsCont().Push(t.Runtime, res)
   353  			} else {
   354  				setReg(regs, cells, dst, res)
   355  			}
   356  			pc++
   357  			continue RunLoop
   358  		case code.Type5Pfx:
   359  			switch opcode.GetJ() {
   360  			case code.OpJump:
   361  				pc += int16(opcode.GetOffset())
   362  				continue RunLoop
   363  			case code.OpJumpIf:
   364  				test := Truth(getReg(regs, cells, opcode.GetA()))
   365  				if test == opcode.GetF() {
   366  					pc += int16(opcode.GetOffset())
   367  				} else {
   368  					pc++
   369  				}
   370  				continue RunLoop
   371  			case code.OpCall:
   372  				pc++
   373  				c.pc = pc
   374  				c.acc = nil
   375  				c.running = false
   376  				contReg := opcode.GetA()
   377  				isTail := opcode.GetF() // Can mean tail call or simple return
   378  				next := getReg(regs, cells, contReg).AsCont()
   379  
   380  				// We clear the register containing the continuation to allow
   381  				// garbage collection.  A continuation can only be called once
   382  				// anyway, so that's ok semantically.
   383  				c.clearReg(contReg)
   384  
   385  				if isTail {
   386  					// As we're leaving this continuation for good, perform all
   387  					// the pending close actions.  It must be done before debug
   388  					// hooks are called.
   389  					if err := t.cleanupCloseStack(c, c.closeStackBase, nil); err != nil {
   390  						return nil, err
   391  					}
   392  				}
   393  
   394  				if t.areFlagsEnabled(HookFlagCall | HookFlagReturn) {
   395  					switch {
   396  					case contReg == code.ValueReg(0):
   397  						_ = t.triggerReturn(t, c)
   398  					case isTail:
   399  						_ = t.triggerTailCall(t, next)
   400  					default:
   401  						_ = t.triggerCall(t, next)
   402  					}
   403  				}
   404  
   405  				if isTail {
   406  					// It's a tail call.  There is no error, so nothing will
   407  					// reference c anymore, therefore we are safe to give it to
   408  					// the pool for reuse.  It must be done after debug hooks
   409  					// are called because they may use c.
   410  					c.release(t.Runtime)
   411  				}
   412  				return next, nil
   413  			case code.OpClStack:
   414  				if opcode.GetF() {
   415  					// Push to close stack
   416  					v := getReg(regs, cells, opcode.GetA())
   417  					if Truth(v) && t.metaGetS(v, "__close").IsNil() {
   418  						c.pc = pc
   419  						return nil, errors.New("to be closed value missing a __close metamethod")
   420  					}
   421  					t.closeStack.push(v)
   422  				} else {
   423  					// Truncate close stack
   424  					h := c.closeStackBase + int(opcode.GetClStackOffset())
   425  					if err := t.cleanupCloseStack(c, h, nil); err != nil {
   426  						c.pc = pc
   427  						return nil, err
   428  					}
   429  				}
   430  				pc++
   431  				continue RunLoop
   432  			default:
   433  				panic("unsupported")
   434  			}
   435  		case code.Type6Pfx:
   436  			dst := opcode.GetA()
   437  			etc := getReg(regs, cells, opcode.GetB()).AsArray()
   438  			idx := int(opcode.GetM())
   439  			var val Value
   440  			if idx < len(etc) {
   441  				val = etc[idx]
   442  			}
   443  			if opcode.GetF() {
   444  				tbl := getReg(regs, cells, dst).AsTable()
   445  				for i, v := range etc {
   446  					t.SetTable(tbl, IntValue(int64(i+idx)), v)
   447  				}
   448  			} else {
   449  				setReg(regs, cells, dst, val)
   450  			}
   451  			pc++
   452  			continue RunLoop
   453  		case code.Type7Pfx:
   454  			startReg, stopReg, stepReg := opcode.GetA(), opcode.GetB(), opcode.GetC()
   455  			start := getReg(regs, cells, startReg)
   456  			stop := getReg(regs, cells, stopReg)
   457  			step := getReg(regs, cells, stepReg)
   458  			if opcode.GetF() {
   459  				// Advance for loop.  All registers are assumed to contain
   460  				// numeric values because they have been prepared previously.
   461  				nextStart, _ := Add(start, step)
   462  
   463  				// Check if the loop is done.  It can be done if we have gone
   464  				// over the stop value or if there has been overflow /
   465  				// underflow.
   466  				var done bool
   467  				if isPositive(step) {
   468  					done = numIsLessThan(stop, nextStart) || numIsLessThan(nextStart, start)
   469  				} else {
   470  					done = numIsLessThan(nextStart, stop) || numIsLessThan(start, nextStart)
   471  				}
   472  				if done {
   473  					nextStart = NilValue
   474  				}
   475  				setReg(regs, cells, startReg, nextStart)
   476  			} else {
   477  				// Prepare for loop
   478  				start, tstart := ToNumberValue(start)
   479  				stop, tstop := ToNumberValue(stop)
   480  				step, tstep := ToNumberValue(step)
   481  				if tstart == NaN || tstop == NaN || tstep == NaN {
   482  					c.pc = pc
   483  					var (
   484  						role string
   485  						val  Value
   486  					)
   487  					switch {
   488  					case tstart == NaN:
   489  						role, val = "initial value", start
   490  					case tstop == NaN:
   491  						role, val = "limit", stop
   492  					default:
   493  						role, val = "step", step
   494  					}
   495  					return nil, fmt.Errorf("'for' %s: expected number, got %s", role, val.CustomTypeName())
   496  				}
   497  				// Make sure start and step have the same numeric type
   498  				if tstart != tstep {
   499  					// One is a float, one is an int, turn them both to floats
   500  					if tstart == IsInt {
   501  						start = FloatValue(float64(start.AsInt()))
   502  					} else {
   503  						step = FloatValue(float64(step.AsInt()))
   504  					}
   505  				}
   506  				// A 0 step is an error
   507  				if isZero(step) {
   508  					c.pc = pc
   509  					return nil, errors.New("'for' step is zero")
   510  				}
   511  				// Check the loop is not already finished. If so, startReg is
   512  				// set to nil.
   513  				var done bool
   514  				if isPositive(step) {
   515  					done, _ = isLessThan(stop, start)
   516  				} else {
   517  					done, _ = isLessThan(start, stop)
   518  				}
   519  				if done {
   520  					start = NilValue
   521  				}
   522  				setReg(regs, cells, startReg, start)
   523  				setReg(regs, cells, stopReg, stop)
   524  				setReg(regs, cells, stepReg, step)
   525  			}
   526  			pc++
   527  			continue RunLoop
   528  		}
   529  	}
   530  	// return nil, errors.New("Invalid PC")
   531  }
   532  
   533  // DebugInfo implements Cont.DebugInfo.
   534  func (c *LuaCont) DebugInfo() *DebugInfo {
   535  	pc := c.pc
   536  	if !c.running {
   537  		pc--
   538  	}
   539  	var currentLine int32 = -1
   540  	if pc >= 0 && int(pc) < len(c.lines) {
   541  		currentLine = c.lines[pc]
   542  	}
   543  	name := c.name
   544  	if name == "" {
   545  		name = "<lua function>"
   546  	}
   547  	return &DebugInfo{
   548  		Source:      c.source,
   549  		Name:        name,
   550  		CurrentLine: currentLine,
   551  	}
   552  }
   553  
   554  func (c *LuaCont) getRegCell(reg code.Reg) Cell {
   555  	if reg.IsCell() {
   556  		return c.cells[reg.Idx()]
   557  	}
   558  	panic("should be a cell")
   559  }
   560  
   561  func (c *LuaCont) clearReg(reg code.Reg) {
   562  	if reg.IsCell() {
   563  		c.cells[reg.Idx()] = newCell(NilValue)
   564  	} else {
   565  		c.registers[reg.Idx()] = NilValue
   566  	}
   567  }
   568  
   569  func setReg(regs []Value, cells []Cell, reg code.Reg, val Value) {
   570  	idx := reg.Idx()
   571  	if reg.IsCell() {
   572  		cells[idx].set(val)
   573  	} else {
   574  		regs[idx] = val
   575  	}
   576  }
   577  
   578  func getReg(regs []Value, cells []Cell, reg code.Reg) Value {
   579  	if reg.IsCell() {
   580  		return *cells[reg.Idx()].ref
   581  	}
   582  	return regs[reg.Idx()]
   583  }