github.com/google/skylark@v0.0.0-20181101142754-a5f7082aabed/interp.go (about)

     1  package skylark
     2  
     3  // This file defines the bytecode interpreter.
     4  
     5  import (
     6  	"fmt"
     7  	"os"
     8  
     9  	"github.com/google/skylark/internal/compile"
    10  	"github.com/google/skylark/syntax"
    11  )
    12  
    13  const vmdebug = false // TODO(adonovan): use a bitfield of specific kinds of error.
    14  
    15  // TODO(adonovan):
    16  // - optimize position table.
    17  // - opt: reduce allocations by preallocating a large stack, saving it
    18  //   in the thread, and slicing it.
    19  // - opt: record MaxIterStack during compilation and preallocate the stack.
    20  
    21  func (fn *Function) CallInternal(thread *Thread, args Tuple, kwargs []Tuple) (Value, error) {
    22  	if debug {
    23  		fmt.Printf("call of %s %v %v\n", fn.Name(), args, kwargs)
    24  	}
    25  
    26  	// detect recursion
    27  	for fr := thread.frame.parent; fr != nil; fr = fr.parent {
    28  		// We look for the same function code,
    29  		// not function value, otherwise the user could
    30  		// defeat the check by writing the Y combinator.
    31  		if frfn, ok := fr.Callable().(*Function); ok && frfn.funcode == fn.funcode {
    32  			return nil, fmt.Errorf("function %s called recursively", fn.Name())
    33  		}
    34  	}
    35  
    36  	return call(thread, args, kwargs)
    37  }
    38  
    39  func call(thread *Thread, args Tuple, kwargs []Tuple) (Value, error) {
    40  	fr := thread.frame
    41  	fn := fr.callable.(*Function)
    42  	f := fn.funcode
    43  	nlocals := len(f.Locals)
    44  	stack := make([]Value, nlocals+f.MaxStack)
    45  	locals := stack[:nlocals:nlocals] // local variables, starting with parameters
    46  	stack = stack[nlocals:]
    47  
    48  	err := setArgs(locals, fn, args, kwargs)
    49  	if err != nil {
    50  		return nil, fr.errorf(fr.Position(), "%v", err)
    51  	}
    52  
    53  	if vmdebug {
    54  		fmt.Printf("Entering %s @ %s\n", f.Name, f.Position(0))
    55  		fmt.Printf("%d stack, %d locals\n", len(stack), len(locals))
    56  		defer fmt.Println("Leaving ", f.Name)
    57  	}
    58  
    59  	// TODO(adonovan): add static check that beneath this point
    60  	// - there is exactly one return statement
    61  	// - there is no redefinition of 'err'.
    62  
    63  	var iterstack []Iterator // stack of active iterators
    64  
    65  	sp := 0
    66  	var pc, savedpc uint32
    67  	var result Value
    68  	code := f.Code
    69  loop:
    70  	for {
    71  		savedpc = pc
    72  
    73  		op := compile.Opcode(code[pc])
    74  		pc++
    75  		var arg uint32
    76  		if op >= compile.OpcodeArgMin {
    77  			// TODO(adonovan): opt: profile this.
    78  			// Perhaps compiling big endian would be less work to decode?
    79  			for s := uint(0); ; s += 7 {
    80  				b := code[pc]
    81  				pc++
    82  				arg |= uint32(b&0x7f) << s
    83  				if b < 0x80 {
    84  					break
    85  				}
    86  			}
    87  		}
    88  		if vmdebug {
    89  			fmt.Fprintln(os.Stderr, stack[:sp]) // very verbose!
    90  			compile.PrintOp(f, savedpc, op, arg)
    91  		}
    92  
    93  		switch op {
    94  		case compile.NOP:
    95  			// nop
    96  
    97  		case compile.DUP:
    98  			stack[sp] = stack[sp-1]
    99  			sp++
   100  
   101  		case compile.DUP2:
   102  			stack[sp] = stack[sp-2]
   103  			stack[sp+1] = stack[sp-1]
   104  			sp += 2
   105  
   106  		case compile.POP:
   107  			sp--
   108  
   109  		case compile.EXCH:
   110  			stack[sp-2], stack[sp-1] = stack[sp-1], stack[sp-2]
   111  
   112  		case compile.EQL, compile.NEQ, compile.GT, compile.LT, compile.LE, compile.GE:
   113  			op := syntax.Token(op-compile.EQL) + syntax.EQL
   114  			y := stack[sp-1]
   115  			x := stack[sp-2]
   116  			sp -= 2
   117  			ok, err2 := Compare(op, x, y)
   118  			if err2 != nil {
   119  				err = err2
   120  				break loop
   121  			}
   122  			stack[sp] = Bool(ok)
   123  			sp++
   124  
   125  		case compile.PLUS,
   126  			compile.MINUS,
   127  			compile.STAR,
   128  			compile.SLASH,
   129  			compile.SLASHSLASH,
   130  			compile.PERCENT,
   131  			compile.AMP,
   132  			compile.PIPE,
   133  			compile.CIRCUMFLEX,
   134  			compile.LTLT,
   135  			compile.GTGT,
   136  			compile.IN:
   137  			binop := syntax.Token(op-compile.PLUS) + syntax.PLUS
   138  			if op == compile.IN {
   139  				binop = syntax.IN // IN token is out of order
   140  			}
   141  			y := stack[sp-1]
   142  			x := stack[sp-2]
   143  			sp -= 2
   144  			z, err2 := Binary(binop, x, y)
   145  			if err2 != nil {
   146  				err = err2
   147  				break loop
   148  			}
   149  			stack[sp] = z
   150  			sp++
   151  
   152  		case compile.UPLUS, compile.UMINUS, compile.TILDE:
   153  			var unop syntax.Token
   154  			if op == compile.TILDE {
   155  				unop = syntax.TILDE
   156  			} else {
   157  				unop = syntax.Token(op-compile.UPLUS) + syntax.PLUS
   158  			}
   159  			x := stack[sp-1]
   160  			y, err2 := Unary(unop, x)
   161  			if err2 != nil {
   162  				err = err2
   163  				break loop
   164  			}
   165  			stack[sp-1] = y
   166  
   167  		case compile.INPLACE_ADD:
   168  			y := stack[sp-1]
   169  			x := stack[sp-2]
   170  			sp -= 2
   171  
   172  			// It's possible that y is not Iterable but
   173  			// nonetheless defines x+y, in which case we
   174  			// should fall back to the general case.
   175  			var z Value
   176  			if xlist, ok := x.(*List); ok {
   177  				if yiter, ok := y.(Iterable); ok {
   178  					if err = xlist.checkMutable("apply += to", true); err != nil {
   179  						break loop
   180  					}
   181  					listExtend(xlist, yiter)
   182  					z = xlist
   183  				}
   184  			}
   185  			if z == nil {
   186  				z, err = Binary(syntax.PLUS, x, y)
   187  				if err != nil {
   188  					break loop
   189  				}
   190  			}
   191  
   192  			stack[sp] = z
   193  			sp++
   194  
   195  		case compile.NONE:
   196  			stack[sp] = None
   197  			sp++
   198  
   199  		case compile.TRUE:
   200  			stack[sp] = True
   201  			sp++
   202  
   203  		case compile.FALSE:
   204  			stack[sp] = False
   205  			sp++
   206  
   207  		case compile.JMP:
   208  			pc = arg
   209  
   210  		case compile.CALL, compile.CALL_VAR, compile.CALL_KW, compile.CALL_VAR_KW:
   211  			var kwargs Value
   212  			if op == compile.CALL_KW || op == compile.CALL_VAR_KW {
   213  				kwargs = stack[sp-1]
   214  				sp--
   215  			}
   216  
   217  			var args Value
   218  			if op == compile.CALL_VAR || op == compile.CALL_VAR_KW {
   219  				args = stack[sp-1]
   220  				sp--
   221  			}
   222  
   223  			// named args (pairs)
   224  			var kvpairs []Tuple
   225  			if nkvpairs := int(arg & 0xff); nkvpairs > 0 {
   226  				kvpairs = make([]Tuple, 0, nkvpairs)
   227  				kvpairsAlloc := make(Tuple, 2*nkvpairs) // allocate a single backing array
   228  				sp -= 2 * nkvpairs
   229  				for i := 0; i < nkvpairs; i++ {
   230  					pair := kvpairsAlloc[:2:2]
   231  					kvpairsAlloc = kvpairsAlloc[2:]
   232  					pair[0] = stack[sp+2*i]   // name
   233  					pair[1] = stack[sp+2*i+1] // value
   234  					kvpairs = append(kvpairs, pair)
   235  				}
   236  			}
   237  			if kwargs != nil {
   238  				// Add key/value items from **kwargs dictionary.
   239  				dict, ok := kwargs.(*Dict)
   240  				if !ok {
   241  					err = fmt.Errorf("argument after ** must be a mapping, not %s", kwargs.Type())
   242  					break loop
   243  				}
   244  				items := dict.Items()
   245  				for _, item := range items {
   246  					if _, ok := item[0].(String); !ok {
   247  						err = fmt.Errorf("keywords must be strings, not %s", item[0].Type())
   248  						break loop
   249  					}
   250  				}
   251  				if len(kvpairs) == 0 {
   252  					kvpairs = items
   253  				} else {
   254  					kvpairs = append(kvpairs, items...)
   255  				}
   256  			}
   257  
   258  			// positional args
   259  			var positional Tuple
   260  			if npos := int(arg >> 8); npos > 0 {
   261  				positional = make(Tuple, npos)
   262  				sp -= npos
   263  				copy(positional, stack[sp:])
   264  			}
   265  			if args != nil {
   266  				// Add elements from *args sequence.
   267  				iter := Iterate(args)
   268  				if iter == nil {
   269  					err = fmt.Errorf("argument after * must be iterable, not %s", args.Type())
   270  					break loop
   271  				}
   272  				var elem Value
   273  				for iter.Next(&elem) {
   274  					positional = append(positional, elem)
   275  				}
   276  				iter.Done()
   277  			}
   278  
   279  			function := stack[sp-1]
   280  
   281  			if vmdebug {
   282  				fmt.Printf("VM call %s args=%s kwargs=%s @%s\n",
   283  					function, positional, kvpairs, f.Position(fr.callpc))
   284  			}
   285  
   286  			fr.callpc = savedpc
   287  			z, err2 := Call(thread, function, positional, kvpairs)
   288  			if err2 != nil {
   289  				err = err2
   290  				break loop
   291  			}
   292  			if vmdebug {
   293  				fmt.Printf("Resuming %s @ %s\n", f.Name, f.Position(0))
   294  			}
   295  			stack[sp-1] = z
   296  
   297  		case compile.ITERPUSH:
   298  			x := stack[sp-1]
   299  			sp--
   300  			iter := Iterate(x)
   301  			if iter == nil {
   302  				err = fmt.Errorf("%s value is not iterable", x.Type())
   303  				break loop
   304  			}
   305  			iterstack = append(iterstack, iter)
   306  
   307  		case compile.ITERJMP:
   308  			iter := iterstack[len(iterstack)-1]
   309  			if iter.Next(&stack[sp]) {
   310  				sp++
   311  			} else {
   312  				pc = arg
   313  			}
   314  
   315  		case compile.ITERPOP:
   316  			n := len(iterstack) - 1
   317  			iterstack[n].Done()
   318  			iterstack = iterstack[:n]
   319  
   320  		case compile.NOT:
   321  			stack[sp-1] = !stack[sp-1].Truth()
   322  
   323  		case compile.RETURN:
   324  			result = stack[sp-1]
   325  			break loop
   326  
   327  		case compile.SETINDEX:
   328  			z := stack[sp-1]
   329  			y := stack[sp-2]
   330  			x := stack[sp-3]
   331  			sp -= 3
   332  			err = setIndex(fr, x, y, z)
   333  			if err != nil {
   334  				break loop
   335  			}
   336  
   337  		case compile.INDEX:
   338  			y := stack[sp-1]
   339  			x := stack[sp-2]
   340  			sp -= 2
   341  			z, err2 := getIndex(fr, x, y)
   342  			if err2 != nil {
   343  				err = err2
   344  				break loop
   345  			}
   346  			stack[sp] = z
   347  			sp++
   348  
   349  		case compile.ATTR:
   350  			x := stack[sp-1]
   351  			name := f.Prog.Names[arg]
   352  			y, err2 := getAttr(fr, x, name)
   353  			if err2 != nil {
   354  				err = err2
   355  				break loop
   356  			}
   357  			stack[sp-1] = y
   358  
   359  		case compile.SETFIELD:
   360  			y := stack[sp-1]
   361  			x := stack[sp-2]
   362  			sp -= 2
   363  			name := f.Prog.Names[arg]
   364  			if err2 := setField(fr, x, name, y); err2 != nil {
   365  				err = err2
   366  				break loop
   367  			}
   368  
   369  		case compile.MAKEDICT:
   370  			stack[sp] = new(Dict)
   371  			sp++
   372  
   373  		case compile.SETDICT, compile.SETDICTUNIQ:
   374  			dict := stack[sp-3].(*Dict)
   375  			k := stack[sp-2]
   376  			v := stack[sp-1]
   377  			sp -= 3
   378  			oldlen := dict.Len()
   379  			if err2 := dict.SetKey(k, v); err2 != nil {
   380  				err = err2
   381  				break loop
   382  			}
   383  			if op == compile.SETDICTUNIQ && dict.Len() == oldlen {
   384  				err = fmt.Errorf("duplicate key: %v", k)
   385  				break loop
   386  			}
   387  
   388  		case compile.APPEND:
   389  			elem := stack[sp-1]
   390  			list := stack[sp-2].(*List)
   391  			sp -= 2
   392  			list.elems = append(list.elems, elem)
   393  
   394  		case compile.SLICE:
   395  			x := stack[sp-4]
   396  			lo := stack[sp-3]
   397  			hi := stack[sp-2]
   398  			step := stack[sp-1]
   399  			sp -= 4
   400  			res, err2 := slice(x, lo, hi, step)
   401  			if err2 != nil {
   402  				err = err2
   403  				break loop
   404  			}
   405  			stack[sp] = res
   406  			sp++
   407  
   408  		case compile.UNPACK:
   409  			n := int(arg)
   410  			iterable := stack[sp-1]
   411  			sp--
   412  			iter := Iterate(iterable)
   413  			if iter == nil {
   414  				err = fmt.Errorf("got %s in sequence assignment", iterable.Type())
   415  				break loop
   416  			}
   417  			i := 0
   418  			sp += n
   419  			for i < n && iter.Next(&stack[sp-1-i]) {
   420  				i++
   421  			}
   422  			var dummy Value
   423  			if iter.Next(&dummy) {
   424  				// NB: Len may return -1 here in obscure cases.
   425  				err = fmt.Errorf("too many values to unpack (got %d, want %d)", Len(iterable), n)
   426  				break loop
   427  			}
   428  			iter.Done()
   429  			if i < n {
   430  				err = fmt.Errorf("too few values to unpack (got %d, want %d)", i, n)
   431  				break loop
   432  			}
   433  
   434  		case compile.CJMP:
   435  			if stack[sp-1].Truth() {
   436  				pc = arg
   437  			}
   438  			sp--
   439  
   440  		case compile.CONSTANT:
   441  			stack[sp] = fn.constants[arg]
   442  			sp++
   443  
   444  		case compile.MAKETUPLE:
   445  			n := int(arg)
   446  			tuple := make(Tuple, n)
   447  			sp -= n
   448  			copy(tuple, stack[sp:])
   449  			stack[sp] = tuple
   450  			sp++
   451  
   452  		case compile.MAKELIST:
   453  			n := int(arg)
   454  			elems := make([]Value, n)
   455  			sp -= n
   456  			copy(elems, stack[sp:])
   457  			stack[sp] = NewList(elems)
   458  			sp++
   459  
   460  		case compile.MAKEFUNC:
   461  			funcode := f.Prog.Functions[arg]
   462  			freevars := stack[sp-1].(Tuple)
   463  			defaults := stack[sp-2].(Tuple)
   464  			sp -= 2
   465  			stack[sp] = &Function{
   466  				funcode:     funcode,
   467  				predeclared: fn.predeclared,
   468  				globals:     fn.globals,
   469  				constants:   fn.constants,
   470  				defaults:    defaults,
   471  				freevars:    freevars,
   472  			}
   473  			sp++
   474  
   475  		case compile.LOAD:
   476  			n := int(arg)
   477  			module := string(stack[sp-1].(String))
   478  			sp--
   479  
   480  			if thread.Load == nil {
   481  				err = fmt.Errorf("load not implemented by this application")
   482  				break loop
   483  			}
   484  
   485  			dict, err2 := thread.Load(thread, module)
   486  			if err2 != nil {
   487  				err = fmt.Errorf("cannot load %s: %v", module, err2)
   488  				break loop
   489  			}
   490  
   491  			for i := 0; i < n; i++ {
   492  				from := string(stack[sp-1-i].(String))
   493  				v, ok := dict[from]
   494  				if !ok {
   495  					err = fmt.Errorf("load: name %s not found in module %s", from, module)
   496  					break loop
   497  				}
   498  				stack[sp-1-i] = v
   499  			}
   500  
   501  		case compile.SETLOCAL:
   502  			locals[arg] = stack[sp-1]
   503  			sp--
   504  
   505  		case compile.SETGLOBAL:
   506  			fn.globals[arg] = stack[sp-1]
   507  			sp--
   508  
   509  		case compile.LOCAL:
   510  			x := locals[arg]
   511  			if x == nil {
   512  				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
   513  				break loop
   514  			}
   515  			stack[sp] = x
   516  			sp++
   517  
   518  		case compile.FREE:
   519  			stack[sp] = fn.freevars[arg]
   520  			sp++
   521  
   522  		case compile.GLOBAL:
   523  			x := fn.globals[arg]
   524  			if x == nil {
   525  				err = fmt.Errorf("global variable %s referenced before assignment", f.Prog.Globals[arg].Name)
   526  				break loop
   527  			}
   528  			stack[sp] = x
   529  			sp++
   530  
   531  		case compile.PREDECLARED:
   532  			name := f.Prog.Names[arg]
   533  			x := fn.predeclared[name]
   534  			if x == nil {
   535  				err = fmt.Errorf("internal error: predeclared variable %s is uninitialized", name)
   536  				break loop
   537  			}
   538  			stack[sp] = x
   539  			sp++
   540  
   541  		case compile.UNIVERSAL:
   542  			stack[sp] = Universe[f.Prog.Names[arg]]
   543  			sp++
   544  
   545  		default:
   546  			err = fmt.Errorf("unimplemented: %s", op)
   547  			break loop
   548  		}
   549  	}
   550  
   551  	// ITERPOP the rest of the iterator stack.
   552  	for _, iter := range iterstack {
   553  		iter.Done()
   554  	}
   555  
   556  	if err != nil {
   557  		if _, ok := err.(*EvalError); !ok {
   558  			err = fr.errorf(f.Position(savedpc), "%s", err.Error())
   559  		}
   560  	}
   561  	return result, err
   562  }