github.com/lab47/exprcore@v0.0.0-20210525052339-fb7d6bd9331e/exprcore/interp.go (about)

     1  package exprcore
     2  
     3  // This file defines the bytecode interpreter.
     4  
     5  import (
     6  	"fmt"
     7  	"os"
     8  
     9  	"github.com/lab47/exprcore/internal/compile"
    10  	"github.com/lab47/exprcore/internal/spell"
    11  	"github.com/lab47/exprcore/resolve"
    12  	"github.com/lab47/exprcore/syntax"
    13  )
    14  
    15  const vmdebug = false // TODO(adonovan): use a bitfield of specific kinds of error.
    16  
    17  // TODO(adonovan):
    18  // - optimize position table.
    19  // - opt: record MaxIterStack during compilation and preallocate the stack.
    20  
    21  func (fn *Function) CallInternal(thread *Thread, args Tuple, kwargs []Tuple) (Value, error) {
    22  	if !resolve.AllowRecursion {
    23  		// detect recursion
    24  		for _, fr := range thread.stack[:len(thread.stack)-1] {
    25  			// We look for the same function code,
    26  			// not function value, otherwise the user could
    27  			// defeat the check by writing the Y combinator.
    28  			if frfn, ok := fr.Callable().(*Function); ok && frfn.funcode == fn.funcode {
    29  				return nil, fmt.Errorf("function %s called recursively", fn.Name())
    30  			}
    31  		}
    32  	}
    33  
    34  	f := fn.funcode
    35  	fr := thread.frameAt(0)
    36  
    37  	// Allocate space for stack and locals.
    38  	// Logically these do not escape from this frame
    39  	// (See https://github.com/golang/go/issues/20533.)
    40  	//
    41  	// This heap allocation looks expensive, but I was unable to get
    42  	// more than 1% real time improvement in a large alloc-heavy
    43  	// benchmark (in which this alloc was 8% of alloc-bytes)
    44  	// by allocating space for 8 Values in each frame, or
    45  	// by allocating stack by slicing an array held by the Thread
    46  	// that is expanded in chunks of min(k, nspace), for k=256 or 1024.
    47  	nlocals := len(f.Locals)
    48  	nspace := nlocals + f.MaxStack
    49  	space := make([]Value, nspace)
    50  	locals := space[:nlocals:nlocals] // local variables, starting with parameters
    51  	stack := space[nlocals:]          // operand stack
    52  
    53  	// Digest arguments and set parameters.
    54  	err := setArgs(locals, fn, args, kwargs)
    55  	if err != nil {
    56  		return nil, thread.evalError(err)
    57  	}
    58  
    59  	fr.locals = locals
    60  
    61  	if vmdebug {
    62  		fmt.Printf("Entering %s @ %s\n", f.Name, f.Position(0))
    63  		fmt.Printf("%d stack, %d locals\n", len(stack), len(locals))
    64  		defer fmt.Println("Leaving ", f.Name)
    65  	}
    66  
    67  	// Spill indicated locals to cells.
    68  	// Each cell is a separate alloc to avoid spurious liveness.
    69  	for _, index := range f.Cells {
    70  		locals[index] = &cell{locals[index]}
    71  	}
    72  
    73  	// TODO(adonovan): add static check that beneath this point
    74  	// - there is exactly one return statement
    75  	// - there is no redefinition of 'err'.
    76  
    77  	var iterstack []Iterator // stack of active iterators
    78  
    79  	sp := 0
    80  	var pc uint32
    81  	var result Value
    82  	code := f.Code
    83  loop:
    84  	for {
    85  		fr.pc = pc
    86  
    87  		op := compile.Opcode(code[pc])
    88  		pc++
    89  		var arg uint32
    90  		if op >= compile.OpcodeArgMin {
    91  			// TODO(adonovan): opt: profile this.
    92  			// Perhaps compiling big endian would be less work to decode?
    93  			for s := uint(0); ; s += 7 {
    94  				b := code[pc]
    95  				pc++
    96  				arg |= uint32(b&0x7f) << s
    97  				if b < 0x80 {
    98  					break
    99  				}
   100  			}
   101  		}
   102  		if vmdebug {
   103  			fmt.Fprintln(os.Stderr, stack[:sp]) // very verbose!
   104  			compile.PrintOp(f, fr.pc, op, arg)
   105  		}
   106  
   107  		switch op {
   108  		case compile.NOP:
   109  			// nop
   110  
   111  		case compile.DUP:
   112  			stack[sp] = stack[sp-1]
   113  			sp++
   114  
   115  		case compile.DUP2:
   116  			stack[sp] = stack[sp-2]
   117  			stack[sp+1] = stack[sp-1]
   118  			sp += 2
   119  
   120  		case compile.POP:
   121  			sp--
   122  
   123  		case compile.EXCH:
   124  			stack[sp-2], stack[sp-1] = stack[sp-1], stack[sp-2]
   125  
   126  		case compile.EQL, compile.NEQ, compile.GT, compile.LT, compile.LE, compile.GE:
   127  			op := syntax.Token(op-compile.EQL) + syntax.EQL
   128  			y := stack[sp-1]
   129  			x := stack[sp-2]
   130  			sp -= 2
   131  			ok, err2 := Compare(op, x, y)
   132  			if err2 != nil {
   133  				err = err2
   134  				break loop
   135  			}
   136  			stack[sp] = Bool(ok)
   137  			sp++
   138  
   139  		case compile.PLUS,
   140  			compile.MINUS,
   141  			compile.STAR,
   142  			compile.SLASH,
   143  			compile.SLASHSLASH,
   144  			compile.PERCENT,
   145  			compile.AMP,
   146  			compile.PIPE,
   147  			compile.CIRCUMFLEX,
   148  			compile.LTLT,
   149  			compile.GTGT,
   150  			compile.IN:
   151  			binop := syntax.Token(op-compile.PLUS) + syntax.PLUS
   152  			if op == compile.IN {
   153  				binop = syntax.IN // IN token is out of order
   154  			}
   155  			y := stack[sp-1]
   156  			x := stack[sp-2]
   157  			sp -= 2
   158  			z, err2 := Binary(binop, x, y)
   159  			if err2 != nil {
   160  				err = err2
   161  				break loop
   162  			}
   163  			stack[sp] = z
   164  			sp++
   165  
   166  		case compile.UPLUS, compile.UMINUS, compile.TILDE:
   167  			var unop syntax.Token
   168  			if op == compile.TILDE {
   169  				unop = syntax.TILDE
   170  			} else {
   171  				unop = syntax.Token(op-compile.UPLUS) + syntax.PLUS
   172  			}
   173  			x := stack[sp-1]
   174  			y, err2 := Unary(unop, x)
   175  			if err2 != nil {
   176  				err = err2
   177  				break loop
   178  			}
   179  			stack[sp-1] = y
   180  
   181  		case compile.INPLACE_ADD:
   182  			y := stack[sp-1]
   183  			x := stack[sp-2]
   184  			sp -= 2
   185  
   186  			// It's possible that y is not Iterable but
   187  			// nonetheless defines x+y, in which case we
   188  			// should fall back to the general case.
   189  			var z Value
   190  			if xlist, ok := x.(*List); ok {
   191  				if yiter, ok := y.(Iterable); ok {
   192  					if err = xlist.checkMutable("apply += to"); err != nil {
   193  						break loop
   194  					}
   195  					listExtend(xlist, yiter)
   196  					z = xlist
   197  				}
   198  			}
   199  			if z == nil {
   200  				z, err = Binary(syntax.PLUS, x, y)
   201  				if err != nil {
   202  					break loop
   203  				}
   204  			}
   205  
   206  			stack[sp] = z
   207  			sp++
   208  
   209  		case compile.NONE:
   210  			stack[sp] = None
   211  			sp++
   212  
   213  		case compile.TRUE:
   214  			stack[sp] = True
   215  			sp++
   216  
   217  		case compile.FALSE:
   218  			stack[sp] = False
   219  			sp++
   220  
   221  		case compile.MANDATORY:
   222  			stack[sp] = mandatory{}
   223  			sp++
   224  
   225  		case compile.JMP:
   226  			pc = arg
   227  
   228  		case compile.CALL, compile.CALL_VAR, compile.CALL_KW, compile.CALL_VAR_KW:
   229  			var kwargs Value
   230  			if op == compile.CALL_KW || op == compile.CALL_VAR_KW {
   231  				kwargs = stack[sp-1]
   232  				sp--
   233  			}
   234  
   235  			var args Value
   236  			if op == compile.CALL_VAR || op == compile.CALL_VAR_KW {
   237  				args = stack[sp-1]
   238  				sp--
   239  			}
   240  
   241  			// named args (pairs)
   242  			var kvpairs []Tuple
   243  			if nkvpairs := int(arg & 0xff); nkvpairs > 0 {
   244  				kvpairs = make([]Tuple, 0, nkvpairs)
   245  				kvpairsAlloc := make(Tuple, 2*nkvpairs) // allocate a single backing array
   246  				sp -= 2 * nkvpairs
   247  				for i := 0; i < nkvpairs; i++ {
   248  					pair := kvpairsAlloc[:2:2]
   249  					kvpairsAlloc = kvpairsAlloc[2:]
   250  					pair[0] = stack[sp+2*i]   // name
   251  					pair[1] = stack[sp+2*i+1] // value
   252  					kvpairs = append(kvpairs, pair)
   253  				}
   254  			}
   255  			if kwargs != nil {
   256  				// Add key/value items from **kwargs dictionary.
   257  				dict, ok := kwargs.(IterableMapping)
   258  				if !ok {
   259  					err = fmt.Errorf("argument after ** must be a mapping, not %s", kwargs.Type())
   260  					break loop
   261  				}
   262  				items := dict.Items()
   263  				for _, item := range items {
   264  					if _, ok := item[0].(String); !ok {
   265  						err = fmt.Errorf("keywords must be strings, not %s", item[0].Type())
   266  						break loop
   267  					}
   268  				}
   269  				if len(kvpairs) == 0 {
   270  					kvpairs = items
   271  				} else {
   272  					kvpairs = append(kvpairs, items...)
   273  				}
   274  			}
   275  
   276  			// positional args
   277  			var positional Tuple
   278  			if npos := int(arg >> 8); npos > 0 {
   279  				positional = make(Tuple, npos)
   280  				sp -= npos
   281  				copy(positional, stack[sp:])
   282  			}
   283  			if args != nil {
   284  				// Add elements from *args sequence.
   285  				iter := Iterate(args)
   286  				if iter == nil {
   287  					err = fmt.Errorf("argument after * must be iterable, not %s", args.Type())
   288  					break loop
   289  				}
   290  				var elem Value
   291  				for iter.Next(&elem) {
   292  					positional = append(positional, elem)
   293  				}
   294  				iter.Done()
   295  			}
   296  
   297  			function := stack[sp-1]
   298  
   299  			if vmdebug {
   300  				fmt.Printf("VM call %s args=%s kwargs=%s @%s\n",
   301  					function, positional, kvpairs, f.Position(fr.pc))
   302  			}
   303  
   304  			thread.endProfSpan()
   305  			z, err2 := Call(thread, function, positional, kvpairs)
   306  			thread.beginProfSpan()
   307  			if err2 != nil {
   308  				err = err2
   309  				break loop
   310  			}
   311  			if vmdebug {
   312  				fmt.Printf("Resuming %s @ %s\n", f.Name, f.Position(0))
   313  			}
   314  			stack[sp-1] = z
   315  
   316  		case compile.ITERPUSH:
   317  			x := stack[sp-1]
   318  			sp--
   319  			iter := Iterate(x)
   320  			if iter == nil {
   321  				err = fmt.Errorf("%s value is not iterable", x.Type())
   322  				break loop
   323  			}
   324  			iterstack = append(iterstack, iter)
   325  
   326  		case compile.ITERJMP:
   327  			iter := iterstack[len(iterstack)-1]
   328  			if iter.Next(&stack[sp]) {
   329  				sp++
   330  			} else {
   331  				pc = arg
   332  			}
   333  
   334  		case compile.ITERPOP:
   335  			n := len(iterstack) - 1
   336  			iterstack[n].Done()
   337  			iterstack = iterstack[:n]
   338  
   339  		case compile.NOT:
   340  			stack[sp-1] = !stack[sp-1].Truth()
   341  
   342  		case compile.RETURN:
   343  			result = stack[sp-1]
   344  			break loop
   345  
   346  		case compile.SETINDEX:
   347  			z := stack[sp-1]
   348  			y := stack[sp-2]
   349  			x := stack[sp-3]
   350  			sp -= 3
   351  			err = setIndex(x, y, z)
   352  			if err != nil {
   353  				break loop
   354  			}
   355  
   356  		case compile.INDEX:
   357  			y := stack[sp-1]
   358  			x := stack[sp-2]
   359  			sp -= 2
   360  			z, err2 := getIndex(x, y)
   361  			if err2 != nil {
   362  				err = err2
   363  				break loop
   364  			}
   365  			stack[sp] = z
   366  			sp++
   367  
   368  		case compile.ATTR:
   369  			x := stack[sp-1]
   370  			name := f.Prog.Names[arg]
   371  			y, err2 := getAttr(x, name)
   372  			if err2 != nil {
   373  				err = err2
   374  				break loop
   375  			}
   376  			stack[sp-1] = y
   377  
   378  		case compile.SETFIELD:
   379  			y := stack[sp-1]
   380  			x := stack[sp-2]
   381  			sp -= 2
   382  			name := f.Prog.Names[arg]
   383  			if err2 := setField(x, name, y); err2 != nil {
   384  				err = err2
   385  				break loop
   386  			}
   387  
   388  		case compile.MAKEDICT:
   389  			stack[sp] = NewDict(0)
   390  			sp++
   391  
   392  		case compile.MAKEPROTO:
   393  			stack[sp] = FromKeywords(Root, nil)
   394  			sp++
   395  
   396  		case compile.SETDICT, compile.SETDICTUNIQ:
   397  			dict := stack[sp-3].(*Dict)
   398  			k := stack[sp-2]
   399  			v := stack[sp-1]
   400  			sp -= 3
   401  			oldlen := dict.Len()
   402  			if err2 := dict.SetKey(k, v); err2 != nil {
   403  				err = err2
   404  				break loop
   405  			}
   406  			if op == compile.SETDICTUNIQ && dict.Len() == oldlen {
   407  				err = fmt.Errorf("duplicate key: %v", k)
   408  				break loop
   409  			}
   410  
   411  		case compile.SETPROTOKEY:
   412  			proto := stack[sp-3].(*Prototype)
   413  			k := stack[sp-2]
   414  			v := stack[sp-1]
   415  			sp -= 3
   416  
   417  			fmt.Printf("set-proto: %v (%s: %v)\n", proto, k, v)
   418  
   419  			err2 := proto.SetField(string(k.(String)), v)
   420  			if err2 != nil {
   421  				err = err2
   422  			}
   423  
   424  		case compile.APPEND:
   425  			elem := stack[sp-1]
   426  			list := stack[sp-2].(*List)
   427  			sp -= 2
   428  			list.elems = append(list.elems, elem)
   429  
   430  		case compile.SLICE:
   431  			x := stack[sp-4]
   432  			lo := stack[sp-3]
   433  			hi := stack[sp-2]
   434  			step := stack[sp-1]
   435  			sp -= 4
   436  			res, err2 := slice(x, lo, hi, step)
   437  			if err2 != nil {
   438  				err = err2
   439  				break loop
   440  			}
   441  			stack[sp] = res
   442  			sp++
   443  
   444  		case compile.UNPACK:
   445  			n := int(arg)
   446  			iterable := stack[sp-1]
   447  			sp--
   448  			iter := Iterate(iterable)
   449  			if iter == nil {
   450  				err = fmt.Errorf("got %s in sequence assignment", iterable.Type())
   451  				break loop
   452  			}
   453  			i := 0
   454  			sp += n
   455  			for i < n && iter.Next(&stack[sp-1-i]) {
   456  				i++
   457  			}
   458  			var dummy Value
   459  			if iter.Next(&dummy) {
   460  				// NB: Len may return -1 here in obscure cases.
   461  				err = fmt.Errorf("too many values to unpack (got %d, want %d)", Len(iterable), n)
   462  				break loop
   463  			}
   464  			iter.Done()
   465  			if i < n {
   466  				err = fmt.Errorf("too few values to unpack (got %d, want %d)", i, n)
   467  				break loop
   468  			}
   469  
   470  		case compile.CJMP:
   471  			if stack[sp-1].Truth() {
   472  				pc = arg
   473  			}
   474  			sp--
   475  
   476  		case compile.CONSTANT:
   477  			stack[sp] = fn.module.constants[arg]
   478  			sp++
   479  
   480  		case compile.MAKETUPLE:
   481  			n := int(arg)
   482  			tuple := make(Tuple, n)
   483  			sp -= n
   484  			copy(tuple, stack[sp:])
   485  			stack[sp] = tuple
   486  			sp++
   487  
   488  		case compile.MAKELIST:
   489  			n := int(arg)
   490  			elems := make([]Value, n)
   491  			sp -= n
   492  			copy(elems, stack[sp:])
   493  			stack[sp] = NewList(elems)
   494  			sp++
   495  
   496  		case compile.MAKEFUNC:
   497  			funcode := f.Prog.Functions[arg]
   498  			tuple := stack[sp-1].(Tuple)
   499  			n := len(tuple) - len(funcode.Freevars)
   500  			defaults := tuple[:n:n]
   501  			freevars := tuple[n:]
   502  			fn := &Function{
   503  				funcode:  funcode,
   504  				module:   fn.module,
   505  				defaults: defaults,
   506  				freevars: freevars,
   507  			}
   508  			fn.setParent("parent", funcPrototype)
   509  			stack[sp-1] = fn
   510  
   511  		case compile.LOAD:
   512  			n := int(arg)
   513  			module := string(stack[sp-1].(String))
   514  			sp--
   515  
   516  			if thread.Load == nil {
   517  				err = fmt.Errorf("load not implemented by this application")
   518  				break loop
   519  			}
   520  
   521  			thread.endProfSpan()
   522  			dict, err2 := thread.Load(thread, module)
   523  			thread.beginProfSpan()
   524  			if err2 != nil {
   525  				err = wrappedError{
   526  					msg:   fmt.Sprintf("cannot load %s: %v", module, err2),
   527  					cause: err2,
   528  				}
   529  				break loop
   530  			}
   531  
   532  			for i := 0; i < n; i++ {
   533  				from := string(stack[sp-1-i].(String))
   534  				v, ok := dict[from]
   535  				if !ok {
   536  					err = fmt.Errorf("load: name %s not found in module %s", from, module)
   537  					if n := spell.Nearest(from, dict.Keys()); n != "" {
   538  						err = fmt.Errorf("%s (did you mean %s?)", err, n)
   539  					}
   540  					break loop
   541  				}
   542  				stack[sp-1-i] = v
   543  			}
   544  
   545  		case compile.IMPORT:
   546  			ns := string(stack[sp-3].(String))
   547  			name := string(stack[sp-2].(String))
   548  
   549  			if thread.Import == nil {
   550  				err = fmt.Errorf("import not implemented by this application")
   551  				break loop
   552  			}
   553  
   554  			var args *Dict
   555  
   556  			switch x := stack[sp-1].(type) {
   557  			case *Dict:
   558  				args = x
   559  			}
   560  
   561  			val, err2 := thread.Import(thread, ns, name, args)
   562  			if err2 != nil {
   563  				err = wrappedError{
   564  					msg:   fmt.Sprintf("cannot import %s: %v", name, err2),
   565  					cause: err2,
   566  				}
   567  				break loop
   568  			}
   569  			stack[sp-3] = val
   570  
   571  			sp -= 2
   572  
   573  		case compile.SHELL:
   574  			var parts []string
   575  
   576  			for i := uint32(0); i < arg; i++ {
   577  				s, ok := stack[sp-1-int(i)].(String)
   578  				if !ok {
   579  					err = fmt.Errorf("shell only accepts strings")
   580  					break loop
   581  				}
   582  
   583  				parts = append(parts, string(s))
   584  			}
   585  
   586  			sp -= int(arg)
   587  
   588  			if thread.Shell == nil {
   589  				err = fmt.Errorf("shell not implemented by this application")
   590  				break loop
   591  			}
   592  
   593  			val, err2 := thread.Shell(thread, parts)
   594  			if err2 != nil {
   595  				err = wrappedError{
   596  					msg:   fmt.Sprintf("error executing shell: %v", err2),
   597  					cause: err2,
   598  				}
   599  				break loop
   600  			}
   601  
   602  			stack[sp] = val
   603  			sp++
   604  
   605  		case compile.SETLOCAL:
   606  			locals[arg] = stack[sp-1]
   607  			sp--
   608  
   609  		case compile.SETCELL:
   610  			x := stack[sp-2]
   611  			y := stack[sp-1]
   612  			sp -= 2
   613  			y.(*cell).v = x
   614  
   615  		case compile.SETGLOBAL:
   616  			fn.module.globals[arg] = stack[sp-1]
   617  			sp--
   618  
   619  		case compile.LOCAL:
   620  			x := locals[arg]
   621  			if x == nil {
   622  				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
   623  				break loop
   624  			}
   625  			stack[sp] = x
   626  			sp++
   627  
   628  		case compile.FREE:
   629  			stack[sp] = fn.freevars[arg]
   630  			sp++
   631  
   632  		case compile.CELL:
   633  			x := stack[sp-1]
   634  			stack[sp-1] = x.(*cell).v
   635  
   636  		case compile.GLOBAL:
   637  			x := fn.module.globals[arg]
   638  			if x == nil {
   639  				err = fmt.Errorf("global variable %s referenced before assignment", f.Prog.Globals[arg].Name)
   640  				break loop
   641  			}
   642  			stack[sp] = x
   643  			sp++
   644  
   645  		case compile.PREDECLARED:
   646  			name := f.Prog.Names[arg]
   647  			x := fn.module.predeclared[name]
   648  			if x == nil {
   649  				err = fmt.Errorf("internal error: predeclared variable %s is uninitialized", name)
   650  				break loop
   651  			}
   652  			stack[sp] = x
   653  			sp++
   654  
   655  		case compile.UNIVERSAL:
   656  			stack[sp] = Universe[f.Prog.Names[arg]]
   657  			sp++
   658  
   659  		case compile.AT:
   660  			name := f.Prog.Names[arg]
   661  
   662  			val, err := getIndex(fn.recv, String(name))
   663  			if err != nil {
   664  				break loop
   665  			}
   666  
   667  			if val == nil {
   668  				val = None
   669  			}
   670  
   671  			stack[sp] = val
   672  			sp++
   673  
   674  		default:
   675  			err = fmt.Errorf("unimplemented: %s", op)
   676  			break loop
   677  		}
   678  	}
   679  
   680  	// ITERPOP the rest of the iterator stack.
   681  	for _, iter := range iterstack {
   682  		iter.Done()
   683  	}
   684  
   685  	fr.locals = nil
   686  
   687  	return result, err
   688  }
   689  
   690  type wrappedError struct {
   691  	msg   string
   692  	cause error
   693  }
   694  
   695  func (e wrappedError) Error() string {
   696  	return e.msg
   697  }
   698  
   699  // Implements the xerrors.Wrapper interface
   700  // https://godoc.org/golang.org/x/xerrors#Wrapper
   701  func (e wrappedError) Unwrap() error {
   702  	return e.cause
   703  }
   704  
   705  // mandatory is a sentinel value used in a function's defaults tuple
   706  // to indicate that a (keyword-only) parameter is mandatory.
   707  type mandatory struct{}
   708  
   709  func (mandatory) String() string        { return "mandatory" }
   710  func (mandatory) Type() string          { return "mandatory" }
   711  func (mandatory) Freeze()               {} // immutable
   712  func (mandatory) Truth() Bool           { return False }
   713  func (mandatory) Hash() (uint32, error) { return 0, nil }
   714  
   715  // A cell is a box containing a Value.
   716  // Local variables marked as cells hold their value indirectly
   717  // so that they may be shared by outer and inner nested functions.
   718  // Cells are always accessed using indirect CELL/SETCELL instructions.
   719  // The FreeVars tuple contains only cells.
   720  // The FREE instruction always yields a cell.
   721  type cell struct{ v Value }
   722  
   723  func (c *cell) String() string { return "cell" }
   724  func (c *cell) Type() string   { return "cell" }
   725  func (c *cell) Freeze() {
   726  	if c.v != nil {
   727  		c.v.Freeze()
   728  	}
   729  }
   730  func (c *cell) Truth() Bool           { panic("unreachable") }
   731  func (c *cell) Hash() (uint32, error) { panic("unreachable") }