go.starlark.net@v0.0.0-20231101134539-556fd59b42f6/starlark/interp.go (about)

     1  package starlark
     2  
     3  // This file defines the bytecode interpreter.
     4  
     5  import (
     6  	"fmt"
     7  	"os"
     8  	"sync/atomic"
     9  	"unsafe"
    10  
    11  	"go.starlark.net/internal/compile"
    12  	"go.starlark.net/internal/spell"
    13  	"go.starlark.net/syntax"
    14  )
    15  
    16  const vmdebug = false // TODO(adonovan): use a bitfield of specific kinds of error.
    17  
    18  // TODO(adonovan):
    19  // - optimize position table.
    20  // - opt: record MaxIterStack during compilation and preallocate the stack.
    21  
    22  func (fn *Function) CallInternal(thread *Thread, args Tuple, kwargs []Tuple) (Value, error) {
    23  	// Postcondition: args is not mutated. This is stricter than required by Callable,
    24  	// but allows CALL to avoid a copy.
    25  
    26  	f := fn.funcode
    27  	if !f.Prog.Recursion {
    28  		// detect recursion
    29  		for _, fr := range thread.stack[:len(thread.stack)-1] {
    30  			// We look for the same function code,
    31  			// not function value, otherwise the user could
    32  			// defeat the check by writing the Y combinator.
    33  			if frfn, ok := fr.Callable().(*Function); ok && frfn.funcode == f {
    34  				return nil, fmt.Errorf("function %s called recursively", fn.Name())
    35  			}
    36  		}
    37  	}
    38  
    39  	fr := thread.frameAt(0)
    40  
    41  	// Allocate space for stack and locals.
    42  	// Logically these do not escape from this frame
    43  	// (See https://github.com/golang/go/issues/20533.)
    44  	//
    45  	// This heap allocation looks expensive, but I was unable to get
    46  	// more than 1% real time improvement in a large alloc-heavy
    47  	// benchmark (in which this alloc was 8% of alloc-bytes)
    48  	// by allocating space for 8 Values in each frame, or
    49  	// by allocating stack by slicing an array held by the Thread
    50  	// that is expanded in chunks of min(k, nspace), for k=256 or 1024.
    51  	nlocals := len(f.Locals)
    52  	nspace := nlocals + f.MaxStack
    53  	space := make([]Value, nspace)
    54  	locals := space[:nlocals:nlocals] // local variables, starting with parameters
    55  	stack := space[nlocals:]          // operand stack
    56  
    57  	// Digest arguments and set parameters.
    58  	err := setArgs(locals, fn, args, kwargs)
    59  	if err != nil {
    60  		return nil, thread.evalError(err)
    61  	}
    62  
    63  	fr.locals = locals
    64  
    65  	if vmdebug {
    66  		fmt.Printf("Entering %s @ %s\n", f.Name, f.Position(0))
    67  		fmt.Printf("%d stack, %d locals\n", len(stack), len(locals))
    68  		defer fmt.Println("Leaving ", f.Name)
    69  	}
    70  
    71  	// Spill indicated locals to cells.
    72  	// Each cell is a separate alloc to avoid spurious liveness.
    73  	for _, index := range f.Cells {
    74  		locals[index] = &cell{locals[index]}
    75  	}
    76  
    77  	// TODO(adonovan): add static check that beneath this point
    78  	// - there is exactly one return statement
    79  	// - there is no redefinition of 'err'.
    80  
    81  	var iterstack []Iterator // stack of active iterators
    82  
    83  	// Use defer so that application panics can pass through
    84  	// interpreter without leaving thread in a bad state.
    85  	defer func() {
    86  		// ITERPOP the rest of the iterator stack.
    87  		for _, iter := range iterstack {
    88  			iter.Done()
    89  		}
    90  
    91  		fr.locals = nil
    92  	}()
    93  
    94  	sp := 0
    95  	var pc uint32
    96  	var result Value
    97  	code := f.Code
    98  loop:
    99  	for {
   100  		thread.Steps++
   101  		if thread.Steps >= thread.maxSteps {
   102  			if thread.OnMaxSteps != nil {
   103  				thread.OnMaxSteps(thread)
   104  			} else {
   105  				thread.Cancel("too many steps")
   106  			}
   107  		}
   108  		if reason := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&thread.cancelReason))); reason != nil {
   109  			err = fmt.Errorf("Starlark computation cancelled: %s", *(*string)(reason))
   110  			break loop
   111  		}
   112  
   113  		fr.pc = pc
   114  
   115  		op := compile.Opcode(code[pc])
   116  		pc++
   117  		var arg uint32
   118  		if op >= compile.OpcodeArgMin {
   119  			// TODO(adonovan): opt: profile this.
   120  			// Perhaps compiling big endian would be less work to decode?
   121  			for s := uint(0); ; s += 7 {
   122  				b := code[pc]
   123  				pc++
   124  				arg |= uint32(b&0x7f) << s
   125  				if b < 0x80 {
   126  					break
   127  				}
   128  			}
   129  		}
   130  		if vmdebug {
   131  			fmt.Fprintln(os.Stderr, stack[:sp]) // very verbose!
   132  			compile.PrintOp(f, fr.pc, op, arg)
   133  		}
   134  
   135  		switch op {
   136  		case compile.NOP:
   137  			// nop
   138  
   139  		case compile.DUP:
   140  			stack[sp] = stack[sp-1]
   141  			sp++
   142  
   143  		case compile.DUP2:
   144  			stack[sp] = stack[sp-2]
   145  			stack[sp+1] = stack[sp-1]
   146  			sp += 2
   147  
   148  		case compile.POP:
   149  			sp--
   150  
   151  		case compile.EXCH:
   152  			stack[sp-2], stack[sp-1] = stack[sp-1], stack[sp-2]
   153  
   154  		case compile.EQL, compile.NEQ, compile.GT, compile.LT, compile.LE, compile.GE:
   155  			op := syntax.Token(op-compile.EQL) + syntax.EQL
   156  			y := stack[sp-1]
   157  			x := stack[sp-2]
   158  			sp -= 2
   159  			ok, err2 := Compare(op, x, y)
   160  			if err2 != nil {
   161  				err = err2
   162  				break loop
   163  			}
   164  			stack[sp] = Bool(ok)
   165  			sp++
   166  
   167  		case compile.PLUS,
   168  			compile.MINUS,
   169  			compile.STAR,
   170  			compile.SLASH,
   171  			compile.SLASHSLASH,
   172  			compile.PERCENT,
   173  			compile.AMP,
   174  			compile.PIPE,
   175  			compile.CIRCUMFLEX,
   176  			compile.LTLT,
   177  			compile.GTGT,
   178  			compile.IN:
   179  			binop := syntax.Token(op-compile.PLUS) + syntax.PLUS
   180  			if op == compile.IN {
   181  				binop = syntax.IN // IN token is out of order
   182  			}
   183  			y := stack[sp-1]
   184  			x := stack[sp-2]
   185  			sp -= 2
   186  			z, err2 := Binary(binop, x, y)
   187  			if err2 != nil {
   188  				err = err2
   189  				break loop
   190  			}
   191  			stack[sp] = z
   192  			sp++
   193  
   194  		case compile.UPLUS, compile.UMINUS, compile.TILDE:
   195  			var unop syntax.Token
   196  			if op == compile.TILDE {
   197  				unop = syntax.TILDE
   198  			} else {
   199  				unop = syntax.Token(op-compile.UPLUS) + syntax.PLUS
   200  			}
   201  			x := stack[sp-1]
   202  			y, err2 := Unary(unop, x)
   203  			if err2 != nil {
   204  				err = err2
   205  				break loop
   206  			}
   207  			stack[sp-1] = y
   208  
   209  		case compile.INPLACE_ADD:
   210  			y := stack[sp-1]
   211  			x := stack[sp-2]
   212  			sp -= 2
   213  
   214  			// It's possible that y is not Iterable but
   215  			// nonetheless defines x+y, in which case we
   216  			// should fall back to the general case.
   217  			var z Value
   218  			if xlist, ok := x.(*List); ok {
   219  				if yiter, ok := y.(Iterable); ok {
   220  					if err = xlist.checkMutable("apply += to"); err != nil {
   221  						break loop
   222  					}
   223  					listExtend(xlist, yiter)
   224  					z = xlist
   225  				}
   226  			}
   227  			if z == nil {
   228  				z, err = Binary(syntax.PLUS, x, y)
   229  				if err != nil {
   230  					break loop
   231  				}
   232  			}
   233  
   234  			stack[sp] = z
   235  			sp++
   236  
   237  		case compile.INPLACE_PIPE:
   238  			y := stack[sp-1]
   239  			x := stack[sp-2]
   240  			sp -= 2
   241  
   242  			// It's possible that y is not Dict but
   243  			// nonetheless defines x|y, in which case we
   244  			// should fall back to the general case.
   245  			var z Value
   246  			if xdict, ok := x.(*Dict); ok {
   247  				if ydict, ok := y.(*Dict); ok {
   248  					if err = xdict.ht.checkMutable("apply |= to"); err != nil {
   249  						break loop
   250  					}
   251  					xdict.ht.addAll(&ydict.ht) // can't fail
   252  					z = xdict
   253  				}
   254  			}
   255  			if z == nil {
   256  				z, err = Binary(syntax.PIPE, x, y)
   257  				if err != nil {
   258  					break loop
   259  				}
   260  			}
   261  
   262  			stack[sp] = z
   263  			sp++
   264  
   265  		case compile.NONE:
   266  			stack[sp] = None
   267  			sp++
   268  
   269  		case compile.TRUE:
   270  			stack[sp] = True
   271  			sp++
   272  
   273  		case compile.FALSE:
   274  			stack[sp] = False
   275  			sp++
   276  
   277  		case compile.MANDATORY:
   278  			stack[sp] = mandatory{}
   279  			sp++
   280  
   281  		case compile.JMP:
   282  			pc = arg
   283  
   284  		case compile.CALL, compile.CALL_VAR, compile.CALL_KW, compile.CALL_VAR_KW:
   285  			var kwargs Value
   286  			if op == compile.CALL_KW || op == compile.CALL_VAR_KW {
   287  				kwargs = stack[sp-1]
   288  				sp--
   289  			}
   290  
   291  			var args Value
   292  			if op == compile.CALL_VAR || op == compile.CALL_VAR_KW {
   293  				args = stack[sp-1]
   294  				sp--
   295  			}
   296  
   297  			// named args (pairs)
   298  			var kvpairs []Tuple
   299  			if nkvpairs := int(arg & 0xff); nkvpairs > 0 {
   300  				kvpairs = make([]Tuple, 0, nkvpairs)
   301  				kvpairsAlloc := make(Tuple, 2*nkvpairs) // allocate a single backing array
   302  				sp -= 2 * nkvpairs
   303  				for i := 0; i < nkvpairs; i++ {
   304  					pair := kvpairsAlloc[:2:2]
   305  					kvpairsAlloc = kvpairsAlloc[2:]
   306  					pair[0] = stack[sp+2*i]   // name
   307  					pair[1] = stack[sp+2*i+1] // value
   308  					kvpairs = append(kvpairs, pair)
   309  				}
   310  			}
   311  			if kwargs != nil {
   312  				// Add key/value items from **kwargs dictionary.
   313  				dict, ok := kwargs.(IterableMapping)
   314  				if !ok {
   315  					err = fmt.Errorf("argument after ** must be a mapping, not %s", kwargs.Type())
   316  					break loop
   317  				}
   318  				items := dict.Items()
   319  				for _, item := range items {
   320  					if _, ok := item[0].(String); !ok {
   321  						err = fmt.Errorf("keywords must be strings, not %s", item[0].Type())
   322  						break loop
   323  					}
   324  				}
   325  				if len(kvpairs) == 0 {
   326  					kvpairs = items
   327  				} else {
   328  					kvpairs = append(kvpairs, items...)
   329  				}
   330  			}
   331  
   332  			// positional args
   333  			var positional Tuple
   334  			if npos := int(arg >> 8); npos > 0 {
   335  				positional = stack[sp-npos : sp]
   336  				sp -= npos
   337  
   338  				// Copy positional arguments into a new array,
   339  				// unless the callee is another Starlark function,
   340  				// in which case it can be trusted not to mutate them.
   341  				if _, ok := stack[sp-1].(*Function); !ok || args != nil {
   342  					positional = append(Tuple(nil), positional...)
   343  				}
   344  			}
   345  			if args != nil {
   346  				// Add elements from *args sequence.
   347  				iter := Iterate(args)
   348  				if iter == nil {
   349  					err = fmt.Errorf("argument after * must be iterable, not %s", args.Type())
   350  					break loop
   351  				}
   352  				var elem Value
   353  				for iter.Next(&elem) {
   354  					positional = append(positional, elem)
   355  				}
   356  				iter.Done()
   357  			}
   358  
   359  			function := stack[sp-1]
   360  
   361  			if vmdebug {
   362  				fmt.Printf("VM call %s args=%s kwargs=%s @%s\n",
   363  					function, positional, kvpairs, f.Position(fr.pc))
   364  			}
   365  
   366  			thread.endProfSpan()
   367  			z, err2 := Call(thread, function, positional, kvpairs)
   368  			thread.beginProfSpan()
   369  			if err2 != nil {
   370  				err = err2
   371  				break loop
   372  			}
   373  			if vmdebug {
   374  				fmt.Printf("Resuming %s @ %s\n", f.Name, f.Position(0))
   375  			}
   376  			stack[sp-1] = z
   377  
   378  		case compile.ITERPUSH:
   379  			x := stack[sp-1]
   380  			sp--
   381  			iter := Iterate(x)
   382  			if iter == nil {
   383  				err = fmt.Errorf("%s value is not iterable", x.Type())
   384  				break loop
   385  			}
   386  			iterstack = append(iterstack, iter)
   387  
   388  		case compile.ITERJMP:
   389  			iter := iterstack[len(iterstack)-1]
   390  			if iter.Next(&stack[sp]) {
   391  				sp++
   392  			} else {
   393  				pc = arg
   394  			}
   395  
   396  		case compile.ITERPOP:
   397  			n := len(iterstack) - 1
   398  			iterstack[n].Done()
   399  			iterstack = iterstack[:n]
   400  
   401  		case compile.NOT:
   402  			stack[sp-1] = !stack[sp-1].Truth()
   403  
   404  		case compile.RETURN:
   405  			result = stack[sp-1]
   406  			break loop
   407  
   408  		case compile.SETINDEX:
   409  			z := stack[sp-1]
   410  			y := stack[sp-2]
   411  			x := stack[sp-3]
   412  			sp -= 3
   413  			err = setIndex(x, y, z)
   414  			if err != nil {
   415  				break loop
   416  			}
   417  
   418  		case compile.INDEX:
   419  			y := stack[sp-1]
   420  			x := stack[sp-2]
   421  			sp -= 2
   422  			z, err2 := getIndex(x, y)
   423  			if err2 != nil {
   424  				err = err2
   425  				break loop
   426  			}
   427  			stack[sp] = z
   428  			sp++
   429  
   430  		case compile.ATTR:
   431  			x := stack[sp-1]
   432  			name := f.Prog.Names[arg]
   433  			y, err2 := getAttr(x, name)
   434  			if err2 != nil {
   435  				err = err2
   436  				break loop
   437  			}
   438  			stack[sp-1] = y
   439  
   440  		case compile.SETFIELD:
   441  			y := stack[sp-1]
   442  			x := stack[sp-2]
   443  			sp -= 2
   444  			name := f.Prog.Names[arg]
   445  			if err2 := setField(x, name, y); err2 != nil {
   446  				err = err2
   447  				break loop
   448  			}
   449  
   450  		case compile.MAKEDICT:
   451  			stack[sp] = new(Dict)
   452  			sp++
   453  
   454  		case compile.SETDICT, compile.SETDICTUNIQ:
   455  			dict := stack[sp-3].(*Dict)
   456  			k := stack[sp-2]
   457  			v := stack[sp-1]
   458  			sp -= 3
   459  			oldlen := dict.Len()
   460  			if err2 := dict.SetKey(k, v); err2 != nil {
   461  				err = err2
   462  				break loop
   463  			}
   464  			if op == compile.SETDICTUNIQ && dict.Len() == oldlen {
   465  				err = fmt.Errorf("duplicate key: %v", k)
   466  				break loop
   467  			}
   468  
   469  		case compile.APPEND:
   470  			elem := stack[sp-1]
   471  			list := stack[sp-2].(*List)
   472  			sp -= 2
   473  			list.elems = append(list.elems, elem)
   474  
   475  		case compile.SLICE:
   476  			x := stack[sp-4]
   477  			lo := stack[sp-3]
   478  			hi := stack[sp-2]
   479  			step := stack[sp-1]
   480  			sp -= 4
   481  			res, err2 := slice(x, lo, hi, step)
   482  			if err2 != nil {
   483  				err = err2
   484  				break loop
   485  			}
   486  			stack[sp] = res
   487  			sp++
   488  
   489  		case compile.UNPACK:
   490  			n := int(arg)
   491  			iterable := stack[sp-1]
   492  			sp--
   493  			iter := Iterate(iterable)
   494  			if iter == nil {
   495  				err = fmt.Errorf("got %s in sequence assignment", iterable.Type())
   496  				break loop
   497  			}
   498  			i := 0
   499  			sp += n
   500  			for i < n && iter.Next(&stack[sp-1-i]) {
   501  				i++
   502  			}
   503  			var dummy Value
   504  			if iter.Next(&dummy) {
   505  				// NB: Len may return -1 here in obscure cases.
   506  				err = fmt.Errorf("too many values to unpack (got %d, want %d)", Len(iterable), n)
   507  				break loop
   508  			}
   509  			iter.Done()
   510  			if i < n {
   511  				err = fmt.Errorf("too few values to unpack (got %d, want %d)", i, n)
   512  				break loop
   513  			}
   514  
   515  		case compile.CJMP:
   516  			if stack[sp-1].Truth() {
   517  				pc = arg
   518  			}
   519  			sp--
   520  
   521  		case compile.CONSTANT:
   522  			stack[sp] = fn.module.constants[arg]
   523  			sp++
   524  
   525  		case compile.MAKETUPLE:
   526  			n := int(arg)
   527  			tuple := make(Tuple, n)
   528  			sp -= n
   529  			copy(tuple, stack[sp:])
   530  			stack[sp] = tuple
   531  			sp++
   532  
   533  		case compile.MAKELIST:
   534  			n := int(arg)
   535  			elems := make([]Value, n)
   536  			sp -= n
   537  			copy(elems, stack[sp:])
   538  			stack[sp] = NewList(elems)
   539  			sp++
   540  
   541  		case compile.MAKEFUNC:
   542  			funcode := f.Prog.Functions[arg]
   543  			tuple := stack[sp-1].(Tuple)
   544  			n := len(tuple) - len(funcode.Freevars)
   545  			defaults := tuple[:n:n]
   546  			freevars := tuple[n:]
   547  			stack[sp-1] = &Function{
   548  				funcode:  funcode,
   549  				module:   fn.module,
   550  				defaults: defaults,
   551  				freevars: freevars,
   552  			}
   553  
   554  		case compile.LOAD:
   555  			n := int(arg)
   556  			module := string(stack[sp-1].(String))
   557  			sp--
   558  
   559  			if thread.Load == nil {
   560  				err = fmt.Errorf("load not implemented by this application")
   561  				break loop
   562  			}
   563  
   564  			thread.endProfSpan()
   565  			dict, err2 := thread.Load(thread, module)
   566  			thread.beginProfSpan()
   567  			if err2 != nil {
   568  				err = wrappedError{
   569  					msg:   fmt.Sprintf("cannot load %s: %v", module, err2),
   570  					cause: err2,
   571  				}
   572  				break loop
   573  			}
   574  
   575  			for i := 0; i < n; i++ {
   576  				from := string(stack[sp-1-i].(String))
   577  				v, ok := dict[from]
   578  				if !ok {
   579  					err = fmt.Errorf("load: name %s not found in module %s", from, module)
   580  					if n := spell.Nearest(from, dict.Keys()); n != "" {
   581  						err = fmt.Errorf("%s (did you mean %s?)", err, n)
   582  					}
   583  					break loop
   584  				}
   585  				stack[sp-1-i] = v
   586  			}
   587  
   588  		case compile.SETLOCAL:
   589  			locals[arg] = stack[sp-1]
   590  			sp--
   591  
   592  		case compile.SETLOCALCELL:
   593  			locals[arg].(*cell).v = stack[sp-1]
   594  			sp--
   595  
   596  		case compile.SETGLOBAL:
   597  			fn.module.globals[arg] = stack[sp-1]
   598  			sp--
   599  
   600  		case compile.LOCAL:
   601  			x := locals[arg]
   602  			if x == nil {
   603  				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
   604  				break loop
   605  			}
   606  			stack[sp] = x
   607  			sp++
   608  
   609  		case compile.FREE:
   610  			stack[sp] = fn.freevars[arg]
   611  			sp++
   612  
   613  		case compile.LOCALCELL:
   614  			v := locals[arg].(*cell).v
   615  			if v == nil {
   616  				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
   617  				break loop
   618  			}
   619  			stack[sp] = v
   620  			sp++
   621  
   622  		case compile.FREECELL:
   623  			v := fn.freevars[arg].(*cell).v
   624  			if v == nil {
   625  				err = fmt.Errorf("local variable %s referenced before assignment", f.Freevars[arg].Name)
   626  				break loop
   627  			}
   628  			stack[sp] = v
   629  			sp++
   630  
   631  		case compile.GLOBAL:
   632  			x := fn.module.globals[arg]
   633  			if x == nil {
   634  				err = fmt.Errorf("global variable %s referenced before assignment", f.Prog.Globals[arg].Name)
   635  				break loop
   636  			}
   637  			stack[sp] = x
   638  			sp++
   639  
   640  		case compile.PREDECLARED:
   641  			name := f.Prog.Names[arg]
   642  			x := fn.module.predeclared[name]
   643  			if x == nil {
   644  				err = fmt.Errorf("internal error: predeclared variable %s is uninitialized", name)
   645  				break loop
   646  			}
   647  			stack[sp] = x
   648  			sp++
   649  
   650  		case compile.UNIVERSAL:
   651  			stack[sp] = Universe[f.Prog.Names[arg]]
   652  			sp++
   653  
   654  		default:
   655  			err = fmt.Errorf("unimplemented: %s", op)
   656  			break loop
   657  		}
   658  	}
   659  	// (deferred cleanup runs here)
   660  	return result, err
   661  }
   662  
   663  type wrappedError struct {
   664  	msg   string
   665  	cause error
   666  }
   667  
   668  func (e wrappedError) Error() string {
   669  	return e.msg
   670  }
   671  
   672  // Implements the xerrors.Wrapper interface
   673  // https://godoc.org/golang.org/x/xerrors#Wrapper
   674  func (e wrappedError) Unwrap() error {
   675  	return e.cause
   676  }
   677  
   678  // mandatory is a sentinel value used in a function's defaults tuple
   679  // to indicate that a (keyword-only) parameter is mandatory.
   680  type mandatory struct{}
   681  
   682  func (mandatory) String() string        { return "mandatory" }
   683  func (mandatory) Type() string          { return "mandatory" }
   684  func (mandatory) Freeze()               {} // immutable
   685  func (mandatory) Truth() Bool           { return False }
   686  func (mandatory) Hash() (uint32, error) { return 0, nil }
   687  
   688  // A cell is a box containing a Value.
   689  // Local variables marked as cells hold their value indirectly
   690  // so that they may be shared by outer and inner nested functions.
   691  // Cells are always accessed using indirect {FREE,LOCAL,SETLOCAL}CELL instructions.
   692  // The FreeVars tuple contains only cells.
   693  // The FREE instruction always yields a cell.
   694  type cell struct{ v Value }
   695  
   696  func (c *cell) String() string { return "cell" }
   697  func (c *cell) Type() string   { return "cell" }
   698  func (c *cell) Freeze() {
   699  	if c.v != nil {
   700  		c.v.Freeze()
   701  	}
   702  }
   703  func (c *cell) Truth() Bool           { panic("unreachable") }
   704  func (c *cell) Hash() (uint32, error) { panic("unreachable") }