github.com/expr-lang/expr@v1.16.9/checker/checker.go (about)

     1  package checker
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"regexp"
     7  
     8  	"github.com/expr-lang/expr/ast"
     9  	"github.com/expr-lang/expr/builtin"
    10  	"github.com/expr-lang/expr/conf"
    11  	"github.com/expr-lang/expr/file"
    12  	"github.com/expr-lang/expr/internal/deref"
    13  	"github.com/expr-lang/expr/parser"
    14  )
    15  
    16  // ParseCheck parses input expression and checks its types. Also, it applies
    17  // all provided patchers. In case of error, it returns error with a tree.
    18  func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) {
    19  	tree, err := parser.ParseWithConfig(input, config)
    20  	if err != nil {
    21  		return tree, err
    22  	}
    23  
    24  	if len(config.Visitors) > 0 {
    25  		for i := 0; i < 1000; i++ {
    26  			more := false
    27  			for _, v := range config.Visitors {
    28  				// We need to perform types check, because some visitors may rely on
    29  				// types information available in the tree.
    30  				_, _ = Check(tree, config)
    31  
    32  				ast.Walk(&tree.Node, v)
    33  
    34  				if v, ok := v.(interface {
    35  					ShouldRepeat() bool
    36  				}); ok {
    37  					more = more || v.ShouldRepeat()
    38  				}
    39  			}
    40  			if !more {
    41  				break
    42  			}
    43  		}
    44  	}
    45  	_, err = Check(tree, config)
    46  	if err != nil {
    47  		return tree, err
    48  	}
    49  
    50  	return tree, nil
    51  }
    52  
    53  // Check checks types of the expression tree. It returns type of the expression
    54  // and error if any. If config is nil, then default configuration will be used.
    55  func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
    56  	if config == nil {
    57  		config = conf.New(nil)
    58  	}
    59  
    60  	v := &checker{config: config}
    61  
    62  	t, _ = v.visit(tree.Node)
    63  
    64  	if v.err != nil {
    65  		return t, v.err.Bind(tree.Source)
    66  	}
    67  
    68  	if v.config.Expect != reflect.Invalid {
    69  		if v.config.ExpectAny {
    70  			if isAny(t) {
    71  				return t, nil
    72  			}
    73  		}
    74  
    75  		switch v.config.Expect {
    76  		case reflect.Int, reflect.Int64, reflect.Float64:
    77  			if !isNumber(t) {
    78  				return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, t)
    79  			}
    80  		default:
    81  			if t != nil {
    82  				if t.Kind() == v.config.Expect {
    83  					return t, nil
    84  				}
    85  			}
    86  			return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, t)
    87  		}
    88  	}
    89  
    90  	return t, nil
    91  }
    92  
    93  type checker struct {
    94  	config          *conf.Config
    95  	predicateScopes []predicateScope
    96  	varScopes       []varScope
    97  	err             *file.Error
    98  }
    99  
   100  type predicateScope struct {
   101  	vtype reflect.Type
   102  	vars  map[string]reflect.Type
   103  }
   104  
   105  type varScope struct {
   106  	name  string
   107  	vtype reflect.Type
   108  	info  info
   109  }
   110  
   111  type info struct {
   112  	method bool
   113  	fn     *builtin.Function
   114  
   115  	// elem is element type of array or map.
   116  	// Arrays created with type []any, but
   117  	// we would like to detect expressions
   118  	// like `42 in ["a"]` as invalid.
   119  	elem reflect.Type
   120  }
   121  
   122  func (v *checker) visit(node ast.Node) (reflect.Type, info) {
   123  	var t reflect.Type
   124  	var i info
   125  	switch n := node.(type) {
   126  	case *ast.NilNode:
   127  		t, i = v.NilNode(n)
   128  	case *ast.IdentifierNode:
   129  		t, i = v.IdentifierNode(n)
   130  	case *ast.IntegerNode:
   131  		t, i = v.IntegerNode(n)
   132  	case *ast.FloatNode:
   133  		t, i = v.FloatNode(n)
   134  	case *ast.BoolNode:
   135  		t, i = v.BoolNode(n)
   136  	case *ast.StringNode:
   137  		t, i = v.StringNode(n)
   138  	case *ast.ConstantNode:
   139  		t, i = v.ConstantNode(n)
   140  	case *ast.UnaryNode:
   141  		t, i = v.UnaryNode(n)
   142  	case *ast.BinaryNode:
   143  		t, i = v.BinaryNode(n)
   144  	case *ast.ChainNode:
   145  		t, i = v.ChainNode(n)
   146  	case *ast.MemberNode:
   147  		t, i = v.MemberNode(n)
   148  	case *ast.SliceNode:
   149  		t, i = v.SliceNode(n)
   150  	case *ast.CallNode:
   151  		t, i = v.CallNode(n)
   152  	case *ast.BuiltinNode:
   153  		t, i = v.BuiltinNode(n)
   154  	case *ast.ClosureNode:
   155  		t, i = v.ClosureNode(n)
   156  	case *ast.PointerNode:
   157  		t, i = v.PointerNode(n)
   158  	case *ast.VariableDeclaratorNode:
   159  		t, i = v.VariableDeclaratorNode(n)
   160  	case *ast.ConditionalNode:
   161  		t, i = v.ConditionalNode(n)
   162  	case *ast.ArrayNode:
   163  		t, i = v.ArrayNode(n)
   164  	case *ast.MapNode:
   165  		t, i = v.MapNode(n)
   166  	case *ast.PairNode:
   167  		t, i = v.PairNode(n)
   168  	default:
   169  		panic(fmt.Sprintf("undefined node type (%T)", node))
   170  	}
   171  	node.SetType(t)
   172  	return t, i
   173  }
   174  
   175  func (v *checker) error(node ast.Node, format string, args ...any) (reflect.Type, info) {
   176  	if v.err == nil { // show first error
   177  		v.err = &file.Error{
   178  			Location: node.Location(),
   179  			Message:  fmt.Sprintf(format, args...),
   180  		}
   181  	}
   182  	return anyType, info{} // interface represent undefined type
   183  }
   184  
   185  func (v *checker) NilNode(*ast.NilNode) (reflect.Type, info) {
   186  	return nilType, info{}
   187  }
   188  
   189  func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) {
   190  	if s, ok := v.lookupVariable(node.Value); ok {
   191  		return s.vtype, s.info
   192  	}
   193  	if node.Value == "$env" {
   194  		return mapType, info{}
   195  	}
   196  	return v.ident(node, node.Value, true, true)
   197  }
   198  
   199  // ident method returns type of environment variable, builtin or function.
   200  func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (reflect.Type, info) {
   201  	if t, ok := v.config.Types[name]; ok {
   202  		if t.Ambiguous {
   203  			return v.error(node, "ambiguous identifier %v", name)
   204  		}
   205  		return t.Type, info{method: t.Method}
   206  	}
   207  	if builtins {
   208  		if fn, ok := v.config.Functions[name]; ok {
   209  			return fn.Type(), info{fn: fn}
   210  		}
   211  		if fn, ok := v.config.Builtins[name]; ok {
   212  			return fn.Type(), info{fn: fn}
   213  		}
   214  	}
   215  	if v.config.Strict && strict {
   216  		return v.error(node, "unknown name %v", name)
   217  	}
   218  	if v.config.DefaultType != nil {
   219  		return v.config.DefaultType, info{}
   220  	}
   221  	return anyType, info{}
   222  }
   223  
   224  func (v *checker) IntegerNode(*ast.IntegerNode) (reflect.Type, info) {
   225  	return integerType, info{}
   226  }
   227  
   228  func (v *checker) FloatNode(*ast.FloatNode) (reflect.Type, info) {
   229  	return floatType, info{}
   230  }
   231  
   232  func (v *checker) BoolNode(*ast.BoolNode) (reflect.Type, info) {
   233  	return boolType, info{}
   234  }
   235  
   236  func (v *checker) StringNode(*ast.StringNode) (reflect.Type, info) {
   237  	return stringType, info{}
   238  }
   239  
   240  func (v *checker) ConstantNode(node *ast.ConstantNode) (reflect.Type, info) {
   241  	return reflect.TypeOf(node.Value), info{}
   242  }
   243  
   244  func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) {
   245  	t, _ := v.visit(node.Node)
   246  	t = deref.Type(t)
   247  
   248  	switch node.Operator {
   249  
   250  	case "!", "not":
   251  		if isBool(t) {
   252  			return boolType, info{}
   253  		}
   254  		if isAny(t) {
   255  			return boolType, info{}
   256  		}
   257  
   258  	case "+", "-":
   259  		if isNumber(t) {
   260  			return t, info{}
   261  		}
   262  		if isAny(t) {
   263  			return anyType, info{}
   264  		}
   265  
   266  	default:
   267  		return v.error(node, "unknown operator (%v)", node.Operator)
   268  	}
   269  
   270  	return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t)
   271  }
   272  
   273  func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
   274  	l, _ := v.visit(node.Left)
   275  	r, ri := v.visit(node.Right)
   276  
   277  	l = deref.Type(l)
   278  	r = deref.Type(r)
   279  
   280  	switch node.Operator {
   281  	case "==", "!=":
   282  		if isComparable(l, r) {
   283  			return boolType, info{}
   284  		}
   285  
   286  	case "or", "||", "and", "&&":
   287  		if isBool(l) && isBool(r) {
   288  			return boolType, info{}
   289  		}
   290  		if or(l, r, isBool) {
   291  			return boolType, info{}
   292  		}
   293  
   294  	case "<", ">", ">=", "<=":
   295  		if isNumber(l) && isNumber(r) {
   296  			return boolType, info{}
   297  		}
   298  		if isString(l) && isString(r) {
   299  			return boolType, info{}
   300  		}
   301  		if isTime(l) && isTime(r) {
   302  			return boolType, info{}
   303  		}
   304  		if or(l, r, isNumber, isString, isTime) {
   305  			return boolType, info{}
   306  		}
   307  
   308  	case "-":
   309  		if isNumber(l) && isNumber(r) {
   310  			return combined(l, r), info{}
   311  		}
   312  		if isTime(l) && isTime(r) {
   313  			return durationType, info{}
   314  		}
   315  		if isTime(l) && isDuration(r) {
   316  			return timeType, info{}
   317  		}
   318  		if or(l, r, isNumber, isTime) {
   319  			return anyType, info{}
   320  		}
   321  
   322  	case "*":
   323  		if isNumber(l) && isNumber(r) {
   324  			return combined(l, r), info{}
   325  		}
   326  		if or(l, r, isNumber) {
   327  			return anyType, info{}
   328  		}
   329  
   330  	case "/":
   331  		if isNumber(l) && isNumber(r) {
   332  			return floatType, info{}
   333  		}
   334  		if or(l, r, isNumber) {
   335  			return floatType, info{}
   336  		}
   337  
   338  	case "**", "^":
   339  		if isNumber(l) && isNumber(r) {
   340  			return floatType, info{}
   341  		}
   342  		if or(l, r, isNumber) {
   343  			return floatType, info{}
   344  		}
   345  
   346  	case "%":
   347  		if isInteger(l) && isInteger(r) {
   348  			return combined(l, r), info{}
   349  		}
   350  		if or(l, r, isInteger) {
   351  			return anyType, info{}
   352  		}
   353  
   354  	case "+":
   355  		if isNumber(l) && isNumber(r) {
   356  			return combined(l, r), info{}
   357  		}
   358  		if isString(l) && isString(r) {
   359  			return stringType, info{}
   360  		}
   361  		if isTime(l) && isDuration(r) {
   362  			return timeType, info{}
   363  		}
   364  		if isDuration(l) && isTime(r) {
   365  			return timeType, info{}
   366  		}
   367  		if or(l, r, isNumber, isString, isTime, isDuration) {
   368  			return anyType, info{}
   369  		}
   370  
   371  	case "in":
   372  		if (isString(l) || isAny(l)) && isStruct(r) {
   373  			return boolType, info{}
   374  		}
   375  		if isMap(r) {
   376  			if l == nil { // It is possible to compare with nil.
   377  				return boolType, info{}
   378  			}
   379  			if !isAny(l) && !l.AssignableTo(r.Key()) {
   380  				return v.error(node, "cannot use %v as type %v in map key", l, r.Key())
   381  			}
   382  			return boolType, info{}
   383  		}
   384  		if isArray(r) {
   385  			if l == nil { // It is possible to compare with nil.
   386  				return boolType, info{}
   387  			}
   388  			if !isComparable(l, r.Elem()) {
   389  				return v.error(node, "cannot use %v as type %v in array", l, r.Elem())
   390  			}
   391  			if !isComparable(l, ri.elem) {
   392  				return v.error(node, "cannot use %v as type %v in array", l, ri.elem)
   393  			}
   394  			return boolType, info{}
   395  		}
   396  		if isAny(l) && anyOf(r, isString, isArray, isMap) {
   397  			return boolType, info{}
   398  		}
   399  		if isAny(r) {
   400  			return boolType, info{}
   401  		}
   402  
   403  	case "matches":
   404  		if s, ok := node.Right.(*ast.StringNode); ok {
   405  			_, err := regexp.Compile(s.Value)
   406  			if err != nil {
   407  				return v.error(node, err.Error())
   408  			}
   409  		}
   410  		if isString(l) && isString(r) {
   411  			return boolType, info{}
   412  		}
   413  		if or(l, r, isString) {
   414  			return boolType, info{}
   415  		}
   416  
   417  	case "contains", "startsWith", "endsWith":
   418  		if isString(l) && isString(r) {
   419  			return boolType, info{}
   420  		}
   421  		if or(l, r, isString) {
   422  			return boolType, info{}
   423  		}
   424  
   425  	case "..":
   426  		ret := reflect.SliceOf(integerType)
   427  		if isInteger(l) && isInteger(r) {
   428  			return ret, info{}
   429  		}
   430  		if or(l, r, isInteger) {
   431  			return ret, info{}
   432  		}
   433  
   434  	case "??":
   435  		if l == nil && r != nil {
   436  			return r, info{}
   437  		}
   438  		if l != nil && r == nil {
   439  			return l, info{}
   440  		}
   441  		if l == nil && r == nil {
   442  			return nilType, info{}
   443  		}
   444  		if r.AssignableTo(l) {
   445  			return l, info{}
   446  		}
   447  		return anyType, info{}
   448  
   449  	default:
   450  		return v.error(node, "unknown operator (%v)", node.Operator)
   451  
   452  	}
   453  
   454  	return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
   455  }
   456  
   457  func (v *checker) ChainNode(node *ast.ChainNode) (reflect.Type, info) {
   458  	return v.visit(node.Node)
   459  }
   460  
   461  func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
   462  	// $env variable
   463  	if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" {
   464  		if name, ok := node.Property.(*ast.StringNode); ok {
   465  			strict := v.config.Strict
   466  			if node.Optional {
   467  				// If user explicitly set optional flag, then we should not
   468  				// throw error if field is not found (as user trying to handle
   469  				// this case). But if user did not set optional flag, then we
   470  				// should throw error if field is not found & v.config.Strict.
   471  				strict = false
   472  			}
   473  			return v.ident(node, name.Value, strict, false /* no builtins and no functions */)
   474  		}
   475  		return anyType, info{}
   476  	}
   477  
   478  	base, _ := v.visit(node.Node)
   479  	prop, _ := v.visit(node.Property)
   480  
   481  	if name, ok := node.Property.(*ast.StringNode); ok {
   482  		if base == nil {
   483  			return v.error(node, "type %v has no field %v", base, name.Value)
   484  		}
   485  		// First, check methods defined on base type itself,
   486  		// independent of which type it is. Without dereferencing.
   487  		if m, ok := base.MethodByName(name.Value); ok {
   488  			if kind(base) == reflect.Interface {
   489  				// In case of interface type method will not have a receiver,
   490  				// and to prevent checker decreasing numbers of in arguments
   491  				// return method type as not method (second argument is false).
   492  
   493  				// Also, we can not use m.Index here, because it will be
   494  				// different indexes for different types which implement
   495  				// the same interface.
   496  				return m.Type, info{}
   497  			} else {
   498  				return m.Type, info{method: true}
   499  			}
   500  		}
   501  	}
   502  
   503  	if kind(base) == reflect.Ptr {
   504  		base = base.Elem()
   505  	}
   506  
   507  	switch kind(base) {
   508  	case reflect.Interface:
   509  		return anyType, info{}
   510  
   511  	case reflect.Map:
   512  		if prop != nil && !prop.AssignableTo(base.Key()) && !isAny(prop) {
   513  			return v.error(node.Property, "cannot use %v to get an element from %v", prop, base)
   514  		}
   515  		return base.Elem(), info{}
   516  
   517  	case reflect.Array, reflect.Slice:
   518  		if !isInteger(prop) && !isAny(prop) {
   519  			return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop)
   520  		}
   521  		return base.Elem(), info{}
   522  
   523  	case reflect.Struct:
   524  		if name, ok := node.Property.(*ast.StringNode); ok {
   525  			propertyName := name.Value
   526  			if field, ok := fetchField(base, propertyName); ok {
   527  				return field.Type, info{}
   528  			}
   529  			if node.Method {
   530  				return v.error(node, "type %v has no method %v", base, propertyName)
   531  			}
   532  			return v.error(node, "type %v has no field %v", base, propertyName)
   533  		}
   534  	}
   535  
   536  	return v.error(node, "type %v[%v] is undefined", base, prop)
   537  }
   538  
   539  func (v *checker) SliceNode(node *ast.SliceNode) (reflect.Type, info) {
   540  	t, _ := v.visit(node.Node)
   541  
   542  	switch kind(t) {
   543  	case reflect.Interface:
   544  		// ok
   545  	case reflect.String, reflect.Array, reflect.Slice:
   546  		// ok
   547  	default:
   548  		return v.error(node, "cannot slice %v", t)
   549  	}
   550  
   551  	if node.From != nil {
   552  		from, _ := v.visit(node.From)
   553  		if !isInteger(from) && !isAny(from) {
   554  			return v.error(node.From, "non-integer slice index %v", from)
   555  		}
   556  	}
   557  	if node.To != nil {
   558  		to, _ := v.visit(node.To)
   559  		if !isInteger(to) && !isAny(to) {
   560  			return v.error(node.To, "non-integer slice index %v", to)
   561  		}
   562  	}
   563  	return t, info{}
   564  }
   565  
   566  func (v *checker) CallNode(node *ast.CallNode) (reflect.Type, info) {
   567  	t, i := v.functionReturnType(node)
   568  
   569  	// Check if type was set on node (for example, by patcher)
   570  	// and use node type instead of function return type.
   571  	//
   572  	// If node type is anyType, then we should use function
   573  	// return type. For example, on error we return anyType
   574  	// for a call `errCall().Method()` and method will be
   575  	// evaluated on `anyType.Method()`, so return type will
   576  	// be anyType `anyType.Method(): anyType`. Patcher can
   577  	// fix `errCall()` to return proper type, so on second
   578  	// checker pass we should replace anyType on method node
   579  	// with new correct function return type.
   580  	if node.Type() != nil && node.Type() != anyType {
   581  		return node.Type(), i
   582  	}
   583  
   584  	return t, i
   585  }
   586  
   587  func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) {
   588  	fn, fnInfo := v.visit(node.Callee)
   589  
   590  	if fnInfo.fn != nil {
   591  		return v.checkFunction(fnInfo.fn, node, node.Arguments)
   592  	}
   593  
   594  	fnName := "function"
   595  	if identifier, ok := node.Callee.(*ast.IdentifierNode); ok {
   596  		fnName = identifier.Value
   597  	}
   598  	if member, ok := node.Callee.(*ast.MemberNode); ok {
   599  		if name, ok := member.Property.(*ast.StringNode); ok {
   600  			fnName = name.Value
   601  		}
   602  	}
   603  
   604  	if fn == nil {
   605  		return v.error(node, "%v is nil; cannot call nil as function", fnName)
   606  	}
   607  
   608  	switch fn.Kind() {
   609  	case reflect.Interface:
   610  		return anyType, info{}
   611  	case reflect.Func:
   612  		outType, err := v.checkArguments(fnName, fn, fnInfo.method, node.Arguments, node)
   613  		if err != nil {
   614  			if v.err == nil {
   615  				v.err = err
   616  			}
   617  			return anyType, info{}
   618  		}
   619  		return outType, info{}
   620  	}
   621  	return v.error(node, "%v is not callable", fn)
   622  }
   623  
   624  func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
   625  	switch node.Name {
   626  	case "all", "none", "any", "one":
   627  		collection, _ := v.visit(node.Arguments[0])
   628  		if !isArray(collection) && !isAny(collection) {
   629  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   630  		}
   631  
   632  		v.begin(collection)
   633  		closure, _ := v.visit(node.Arguments[1])
   634  		v.end()
   635  
   636  		if isFunc(closure) &&
   637  			closure.NumOut() == 1 &&
   638  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   639  
   640  			if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
   641  				return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
   642  			}
   643  			return boolType, info{}
   644  		}
   645  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   646  
   647  	case "filter":
   648  		collection, _ := v.visit(node.Arguments[0])
   649  		if !isArray(collection) && !isAny(collection) {
   650  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   651  		}
   652  
   653  		v.begin(collection)
   654  		closure, _ := v.visit(node.Arguments[1])
   655  		v.end()
   656  
   657  		if isFunc(closure) &&
   658  			closure.NumOut() == 1 &&
   659  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   660  
   661  			if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
   662  				return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
   663  			}
   664  			if isAny(collection) {
   665  				return arrayType, info{}
   666  			}
   667  			return arrayType, info{}
   668  		}
   669  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   670  
   671  	case "map":
   672  		collection, _ := v.visit(node.Arguments[0])
   673  		if !isArray(collection) && !isAny(collection) {
   674  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   675  		}
   676  
   677  		v.begin(collection, scopeVar{"index", integerType})
   678  		closure, _ := v.visit(node.Arguments[1])
   679  		v.end()
   680  
   681  		if isFunc(closure) &&
   682  			closure.NumOut() == 1 &&
   683  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   684  
   685  			return arrayType, info{}
   686  		}
   687  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   688  
   689  	case "count":
   690  		collection, _ := v.visit(node.Arguments[0])
   691  		if !isArray(collection) && !isAny(collection) {
   692  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   693  		}
   694  
   695  		if len(node.Arguments) == 1 {
   696  			return integerType, info{}
   697  		}
   698  
   699  		v.begin(collection)
   700  		closure, _ := v.visit(node.Arguments[1])
   701  		v.end()
   702  
   703  		if isFunc(closure) &&
   704  			closure.NumOut() == 1 &&
   705  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   706  			if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
   707  				return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
   708  			}
   709  
   710  			return integerType, info{}
   711  		}
   712  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   713  
   714  	case "sum":
   715  		collection, _ := v.visit(node.Arguments[0])
   716  		if !isArray(collection) && !isAny(collection) {
   717  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   718  		}
   719  
   720  		if len(node.Arguments) == 2 {
   721  			v.begin(collection)
   722  			closure, _ := v.visit(node.Arguments[1])
   723  			v.end()
   724  
   725  			if isFunc(closure) &&
   726  				closure.NumOut() == 1 &&
   727  				closure.NumIn() == 1 && isAny(closure.In(0)) {
   728  				return closure.Out(0), info{}
   729  			}
   730  		} else {
   731  			if isAny(collection) {
   732  				return anyType, info{}
   733  			}
   734  			return collection.Elem(), info{}
   735  		}
   736  
   737  	case "find", "findLast":
   738  		collection, _ := v.visit(node.Arguments[0])
   739  		if !isArray(collection) && !isAny(collection) {
   740  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   741  		}
   742  
   743  		v.begin(collection)
   744  		closure, _ := v.visit(node.Arguments[1])
   745  		v.end()
   746  
   747  		if isFunc(closure) &&
   748  			closure.NumOut() == 1 &&
   749  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   750  
   751  			if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
   752  				return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
   753  			}
   754  			if isAny(collection) {
   755  				return anyType, info{}
   756  			}
   757  			return collection.Elem(), info{}
   758  		}
   759  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   760  
   761  	case "findIndex", "findLastIndex":
   762  		collection, _ := v.visit(node.Arguments[0])
   763  		if !isArray(collection) && !isAny(collection) {
   764  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   765  		}
   766  
   767  		v.begin(collection)
   768  		closure, _ := v.visit(node.Arguments[1])
   769  		v.end()
   770  
   771  		if isFunc(closure) &&
   772  			closure.NumOut() == 1 &&
   773  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   774  
   775  			if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
   776  				return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
   777  			}
   778  			return integerType, info{}
   779  		}
   780  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   781  
   782  	case "groupBy":
   783  		collection, _ := v.visit(node.Arguments[0])
   784  		if !isArray(collection) && !isAny(collection) {
   785  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   786  		}
   787  
   788  		v.begin(collection)
   789  		closure, _ := v.visit(node.Arguments[1])
   790  		v.end()
   791  
   792  		if isFunc(closure) &&
   793  			closure.NumOut() == 1 &&
   794  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   795  
   796  			return reflect.TypeOf(map[any][]any{}), info{}
   797  		}
   798  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   799  
   800  	case "sortBy":
   801  		collection, _ := v.visit(node.Arguments[0])
   802  		if !isArray(collection) && !isAny(collection) {
   803  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   804  		}
   805  
   806  		v.begin(collection)
   807  		closure, _ := v.visit(node.Arguments[1])
   808  		v.end()
   809  
   810  		if len(node.Arguments) == 3 {
   811  			_, _ = v.visit(node.Arguments[2])
   812  		}
   813  
   814  		if isFunc(closure) &&
   815  			closure.NumOut() == 1 &&
   816  			closure.NumIn() == 1 && isAny(closure.In(0)) {
   817  
   818  			return reflect.TypeOf([]any{}), info{}
   819  		}
   820  		return v.error(node.Arguments[1], "predicate should has one input and one output param")
   821  
   822  	case "reduce":
   823  		collection, _ := v.visit(node.Arguments[0])
   824  		if !isArray(collection) && !isAny(collection) {
   825  			return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
   826  		}
   827  
   828  		v.begin(collection, scopeVar{"index", integerType}, scopeVar{"acc", anyType})
   829  		closure, _ := v.visit(node.Arguments[1])
   830  		v.end()
   831  
   832  		if len(node.Arguments) == 3 {
   833  			_, _ = v.visit(node.Arguments[2])
   834  		}
   835  
   836  		if isFunc(closure) && closure.NumOut() == 1 {
   837  			return closure.Out(0), info{}
   838  		}
   839  		return v.error(node.Arguments[1], "predicate should has two input and one output param")
   840  
   841  	}
   842  
   843  	if id, ok := builtin.Index[node.Name]; ok {
   844  		switch node.Name {
   845  		case "get":
   846  			return v.checkBuiltinGet(node)
   847  		}
   848  		return v.checkFunction(builtin.Builtins[id], node, node.Arguments)
   849  	}
   850  
   851  	return v.error(node, "unknown builtin %v", node.Name)
   852  }
   853  
   854  type scopeVar struct {
   855  	name  string
   856  	vtype reflect.Type
   857  }
   858  
   859  func (v *checker) begin(vtype reflect.Type, vars ...scopeVar) {
   860  	scope := predicateScope{vtype: vtype, vars: make(map[string]reflect.Type)}
   861  	for _, v := range vars {
   862  		scope.vars[v.name] = v.vtype
   863  	}
   864  	v.predicateScopes = append(v.predicateScopes, scope)
   865  }
   866  
   867  func (v *checker) end() {
   868  	v.predicateScopes = v.predicateScopes[:len(v.predicateScopes)-1]
   869  }
   870  
   871  func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) (reflect.Type, info) {
   872  	if len(node.Arguments) != 2 {
   873  		return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments))
   874  	}
   875  
   876  	val := node.Arguments[0]
   877  	prop := node.Arguments[1]
   878  	if id, ok := val.(*ast.IdentifierNode); ok && id.Value == "$env" {
   879  		if s, ok := prop.(*ast.StringNode); ok {
   880  			return v.config.Types[s.Value].Type, info{}
   881  		}
   882  		return anyType, info{}
   883  	}
   884  
   885  	t, _ := v.visit(val)
   886  
   887  	switch kind(t) {
   888  	case reflect.Interface:
   889  		return anyType, info{}
   890  	case reflect.Slice, reflect.Array:
   891  		p, _ := v.visit(prop)
   892  		if p == nil {
   893  			return v.error(prop, "cannot use nil as slice index")
   894  		}
   895  		if !isInteger(p) && !isAny(p) {
   896  			return v.error(prop, "non-integer slice index %v", p)
   897  		}
   898  		return t.Elem(), info{}
   899  	case reflect.Map:
   900  		p, _ := v.visit(prop)
   901  		if p == nil {
   902  			return v.error(prop, "cannot use nil as map index")
   903  		}
   904  		if !p.AssignableTo(t.Key()) && !isAny(p) {
   905  			return v.error(prop, "cannot use %v to get an element from %v", p, t)
   906  		}
   907  		return t.Elem(), info{}
   908  	}
   909  	return v.error(val, "type %v does not support indexing", t)
   910  }
   911  
   912  func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) (reflect.Type, info) {
   913  	if f.Validate != nil {
   914  		args := make([]reflect.Type, len(arguments))
   915  		for i, arg := range arguments {
   916  			args[i], _ = v.visit(arg)
   917  		}
   918  		t, err := f.Validate(args)
   919  		if err != nil {
   920  			return v.error(node, "%v", err)
   921  		}
   922  		return t, info{}
   923  	} else if len(f.Types) == 0 {
   924  		t, err := v.checkArguments(f.Name, f.Type(), false, arguments, node)
   925  		if err != nil {
   926  			if v.err == nil {
   927  				v.err = err
   928  			}
   929  			return anyType, info{}
   930  		}
   931  		// No type was specified, so we assume the function returns any.
   932  		return t, info{}
   933  	}
   934  	var lastErr *file.Error
   935  	for _, t := range f.Types {
   936  		outType, err := v.checkArguments(f.Name, t, false, arguments, node)
   937  		if err != nil {
   938  			lastErr = err
   939  			continue
   940  		}
   941  		return outType, info{}
   942  	}
   943  	if lastErr != nil {
   944  		if v.err == nil {
   945  			v.err = lastErr
   946  		}
   947  		return anyType, info{}
   948  	}
   949  
   950  	return v.error(node, "no matching overload for %v", f.Name)
   951  }
   952  
   953  func (v *checker) checkArguments(
   954  	name string,
   955  	fn reflect.Type,
   956  	method bool,
   957  	arguments []ast.Node,
   958  	node ast.Node,
   959  ) (reflect.Type, *file.Error) {
   960  	if isAny(fn) {
   961  		return anyType, nil
   962  	}
   963  
   964  	if fn.NumOut() == 0 {
   965  		return anyType, &file.Error{
   966  			Location: node.Location(),
   967  			Message:  fmt.Sprintf("func %v doesn't return value", name),
   968  		}
   969  	}
   970  	if numOut := fn.NumOut(); numOut > 2 {
   971  		return anyType, &file.Error{
   972  			Location: node.Location(),
   973  			Message:  fmt.Sprintf("func %v returns more then two values", name),
   974  		}
   975  	}
   976  
   977  	// If func is method on an env, first argument should be a receiver,
   978  	// and actual arguments less than fnNumIn by one.
   979  	fnNumIn := fn.NumIn()
   980  	if method {
   981  		fnNumIn--
   982  	}
   983  	// Skip first argument in case of the receiver.
   984  	fnInOffset := 0
   985  	if method {
   986  		fnInOffset = 1
   987  	}
   988  
   989  	var err *file.Error
   990  	if fn.IsVariadic() {
   991  		if len(arguments) < fnNumIn-1 {
   992  			err = &file.Error{
   993  				Location: node.Location(),
   994  				Message:  fmt.Sprintf("not enough arguments to call %v", name),
   995  			}
   996  		}
   997  	} else {
   998  		if len(arguments) > fnNumIn {
   999  			err = &file.Error{
  1000  				Location: node.Location(),
  1001  				Message:  fmt.Sprintf("too many arguments to call %v", name),
  1002  			}
  1003  		}
  1004  		if len(arguments) < fnNumIn {
  1005  			err = &file.Error{
  1006  				Location: node.Location(),
  1007  				Message:  fmt.Sprintf("not enough arguments to call %v", name),
  1008  			}
  1009  		}
  1010  	}
  1011  
  1012  	if err != nil {
  1013  		// If we have an error, we should still visit all arguments to
  1014  		// type check them, as a patch can fix the error later.
  1015  		for _, arg := range arguments {
  1016  			_, _ = v.visit(arg)
  1017  		}
  1018  		return fn.Out(0), err
  1019  	}
  1020  
  1021  	for i, arg := range arguments {
  1022  		t, _ := v.visit(arg)
  1023  
  1024  		var in reflect.Type
  1025  		if fn.IsVariadic() && i >= fnNumIn-1 {
  1026  			// For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int).
  1027  			// As we compare arguments one by one, we need underling type.
  1028  			in = fn.In(fn.NumIn() - 1).Elem()
  1029  		} else {
  1030  			in = fn.In(i + fnInOffset)
  1031  		}
  1032  
  1033  		if isFloat(in) && isInteger(t) {
  1034  			traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in)
  1035  			continue
  1036  		}
  1037  
  1038  		if isInteger(in) && isInteger(t) && kind(t) != kind(in) {
  1039  			traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in)
  1040  			continue
  1041  		}
  1042  
  1043  		if t == nil {
  1044  			continue
  1045  		}
  1046  
  1047  		if !(t.AssignableTo(in) || deref.Type(t).AssignableTo(in)) && kind(t) != reflect.Interface {
  1048  			return anyType, &file.Error{
  1049  				Location: arg.Location(),
  1050  				Message:  fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name),
  1051  			}
  1052  		}
  1053  	}
  1054  
  1055  	return fn.Out(0), nil
  1056  }
  1057  
  1058  func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newType reflect.Type) {
  1059  	switch (*node).(type) {
  1060  	case *ast.IntegerNode:
  1061  		*node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)}
  1062  		(*node).SetType(newType)
  1063  	case *ast.UnaryNode:
  1064  		unaryNode := (*node).(*ast.UnaryNode)
  1065  		traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newType)
  1066  	case *ast.BinaryNode:
  1067  		binaryNode := (*node).(*ast.BinaryNode)
  1068  		switch binaryNode.Operator {
  1069  		case "+", "-", "*":
  1070  			traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newType)
  1071  			traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newType)
  1072  		}
  1073  	}
  1074  }
  1075  
  1076  func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType reflect.Type) {
  1077  	switch (*node).(type) {
  1078  	case *ast.IntegerNode:
  1079  		(*node).SetType(newType)
  1080  	case *ast.UnaryNode:
  1081  		(*node).SetType(newType)
  1082  		unaryNode := (*node).(*ast.UnaryNode)
  1083  		traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType)
  1084  	case *ast.BinaryNode:
  1085  		// TODO: Binary node return type is dependent on the type of the operands. We can't just change the type of the node.
  1086  		binaryNode := (*node).(*ast.BinaryNode)
  1087  		switch binaryNode.Operator {
  1088  		case "+", "-", "*":
  1089  			traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newType)
  1090  			traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newType)
  1091  		}
  1092  	}
  1093  }
  1094  
  1095  func (v *checker) ClosureNode(node *ast.ClosureNode) (reflect.Type, info) {
  1096  	t, _ := v.visit(node.Node)
  1097  	if t == nil {
  1098  		return v.error(node.Node, "closure cannot be nil")
  1099  	}
  1100  	return reflect.FuncOf([]reflect.Type{anyType}, []reflect.Type{t}, false), info{}
  1101  }
  1102  
  1103  func (v *checker) PointerNode(node *ast.PointerNode) (reflect.Type, info) {
  1104  	if len(v.predicateScopes) == 0 {
  1105  		return v.error(node, "cannot use pointer accessor outside closure")
  1106  	}
  1107  	scope := v.predicateScopes[len(v.predicateScopes)-1]
  1108  	if node.Name == "" {
  1109  		switch scope.vtype.Kind() {
  1110  		case reflect.Interface:
  1111  			return anyType, info{}
  1112  		case reflect.Array, reflect.Slice:
  1113  			return scope.vtype.Elem(), info{}
  1114  		}
  1115  		return v.error(node, "cannot use %v as array", scope)
  1116  	}
  1117  	if scope.vars != nil {
  1118  		if t, ok := scope.vars[node.Name]; ok {
  1119  			return t, info{}
  1120  		}
  1121  	}
  1122  	return v.error(node, "unknown pointer #%v", node.Name)
  1123  }
  1124  
  1125  func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) (reflect.Type, info) {
  1126  	if _, ok := v.config.Types[node.Name]; ok {
  1127  		return v.error(node, "cannot redeclare %v", node.Name)
  1128  	}
  1129  	if _, ok := v.config.Functions[node.Name]; ok {
  1130  		return v.error(node, "cannot redeclare function %v", node.Name)
  1131  	}
  1132  	if _, ok := v.config.Builtins[node.Name]; ok {
  1133  		return v.error(node, "cannot redeclare builtin %v", node.Name)
  1134  	}
  1135  	if _, ok := v.lookupVariable(node.Name); ok {
  1136  		return v.error(node, "cannot redeclare variable %v", node.Name)
  1137  	}
  1138  	vtype, vinfo := v.visit(node.Value)
  1139  	v.varScopes = append(v.varScopes, varScope{node.Name, vtype, vinfo})
  1140  	t, i := v.visit(node.Expr)
  1141  	v.varScopes = v.varScopes[:len(v.varScopes)-1]
  1142  	return t, i
  1143  }
  1144  
  1145  func (v *checker) lookupVariable(name string) (varScope, bool) {
  1146  	for i := len(v.varScopes) - 1; i >= 0; i-- {
  1147  		if v.varScopes[i].name == name {
  1148  			return v.varScopes[i], true
  1149  		}
  1150  	}
  1151  	return varScope{}, false
  1152  }
  1153  
  1154  func (v *checker) ConditionalNode(node *ast.ConditionalNode) (reflect.Type, info) {
  1155  	c, _ := v.visit(node.Cond)
  1156  	if !isBool(c) && !isAny(c) {
  1157  		return v.error(node.Cond, "non-bool expression (type %v) used as condition", c)
  1158  	}
  1159  
  1160  	t1, _ := v.visit(node.Exp1)
  1161  	t2, _ := v.visit(node.Exp2)
  1162  
  1163  	if t1 == nil && t2 != nil {
  1164  		return t2, info{}
  1165  	}
  1166  	if t1 != nil && t2 == nil {
  1167  		return t1, info{}
  1168  	}
  1169  	if t1 == nil && t2 == nil {
  1170  		return nilType, info{}
  1171  	}
  1172  	if t1.AssignableTo(t2) {
  1173  		return t1, info{}
  1174  	}
  1175  	return anyType, info{}
  1176  }
  1177  
  1178  func (v *checker) ArrayNode(node *ast.ArrayNode) (reflect.Type, info) {
  1179  	var prev reflect.Type
  1180  	allElementsAreSameType := true
  1181  	for i, node := range node.Nodes {
  1182  		curr, _ := v.visit(node)
  1183  		if i > 0 {
  1184  			if curr == nil || prev == nil {
  1185  				allElementsAreSameType = false
  1186  			} else if curr.Kind() != prev.Kind() {
  1187  				allElementsAreSameType = false
  1188  			}
  1189  		}
  1190  		prev = curr
  1191  	}
  1192  	if allElementsAreSameType && prev != nil {
  1193  		return arrayType, info{elem: prev}
  1194  	}
  1195  	return arrayType, info{}
  1196  }
  1197  
  1198  func (v *checker) MapNode(node *ast.MapNode) (reflect.Type, info) {
  1199  	for _, pair := range node.Pairs {
  1200  		v.visit(pair)
  1201  	}
  1202  	return mapType, info{}
  1203  }
  1204  
  1205  func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) {
  1206  	v.visit(node.Key)
  1207  	v.visit(node.Value)
  1208  	return nilType, info{}
  1209  }