github.com/expr-lang/expr@v1.16.9/vm/vm.go (about)

     1  package vm
     2  
     3  //go:generate sh -c "go run ./func_types > ./func_types[generated].go"
     4  
     5  import (
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"sort"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/expr-lang/expr/builtin"
    14  	"github.com/expr-lang/expr/file"
    15  	"github.com/expr-lang/expr/internal/deref"
    16  	"github.com/expr-lang/expr/vm/runtime"
    17  )
    18  
    19  func Run(program *Program, env any) (any, error) {
    20  	if program == nil {
    21  		return nil, fmt.Errorf("program is nil")
    22  	}
    23  
    24  	vm := VM{}
    25  	return vm.Run(program, env)
    26  }
    27  
    28  func Debug() *VM {
    29  	vm := &VM{
    30  		debug: true,
    31  		step:  make(chan struct{}, 0),
    32  		curr:  make(chan int, 0),
    33  	}
    34  	return vm
    35  }
    36  
    37  type VM struct {
    38  	Stack        []any
    39  	Scopes       []*Scope
    40  	Variables    []any
    41  	ip           int
    42  	memory       uint
    43  	memoryBudget uint
    44  	debug        bool
    45  	step         chan struct{}
    46  	curr         chan int
    47  }
    48  
    49  func (vm *VM) Run(program *Program, env any) (_ any, err error) {
    50  	defer func() {
    51  		if r := recover(); r != nil {
    52  			var location file.Location
    53  			if vm.ip-1 < len(program.locations) {
    54  				location = program.locations[vm.ip-1]
    55  			}
    56  			f := &file.Error{
    57  				Location: location,
    58  				Message:  fmt.Sprintf("%v", r),
    59  			}
    60  			if err, ok := r.(error); ok {
    61  				f.Wrap(err)
    62  			}
    63  			err = f.Bind(program.source)
    64  		}
    65  	}()
    66  
    67  	if vm.Stack == nil {
    68  		vm.Stack = make([]any, 0, 2)
    69  	} else {
    70  		vm.Stack = vm.Stack[0:0]
    71  	}
    72  	if vm.Scopes != nil {
    73  		vm.Scopes = vm.Scopes[0:0]
    74  	}
    75  	if len(vm.Variables) < program.variables {
    76  		vm.Variables = make([]any, program.variables)
    77  	}
    78  
    79  	vm.memoryBudget = MemoryBudget
    80  	vm.memory = 0
    81  	vm.ip = 0
    82  
    83  	for vm.ip < len(program.Bytecode) {
    84  		if debug && vm.debug {
    85  			<-vm.step
    86  		}
    87  
    88  		op := program.Bytecode[vm.ip]
    89  		arg := program.Arguments[vm.ip]
    90  		vm.ip += 1
    91  
    92  		switch op {
    93  
    94  		case OpInvalid:
    95  			panic("invalid opcode")
    96  
    97  		case OpPush:
    98  			vm.push(program.Constants[arg])
    99  
   100  		case OpInt:
   101  			vm.push(arg)
   102  
   103  		case OpPop:
   104  			vm.pop()
   105  
   106  		case OpStore:
   107  			vm.Variables[arg] = vm.pop()
   108  
   109  		case OpLoadVar:
   110  			vm.push(vm.Variables[arg])
   111  
   112  		case OpLoadConst:
   113  			vm.push(runtime.Fetch(env, program.Constants[arg]))
   114  
   115  		case OpLoadField:
   116  			vm.push(runtime.FetchField(env, program.Constants[arg].(*runtime.Field)))
   117  
   118  		case OpLoadFast:
   119  			vm.push(env.(map[string]any)[program.Constants[arg].(string)])
   120  
   121  		case OpLoadMethod:
   122  			vm.push(runtime.FetchMethod(env, program.Constants[arg].(*runtime.Method)))
   123  
   124  		case OpLoadFunc:
   125  			vm.push(program.functions[arg])
   126  
   127  		case OpFetch:
   128  			b := vm.pop()
   129  			a := vm.pop()
   130  			vm.push(runtime.Fetch(a, b))
   131  
   132  		case OpFetchField:
   133  			a := vm.pop()
   134  			vm.push(runtime.FetchField(a, program.Constants[arg].(*runtime.Field)))
   135  
   136  		case OpLoadEnv:
   137  			vm.push(env)
   138  
   139  		case OpMethod:
   140  			a := vm.pop()
   141  			vm.push(runtime.FetchMethod(a, program.Constants[arg].(*runtime.Method)))
   142  
   143  		case OpTrue:
   144  			vm.push(true)
   145  
   146  		case OpFalse:
   147  			vm.push(false)
   148  
   149  		case OpNil:
   150  			vm.push(nil)
   151  
   152  		case OpNegate:
   153  			v := runtime.Negate(vm.pop())
   154  			vm.push(v)
   155  
   156  		case OpNot:
   157  			v := vm.pop().(bool)
   158  			vm.push(!v)
   159  
   160  		case OpEqual:
   161  			b := vm.pop()
   162  			a := vm.pop()
   163  			vm.push(runtime.Equal(a, b))
   164  
   165  		case OpEqualInt:
   166  			b := vm.pop()
   167  			a := vm.pop()
   168  			vm.push(a.(int) == b.(int))
   169  
   170  		case OpEqualString:
   171  			b := vm.pop()
   172  			a := vm.pop()
   173  			vm.push(a.(string) == b.(string))
   174  
   175  		case OpJump:
   176  			vm.ip += arg
   177  
   178  		case OpJumpIfTrue:
   179  			if vm.current().(bool) {
   180  				vm.ip += arg
   181  			}
   182  
   183  		case OpJumpIfFalse:
   184  			if !vm.current().(bool) {
   185  				vm.ip += arg
   186  			}
   187  
   188  		case OpJumpIfNil:
   189  			if runtime.IsNil(vm.current()) {
   190  				vm.ip += arg
   191  			}
   192  
   193  		case OpJumpIfNotNil:
   194  			if !runtime.IsNil(vm.current()) {
   195  				vm.ip += arg
   196  			}
   197  
   198  		case OpJumpIfEnd:
   199  			scope := vm.scope()
   200  			if scope.Index >= scope.Len {
   201  				vm.ip += arg
   202  			}
   203  
   204  		case OpJumpBackward:
   205  			vm.ip -= arg
   206  
   207  		case OpIn:
   208  			b := vm.pop()
   209  			a := vm.pop()
   210  			vm.push(runtime.In(a, b))
   211  
   212  		case OpLess:
   213  			b := vm.pop()
   214  			a := vm.pop()
   215  			vm.push(runtime.Less(a, b))
   216  
   217  		case OpMore:
   218  			b := vm.pop()
   219  			a := vm.pop()
   220  			vm.push(runtime.More(a, b))
   221  
   222  		case OpLessOrEqual:
   223  			b := vm.pop()
   224  			a := vm.pop()
   225  			vm.push(runtime.LessOrEqual(a, b))
   226  
   227  		case OpMoreOrEqual:
   228  			b := vm.pop()
   229  			a := vm.pop()
   230  			vm.push(runtime.MoreOrEqual(a, b))
   231  
   232  		case OpAdd:
   233  			b := vm.pop()
   234  			a := vm.pop()
   235  			vm.push(runtime.Add(a, b))
   236  
   237  		case OpSubtract:
   238  			b := vm.pop()
   239  			a := vm.pop()
   240  			vm.push(runtime.Subtract(a, b))
   241  
   242  		case OpMultiply:
   243  			b := vm.pop()
   244  			a := vm.pop()
   245  			vm.push(runtime.Multiply(a, b))
   246  
   247  		case OpDivide:
   248  			b := vm.pop()
   249  			a := vm.pop()
   250  			vm.push(runtime.Divide(a, b))
   251  
   252  		case OpModulo:
   253  			b := vm.pop()
   254  			a := vm.pop()
   255  			vm.push(runtime.Modulo(a, b))
   256  
   257  		case OpExponent:
   258  			b := vm.pop()
   259  			a := vm.pop()
   260  			vm.push(runtime.Exponent(a, b))
   261  
   262  		case OpRange:
   263  			b := vm.pop()
   264  			a := vm.pop()
   265  			min := runtime.ToInt(a)
   266  			max := runtime.ToInt(b)
   267  			size := max - min + 1
   268  			if size <= 0 {
   269  				size = 0
   270  			}
   271  			vm.memGrow(uint(size))
   272  			vm.push(runtime.MakeRange(min, max))
   273  
   274  		case OpMatches:
   275  			b := vm.pop()
   276  			a := vm.pop()
   277  			if runtime.IsNil(a) || runtime.IsNil(b) {
   278  				vm.push(false)
   279  				break
   280  			}
   281  			match, err := regexp.MatchString(b.(string), a.(string))
   282  			if err != nil {
   283  				panic(err)
   284  			}
   285  			vm.push(match)
   286  
   287  		case OpMatchesConst:
   288  			a := vm.pop()
   289  			if runtime.IsNil(a) {
   290  				vm.push(false)
   291  				break
   292  			}
   293  			r := program.Constants[arg].(*regexp.Regexp)
   294  			vm.push(r.MatchString(a.(string)))
   295  
   296  		case OpContains:
   297  			b := vm.pop()
   298  			a := vm.pop()
   299  			if runtime.IsNil(a) || runtime.IsNil(b) {
   300  				vm.push(false)
   301  				break
   302  			}
   303  			vm.push(strings.Contains(a.(string), b.(string)))
   304  
   305  		case OpStartsWith:
   306  			b := vm.pop()
   307  			a := vm.pop()
   308  			if runtime.IsNil(a) || runtime.IsNil(b) {
   309  				vm.push(false)
   310  				break
   311  			}
   312  			vm.push(strings.HasPrefix(a.(string), b.(string)))
   313  
   314  		case OpEndsWith:
   315  			b := vm.pop()
   316  			a := vm.pop()
   317  			if runtime.IsNil(a) || runtime.IsNil(b) {
   318  				vm.push(false)
   319  				break
   320  			}
   321  			vm.push(strings.HasSuffix(a.(string), b.(string)))
   322  
   323  		case OpSlice:
   324  			from := vm.pop()
   325  			to := vm.pop()
   326  			node := vm.pop()
   327  			vm.push(runtime.Slice(node, from, to))
   328  
   329  		case OpCall:
   330  			fn := reflect.ValueOf(vm.pop())
   331  			size := arg
   332  			in := make([]reflect.Value, size)
   333  			for i := int(size) - 1; i >= 0; i-- {
   334  				param := vm.pop()
   335  				if param == nil && reflect.TypeOf(param) == nil {
   336  					// In case of nil value and nil type use this hack,
   337  					// otherwise reflect.Call will panic on zero value.
   338  					in[i] = reflect.ValueOf(&param).Elem()
   339  				} else {
   340  					in[i] = reflect.ValueOf(param)
   341  				}
   342  			}
   343  			out := fn.Call(in)
   344  			if len(out) == 2 && out[1].Type() == errorType && !out[1].IsNil() {
   345  				panic(out[1].Interface().(error))
   346  			}
   347  			vm.push(out[0].Interface())
   348  
   349  		case OpCall0:
   350  			out, err := program.functions[arg]()
   351  			if err != nil {
   352  				panic(err)
   353  			}
   354  			vm.push(out)
   355  
   356  		case OpCall1:
   357  			a := vm.pop()
   358  			out, err := program.functions[arg](a)
   359  			if err != nil {
   360  				panic(err)
   361  			}
   362  			vm.push(out)
   363  
   364  		case OpCall2:
   365  			b := vm.pop()
   366  			a := vm.pop()
   367  			out, err := program.functions[arg](a, b)
   368  			if err != nil {
   369  				panic(err)
   370  			}
   371  			vm.push(out)
   372  
   373  		case OpCall3:
   374  			c := vm.pop()
   375  			b := vm.pop()
   376  			a := vm.pop()
   377  			out, err := program.functions[arg](a, b, c)
   378  			if err != nil {
   379  				panic(err)
   380  			}
   381  			vm.push(out)
   382  
   383  		case OpCallN:
   384  			fn := vm.pop().(Function)
   385  			size := arg
   386  			in := make([]any, size)
   387  			for i := int(size) - 1; i >= 0; i-- {
   388  				in[i] = vm.pop()
   389  			}
   390  			out, err := fn(in...)
   391  			if err != nil {
   392  				panic(err)
   393  			}
   394  			vm.push(out)
   395  
   396  		case OpCallFast:
   397  			fn := vm.pop().(func(...any) any)
   398  			size := arg
   399  			in := make([]any, size)
   400  			for i := int(size) - 1; i >= 0; i-- {
   401  				in[i] = vm.pop()
   402  			}
   403  			vm.push(fn(in...))
   404  
   405  		case OpCallSafe:
   406  			fn := vm.pop().(SafeFunction)
   407  			size := arg
   408  			in := make([]any, size)
   409  			for i := int(size) - 1; i >= 0; i-- {
   410  				in[i] = vm.pop()
   411  			}
   412  			out, mem, err := fn(in...)
   413  			if err != nil {
   414  				panic(err)
   415  			}
   416  			vm.memGrow(mem)
   417  			vm.push(out)
   418  
   419  		case OpCallTyped:
   420  			vm.push(vm.call(vm.pop(), arg))
   421  
   422  		case OpCallBuiltin1:
   423  			vm.push(builtin.Builtins[arg].Fast(vm.pop()))
   424  
   425  		case OpArray:
   426  			size := vm.pop().(int)
   427  			vm.memGrow(uint(size))
   428  			array := make([]any, size)
   429  			for i := size - 1; i >= 0; i-- {
   430  				array[i] = vm.pop()
   431  			}
   432  			vm.push(array)
   433  
   434  		case OpMap:
   435  			size := vm.pop().(int)
   436  			vm.memGrow(uint(size))
   437  			m := make(map[string]any)
   438  			for i := size - 1; i >= 0; i-- {
   439  				value := vm.pop()
   440  				key := vm.pop()
   441  				m[key.(string)] = value
   442  			}
   443  			vm.push(m)
   444  
   445  		case OpLen:
   446  			vm.push(runtime.Len(vm.current()))
   447  
   448  		case OpCast:
   449  			switch arg {
   450  			case 0:
   451  				vm.push(runtime.ToInt(vm.pop()))
   452  			case 1:
   453  				vm.push(runtime.ToInt64(vm.pop()))
   454  			case 2:
   455  				vm.push(runtime.ToFloat64(vm.pop()))
   456  			}
   457  
   458  		case OpDeref:
   459  			a := vm.pop()
   460  			vm.push(deref.Deref(a))
   461  
   462  		case OpIncrementIndex:
   463  			vm.scope().Index++
   464  
   465  		case OpDecrementIndex:
   466  			scope := vm.scope()
   467  			scope.Index--
   468  
   469  		case OpIncrementCount:
   470  			scope := vm.scope()
   471  			scope.Count++
   472  
   473  		case OpGetIndex:
   474  			vm.push(vm.scope().Index)
   475  
   476  		case OpGetCount:
   477  			scope := vm.scope()
   478  			vm.push(scope.Count)
   479  
   480  		case OpGetLen:
   481  			scope := vm.scope()
   482  			vm.push(scope.Len)
   483  
   484  		case OpGetAcc:
   485  			vm.push(vm.scope().Acc)
   486  
   487  		case OpSetAcc:
   488  			vm.scope().Acc = vm.pop()
   489  
   490  		case OpSetIndex:
   491  			scope := vm.scope()
   492  			scope.Index = vm.pop().(int)
   493  
   494  		case OpPointer:
   495  			scope := vm.scope()
   496  			vm.push(scope.Array.Index(scope.Index).Interface())
   497  
   498  		case OpThrow:
   499  			panic(vm.pop().(error))
   500  
   501  		case OpCreate:
   502  			switch arg {
   503  			case 1:
   504  				vm.push(make(groupBy))
   505  			case 2:
   506  				scope := vm.scope()
   507  				var desc bool
   508  				switch vm.pop().(string) {
   509  				case "asc":
   510  					desc = false
   511  				case "desc":
   512  					desc = true
   513  				default:
   514  					panic("unknown order, use asc or desc")
   515  				}
   516  				vm.push(&runtime.SortBy{
   517  					Desc:   desc,
   518  					Array:  make([]any, 0, scope.Len),
   519  					Values: make([]any, 0, scope.Len),
   520  				})
   521  			default:
   522  				panic(fmt.Sprintf("unknown OpCreate argument %v", arg))
   523  			}
   524  
   525  		case OpGroupBy:
   526  			scope := vm.scope()
   527  			key := vm.pop()
   528  			item := scope.Array.Index(scope.Index).Interface()
   529  			scope.Acc.(groupBy)[key] = append(scope.Acc.(groupBy)[key], item)
   530  
   531  		case OpSortBy:
   532  			scope := vm.scope()
   533  			value := vm.pop()
   534  			item := scope.Array.Index(scope.Index).Interface()
   535  			sortable := scope.Acc.(*runtime.SortBy)
   536  			sortable.Array = append(sortable.Array, item)
   537  			sortable.Values = append(sortable.Values, value)
   538  
   539  		case OpSort:
   540  			scope := vm.scope()
   541  			sortable := scope.Acc.(*runtime.SortBy)
   542  			sort.Sort(sortable)
   543  			vm.memGrow(uint(scope.Len))
   544  			vm.push(sortable.Array)
   545  
   546  		case OpProfileStart:
   547  			span := program.Constants[arg].(*Span)
   548  			span.start = time.Now()
   549  
   550  		case OpProfileEnd:
   551  			span := program.Constants[arg].(*Span)
   552  			span.Duration += time.Since(span.start).Nanoseconds()
   553  
   554  		case OpBegin:
   555  			a := vm.pop()
   556  			array := reflect.ValueOf(a)
   557  			vm.Scopes = append(vm.Scopes, &Scope{
   558  				Array: array,
   559  				Len:   array.Len(),
   560  			})
   561  
   562  		case OpEnd:
   563  			vm.Scopes = vm.Scopes[:len(vm.Scopes)-1]
   564  
   565  		default:
   566  			panic(fmt.Sprintf("unknown bytecode %#x", op))
   567  		}
   568  
   569  		if debug && vm.debug {
   570  			vm.curr <- vm.ip
   571  		}
   572  	}
   573  
   574  	if debug && vm.debug {
   575  		close(vm.curr)
   576  		close(vm.step)
   577  	}
   578  
   579  	if len(vm.Stack) > 0 {
   580  		return vm.pop(), nil
   581  	}
   582  
   583  	return nil, nil
   584  }
   585  
   586  func (vm *VM) push(value any) {
   587  	vm.Stack = append(vm.Stack, value)
   588  }
   589  
   590  func (vm *VM) current() any {
   591  	return vm.Stack[len(vm.Stack)-1]
   592  }
   593  
   594  func (vm *VM) pop() any {
   595  	value := vm.Stack[len(vm.Stack)-1]
   596  	vm.Stack = vm.Stack[:len(vm.Stack)-1]
   597  	return value
   598  }
   599  
   600  func (vm *VM) memGrow(size uint) {
   601  	vm.memory += size
   602  	if vm.memory >= vm.memoryBudget {
   603  		panic("memory budget exceeded")
   604  	}
   605  }
   606  
   607  func (vm *VM) scope() *Scope {
   608  	return vm.Scopes[len(vm.Scopes)-1]
   609  }
   610  
   611  func (vm *VM) Step() {
   612  	vm.step <- struct{}{}
   613  }
   614  
   615  func (vm *VM) Position() chan int {
   616  	return vm.curr
   617  }