github.com/rajeev159/opa@v0.45.0/ast/check.go (about)

     1  // Copyright 2017 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package ast
     6  
     7  import (
     8  	"fmt"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/open-policy-agent/opa/types"
    13  	"github.com/open-policy-agent/opa/util"
    14  )
    15  
    16  type varRewriter func(Ref) Ref
    17  
    18  // exprChecker defines the interface for executing type checking on a single
    19  // expression. The exprChecker must update the provided TypeEnv with inferred
    20  // types of vars.
    21  type exprChecker func(*TypeEnv, *Expr) *Error
    22  
    23  // typeChecker implements type checking on queries and rules. Errors are
    24  // accumulated on the typeChecker so that a single run can report multiple
    25  // issues.
    26  type typeChecker struct {
    27  	errs         Errors
    28  	exprCheckers map[string]exprChecker
    29  	varRewriter  varRewriter
    30  	ss           *SchemaSet
    31  	allowNet     []string
    32  	input        types.Type
    33  }
    34  
    35  // newTypeChecker returns a new typeChecker object that has no errors.
    36  func newTypeChecker() *typeChecker {
    37  	tc := &typeChecker{}
    38  	tc.exprCheckers = map[string]exprChecker{
    39  		"eq": tc.checkExprEq,
    40  	}
    41  	return tc
    42  }
    43  
    44  func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv {
    45  	if exist != nil {
    46  		return exist.wrap()
    47  	}
    48  	env := newTypeEnv(tc.copy)
    49  	if tc.input != nil {
    50  		env.tree.Put(InputRootRef, tc.input)
    51  	}
    52  	return env
    53  }
    54  
    55  func (tc *typeChecker) copy() *typeChecker {
    56  	return newTypeChecker().
    57  		WithVarRewriter(tc.varRewriter).
    58  		WithSchemaSet(tc.ss).
    59  		WithAllowNet(tc.allowNet).
    60  		WithInputType(tc.input)
    61  }
    62  
    63  func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker {
    64  	tc.ss = ss
    65  	return tc
    66  }
    67  
    68  func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker {
    69  	tc.allowNet = hosts
    70  	return tc
    71  }
    72  
    73  func (tc *typeChecker) WithVarRewriter(f varRewriter) *typeChecker {
    74  	tc.varRewriter = f
    75  	return tc
    76  }
    77  
    78  func (tc *typeChecker) WithInputType(tpe types.Type) *typeChecker {
    79  	tc.input = tpe
    80  	return tc
    81  }
    82  
    83  // Env returns a type environment for the specified built-ins with any other
    84  // global types configured on the checker. In practice, this is the default
    85  // environment that other statements will be checked against.
    86  func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv {
    87  	env := tc.newEnv(nil)
    88  	for _, bi := range builtins {
    89  		env.tree.Put(bi.Ref(), bi.Decl)
    90  	}
    91  	return env
    92  }
    93  
    94  // CheckBody runs type checking on the body and returns a TypeEnv if no errors
    95  // are found. The resulting TypeEnv wraps the provided one. The resulting
    96  // TypeEnv will be able to resolve types of vars contained in the body.
    97  func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {
    98  
    99  	errors := []*Error{}
   100  	env = tc.newEnv(env)
   101  
   102  	WalkExprs(body, func(expr *Expr) bool {
   103  
   104  		closureErrs := tc.checkClosures(env, expr)
   105  		for _, err := range closureErrs {
   106  			errors = append(errors, err)
   107  		}
   108  
   109  		hasClosureErrors := len(closureErrs) > 0
   110  
   111  		vis := newRefChecker(env, tc.varRewriter)
   112  		NewGenericVisitor(vis.Visit).Walk(expr)
   113  		for _, err := range vis.errs {
   114  			errors = append(errors, err)
   115  		}
   116  
   117  		hasRefErrors := len(vis.errs) > 0
   118  
   119  		if err := tc.checkExpr(env, expr); err != nil {
   120  			// Suppress this error if a more actionable one has occurred. In
   121  			// this case, if an error occurred in a ref or closure contained in
   122  			// this expression, and the error is due to a nil type, then it's
   123  			// likely to be the result of the more specific error.
   124  			skip := (hasClosureErrors || hasRefErrors) && causedByNilType(err)
   125  			if !skip {
   126  				errors = append(errors, err)
   127  			}
   128  		}
   129  		return true
   130  	})
   131  
   132  	tc.err(errors)
   133  	return env, errors
   134  }
   135  
   136  // CheckTypes runs type checking on the rules returns a TypeEnv if no errors
   137  // are found. The resulting TypeEnv wraps the provided one. The resulting
   138  // TypeEnv will be able to resolve types of refs that refer to rules.
   139  func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T, as *AnnotationSet) (*TypeEnv, Errors) {
   140  	env = tc.newEnv(env)
   141  	for _, s := range sorted {
   142  		tc.checkRule(env, as, s.(*Rule))
   143  	}
   144  	tc.errs.Sort()
   145  	return env, tc.errs
   146  }
   147  
   148  func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors {
   149  	var result Errors
   150  	WalkClosures(expr, func(x interface{}) bool {
   151  		switch x := x.(type) {
   152  		case *ArrayComprehension:
   153  			_, errs := tc.copy().CheckBody(env, x.Body)
   154  			if len(errs) > 0 {
   155  				result = errs
   156  				return true
   157  			}
   158  		case *SetComprehension:
   159  			_, errs := tc.copy().CheckBody(env, x.Body)
   160  			if len(errs) > 0 {
   161  				result = errs
   162  				return true
   163  			}
   164  		case *ObjectComprehension:
   165  			_, errs := tc.copy().CheckBody(env, x.Body)
   166  			if len(errs) > 0 {
   167  				result = errs
   168  				return true
   169  			}
   170  		}
   171  		return false
   172  	})
   173  	return result
   174  }
   175  
   176  func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
   177  
   178  	env = env.wrap()
   179  
   180  	if schemaAnnots := getRuleAnnotation(as, rule); schemaAnnots != nil {
   181  		for _, schemaAnnot := range schemaAnnots {
   182  			ref, refType, err := processAnnotation(tc.ss, schemaAnnot, rule, tc.allowNet)
   183  			if err != nil {
   184  				tc.err([]*Error{err})
   185  				continue
   186  			}
   187  			prefixRef, t := getPrefix(env, ref)
   188  			if t == nil || len(prefixRef) == len(ref) {
   189  				env.tree.Put(ref, refType)
   190  			} else {
   191  				newType, err := override(ref[len(prefixRef):], t, refType, rule)
   192  				if err != nil {
   193  					tc.err([]*Error{err})
   194  					continue
   195  				}
   196  				env.tree.Put(prefixRef, newType)
   197  			}
   198  		}
   199  	}
   200  
   201  	cpy, err := tc.CheckBody(env, rule.Body)
   202  	env = env.next
   203  	path := rule.Path()
   204  
   205  	if len(err) > 0 {
   206  		// if the rule/function contains an error, add it to the type env so
   207  		// that expressions that refer to this rule/function do not encounter
   208  		// type errors.
   209  		env.tree.Put(path, types.A)
   210  		return
   211  	}
   212  
   213  	var tpe types.Type
   214  
   215  	if len(rule.Head.Args) > 0 {
   216  
   217  		// If args are not referred to in body, infer as any.
   218  		WalkVars(rule.Head.Args, func(v Var) bool {
   219  			if cpy.Get(v) == nil {
   220  				cpy.tree.PutOne(v, types.A)
   221  			}
   222  			return false
   223  		})
   224  
   225  		// Construct function type.
   226  		args := make([]types.Type, len(rule.Head.Args))
   227  		for i := 0; i < len(rule.Head.Args); i++ {
   228  			args[i] = cpy.Get(rule.Head.Args[i])
   229  		}
   230  
   231  		f := types.NewFunction(args, cpy.Get(rule.Head.Value))
   232  
   233  		// Union with existing.
   234  		exist := env.tree.Get(path)
   235  		tpe = types.Or(exist, f)
   236  
   237  	} else {
   238  		switch rule.Head.DocKind() {
   239  		case CompleteDoc:
   240  			typeV := cpy.Get(rule.Head.Value)
   241  			if typeV != nil {
   242  				exist := env.tree.Get(path)
   243  				tpe = types.Or(typeV, exist)
   244  			}
   245  		case PartialObjectDoc:
   246  			typeK := cpy.Get(rule.Head.Key)
   247  			typeV := cpy.Get(rule.Head.Value)
   248  			if typeK != nil && typeV != nil {
   249  				exist := env.tree.Get(path)
   250  				typeV = types.Or(types.Values(exist), typeV)
   251  				typeK = types.Or(types.Keys(exist), typeK)
   252  				tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV))
   253  			}
   254  		case PartialSetDoc:
   255  			typeK := cpy.Get(rule.Head.Key)
   256  			if typeK != nil {
   257  				exist := env.tree.Get(path)
   258  				typeK = types.Or(types.Keys(exist), typeK)
   259  				tpe = types.NewSet(typeK)
   260  			}
   261  		}
   262  	}
   263  
   264  	if tpe != nil {
   265  		env.tree.Put(path, tpe)
   266  	}
   267  }
   268  
   269  func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error {
   270  	if err := tc.checkExprWith(env, expr, 0); err != nil {
   271  		return err
   272  	}
   273  	if !expr.IsCall() {
   274  		return nil
   275  	}
   276  
   277  	checker := tc.exprCheckers[expr.Operator().String()]
   278  	if checker != nil {
   279  		return checker(env, expr)
   280  	}
   281  
   282  	return tc.checkExprBuiltin(env, expr)
   283  }
   284  
   285  func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error {
   286  
   287  	args := expr.Operands()
   288  	pre := getArgTypes(env, args)
   289  
   290  	// NOTE(tsandall): undefined functions will have been caught earlier in the
   291  	// compiler. We check for undefined functions before the safety check so
   292  	// that references to non-existent functions result in undefined function
   293  	// errors as opposed to unsafe var errors.
   294  	//
   295  	// We cannot run type checking before the safety check because part of the
   296  	// type checker relies on reordering (in particular for references to local
   297  	// vars).
   298  	name := expr.Operator()
   299  	tpe := env.Get(name)
   300  
   301  	if tpe == nil {
   302  		return NewError(TypeErr, expr.Location, "undefined function %v", name)
   303  	}
   304  
   305  	// check if the expression refers to a function that contains an error
   306  	_, ok := tpe.(types.Any)
   307  	if ok {
   308  		return nil
   309  	}
   310  
   311  	ftpe, ok := tpe.(*types.Function)
   312  	if !ok {
   313  		return NewError(TypeErr, expr.Location, "undefined function %v", name)
   314  	}
   315  
   316  	fargs := ftpe.FuncArgs()
   317  	namedFargs := ftpe.NamedFuncArgs()
   318  
   319  	if ftpe.Result() != nil {
   320  		fargs.Args = append(fargs.Args, ftpe.Result())
   321  		namedFargs.Args = append(namedFargs.Args, ftpe.NamedResult())
   322  	}
   323  
   324  	if len(args) > len(fargs.Args) && fargs.Variadic == nil {
   325  		return newArgError(expr.Location, name, "too many arguments", pre, namedFargs)
   326  	}
   327  
   328  	if len(args) < len(ftpe.FuncArgs().Args) {
   329  		return newArgError(expr.Location, name, "too few arguments", pre, namedFargs)
   330  	}
   331  
   332  	for i := range args {
   333  		if !unify1(env, args[i], fargs.Arg(i), false) {
   334  			post := make([]types.Type, len(args))
   335  			for i := range args {
   336  				post[i] = env.Get(args[i])
   337  			}
   338  			return newArgError(expr.Location, name, "invalid argument(s)", post, namedFargs)
   339  		}
   340  	}
   341  
   342  	return nil
   343  }
   344  
   345  func (tc *typeChecker) checkExprEq(env *TypeEnv, expr *Expr) *Error {
   346  
   347  	pre := getArgTypes(env, expr.Operands())
   348  	exp := Equality.Decl.FuncArgs()
   349  
   350  	if len(pre) < len(exp.Args) {
   351  		return newArgError(expr.Location, expr.Operator(), "too few arguments", pre, exp)
   352  	}
   353  
   354  	if len(exp.Args) < len(pre) {
   355  		return newArgError(expr.Location, expr.Operator(), "too many arguments", pre, exp)
   356  	}
   357  
   358  	a, b := expr.Operand(0), expr.Operand(1)
   359  	typeA, typeB := env.Get(a), env.Get(b)
   360  
   361  	if !unify2(env, a, typeA, b, typeB) {
   362  		err := NewError(TypeErr, expr.Location, "match error")
   363  		err.Details = &UnificationErrDetail{
   364  			Left:  typeA,
   365  			Right: typeB,
   366  		}
   367  		return err
   368  	}
   369  
   370  	return nil
   371  }
   372  
   373  func (tc *typeChecker) checkExprWith(env *TypeEnv, expr *Expr, i int) *Error {
   374  	if i == len(expr.With) {
   375  		return nil
   376  	}
   377  
   378  	target, value := expr.With[i].Target, expr.With[i].Value
   379  	targetType, valueType := env.Get(target), env.Get(value)
   380  
   381  	if t, ok := targetType.(*types.Function); ok { // built-in function replacement
   382  		switch v := valueType.(type) {
   383  		case *types.Function: // ...by function
   384  			if !unifies(targetType, valueType) {
   385  				return newArgError(expr.With[i].Loc(), target.Value.(Ref), "arity mismatch", v.Args(), t.NamedFuncArgs())
   386  			}
   387  		default: // ... by value, nothing to check
   388  		}
   389  	}
   390  
   391  	return tc.checkExprWith(env, expr, i+1)
   392  }
   393  
   394  func unify2(env *TypeEnv, a *Term, typeA types.Type, b *Term, typeB types.Type) bool {
   395  
   396  	nilA := types.Nil(typeA)
   397  	nilB := types.Nil(typeB)
   398  
   399  	if nilA && !nilB {
   400  		return unify1(env, a, typeB, false)
   401  	} else if nilB && !nilA {
   402  		return unify1(env, b, typeA, false)
   403  	} else if !nilA && !nilB {
   404  		return unifies(typeA, typeB)
   405  	}
   406  
   407  	switch a.Value.(type) {
   408  	case *Array:
   409  		return unify2Array(env, a, b)
   410  	case *object:
   411  		return unify2Object(env, a, b)
   412  	case Var:
   413  		switch b.Value.(type) {
   414  		case Var:
   415  			return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false)
   416  		case *Array:
   417  			return unify2Array(env, b, a)
   418  		case *object:
   419  			return unify2Object(env, b, a)
   420  		}
   421  	}
   422  
   423  	return false
   424  }
   425  
   426  func unify2Array(env *TypeEnv, a *Term, b *Term) bool {
   427  	arr := a.Value.(*Array)
   428  	switch bv := b.Value.(type) {
   429  	case *Array:
   430  		if arr.Len() == bv.Len() {
   431  			for i := 0; i < arr.Len(); i++ {
   432  				if !unify2(env, arr.Elem(i), env.Get(arr.Elem(i)), bv.Elem(i), env.Get(bv.Elem(i))) {
   433  					return false
   434  				}
   435  			}
   436  			return true
   437  		}
   438  	case Var:
   439  		return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false)
   440  	}
   441  	return false
   442  }
   443  
   444  func unify2Object(env *TypeEnv, a *Term, b *Term) bool {
   445  	obj := a.Value.(Object)
   446  	switch bv := b.Value.(type) {
   447  	case *object:
   448  		cv := obj.Intersect(bv)
   449  		if obj.Len() == bv.Len() && bv.Len() == len(cv) {
   450  			for i := range cv {
   451  				if !unify2(env, cv[i][1], env.Get(cv[i][1]), cv[i][2], env.Get(cv[i][2])) {
   452  					return false
   453  				}
   454  			}
   455  			return true
   456  		}
   457  	case Var:
   458  		return unify1(env, a, types.A, false) && unify1(env, b, env.Get(a), false)
   459  	}
   460  	return false
   461  }
   462  
   463  func unify1(env *TypeEnv, term *Term, tpe types.Type, union bool) bool {
   464  	switch v := term.Value.(type) {
   465  	case *Array:
   466  		switch tpe := tpe.(type) {
   467  		case *types.Array:
   468  			return unify1Array(env, v, tpe, union)
   469  		case types.Any:
   470  			if types.Compare(tpe, types.A) == 0 {
   471  				for i := 0; i < v.Len(); i++ {
   472  					unify1(env, v.Elem(i), types.A, true)
   473  				}
   474  				return true
   475  			}
   476  			unifies := false
   477  			for i := range tpe {
   478  				unifies = unify1(env, term, tpe[i], true) || unifies
   479  			}
   480  			return unifies
   481  		}
   482  		return false
   483  	case *object:
   484  		switch tpe := tpe.(type) {
   485  		case *types.Object:
   486  			return unify1Object(env, v, tpe, union)
   487  		case types.Any:
   488  			if types.Compare(tpe, types.A) == 0 {
   489  				v.Foreach(func(key, value *Term) {
   490  					unify1(env, key, types.A, true)
   491  					unify1(env, value, types.A, true)
   492  				})
   493  				return true
   494  			}
   495  			unifies := false
   496  			for i := range tpe {
   497  				unifies = unify1(env, term, tpe[i], true) || unifies
   498  			}
   499  			return unifies
   500  		}
   501  		return false
   502  	case Set:
   503  		switch tpe := tpe.(type) {
   504  		case *types.Set:
   505  			return unify1Set(env, v, tpe, union)
   506  		case types.Any:
   507  			if types.Compare(tpe, types.A) == 0 {
   508  				v.Foreach(func(elem *Term) {
   509  					unify1(env, elem, types.A, true)
   510  				})
   511  				return true
   512  			}
   513  			unifies := false
   514  			for i := range tpe {
   515  				unifies = unify1(env, term, tpe[i], true) || unifies
   516  			}
   517  			return unifies
   518  		}
   519  		return false
   520  	case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
   521  		return unifies(env.Get(v), tpe)
   522  	case Var:
   523  		if !union {
   524  			if exist := env.Get(v); exist != nil {
   525  				return unifies(exist, tpe)
   526  			}
   527  			env.tree.PutOne(term.Value, tpe)
   528  		} else {
   529  			env.tree.PutOne(term.Value, types.Or(env.Get(v), tpe))
   530  		}
   531  		return true
   532  	default:
   533  		if !IsConstant(v) {
   534  			panic("unreachable")
   535  		}
   536  		return unifies(env.Get(term), tpe)
   537  	}
   538  }
   539  
   540  func unify1Array(env *TypeEnv, val *Array, tpe *types.Array, union bool) bool {
   541  	if val.Len() != tpe.Len() && tpe.Dynamic() == nil {
   542  		return false
   543  	}
   544  	for i := 0; i < val.Len(); i++ {
   545  		if !unify1(env, val.Elem(i), tpe.Select(i), union) {
   546  			return false
   547  		}
   548  	}
   549  	return true
   550  }
   551  
   552  func unify1Object(env *TypeEnv, val Object, tpe *types.Object, union bool) bool {
   553  	if val.Len() != len(tpe.Keys()) && tpe.DynamicValue() == nil {
   554  		return false
   555  	}
   556  	stop := val.Until(func(k, v *Term) bool {
   557  		if IsConstant(k.Value) {
   558  			if child := selectConstant(tpe, k); child != nil {
   559  				if !unify1(env, v, child, union) {
   560  					return true
   561  				}
   562  			} else {
   563  				return true
   564  			}
   565  		} else {
   566  			// Inferring type of value under dynamic key would involve unioning
   567  			// with all property values of tpe whose keys unify. For now, type
   568  			// these values as Any. We can investigate stricter inference in
   569  			// the future.
   570  			unify1(env, v, types.A, union)
   571  		}
   572  		return false
   573  	})
   574  	return !stop
   575  }
   576  
   577  func unify1Set(env *TypeEnv, val Set, tpe *types.Set, union bool) bool {
   578  	of := types.Values(tpe)
   579  	return !val.Until(func(elem *Term) bool {
   580  		return !unify1(env, elem, of, union)
   581  	})
   582  }
   583  
   584  func (tc *typeChecker) err(errors []*Error) {
   585  	tc.errs = append(tc.errs, errors...)
   586  }
   587  
   588  type refChecker struct {
   589  	env         *TypeEnv
   590  	errs        Errors
   591  	varRewriter varRewriter
   592  }
   593  
   594  func rewriteVarsNop(node Ref) Ref {
   595  	return node
   596  }
   597  
   598  func newRefChecker(env *TypeEnv, f varRewriter) *refChecker {
   599  
   600  	if f == nil {
   601  		f = rewriteVarsNop
   602  	}
   603  
   604  	return &refChecker{
   605  		env:         env,
   606  		errs:        nil,
   607  		varRewriter: f,
   608  	}
   609  }
   610  
   611  func (rc *refChecker) Visit(x interface{}) bool {
   612  	switch x := x.(type) {
   613  	case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
   614  		return true
   615  	case *Expr:
   616  		switch terms := x.Terms.(type) {
   617  		case []*Term:
   618  			for i := 1; i < len(terms); i++ {
   619  				NewGenericVisitor(rc.Visit).Walk(terms[i])
   620  			}
   621  			return true
   622  		case *Term:
   623  			NewGenericVisitor(rc.Visit).Walk(terms)
   624  			return true
   625  		}
   626  	case Ref:
   627  		if err := rc.checkApply(rc.env, x); err != nil {
   628  			rc.errs = append(rc.errs, err)
   629  			return true
   630  		}
   631  		if err := rc.checkRef(rc.env, rc.env.tree, x, 0); err != nil {
   632  			rc.errs = append(rc.errs, err)
   633  		}
   634  	}
   635  	return false
   636  }
   637  
   638  func (rc *refChecker) checkApply(curr *TypeEnv, ref Ref) *Error {
   639  	switch tpe := curr.Get(ref).(type) {
   640  	case *types.Function: // NOTE(sr): We don't support first-class functions, except for `with`.
   641  		return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), len(ref)-1, tpe)
   642  	}
   643  
   644  	return nil
   645  }
   646  
   647  func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx int) *Error {
   648  
   649  	if idx == len(ref) {
   650  		return nil
   651  	}
   652  
   653  	head := ref[idx]
   654  
   655  	// Handle constant ref operands, i.e., strings or the ref head.
   656  	if _, ok := head.Value.(String); ok || idx == 0 {
   657  
   658  		child := node.Child(head.Value)
   659  		if child == nil {
   660  
   661  			if curr.next != nil {
   662  				next := curr.next
   663  				return rc.checkRef(next, next.tree, ref, 0)
   664  			}
   665  
   666  			if RootDocumentNames.Contains(ref[0]) {
   667  				return rc.checkRefLeaf(types.A, ref, 1)
   668  			}
   669  
   670  			return rc.checkRefLeaf(types.A, ref, 0)
   671  		}
   672  
   673  		if child.Leaf() {
   674  			return rc.checkRefLeaf(child.Value(), ref, idx+1)
   675  		}
   676  
   677  		return rc.checkRef(curr, child, ref, idx+1)
   678  	}
   679  
   680  	// Handle dynamic ref operands.
   681  	switch value := head.Value.(type) {
   682  
   683  	case Var:
   684  
   685  		if exist := rc.env.Get(value); exist != nil {
   686  			if !unifies(types.S, exist) {
   687  				return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node))
   688  			}
   689  		} else {
   690  			rc.env.tree.PutOne(value, types.S)
   691  		}
   692  
   693  	case Ref:
   694  
   695  		exist := rc.env.Get(value)
   696  		if exist == nil {
   697  			// If ref type is unknown, an error will already be reported so
   698  			// stop here.
   699  			return nil
   700  		}
   701  
   702  		if !unifies(types.S, exist) {
   703  			return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node))
   704  		}
   705  
   706  	// Catch other ref operand types here. Non-leaf nodes must be referred to
   707  	// with string values.
   708  	default:
   709  		return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.S, getOneOfForNode(node))
   710  	}
   711  
   712  	// Run checking on remaining portion of the ref. Note, since the ref
   713  	// potentially refers to data for which no type information exists,
   714  	// checking should never fail.
   715  	node.Children().Iter(func(_, child util.T) bool {
   716  		_ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error
   717  		return false
   718  	})
   719  
   720  	return nil
   721  }
   722  
   723  func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error {
   724  
   725  	if idx == len(ref) {
   726  		return nil
   727  	}
   728  
   729  	head := ref[idx]
   730  
   731  	keys := types.Keys(tpe)
   732  	if keys == nil {
   733  		return newRefErrUnsupported(ref[0].Location, rc.varRewriter(ref), idx-1, tpe)
   734  	}
   735  
   736  	switch value := head.Value.(type) {
   737  
   738  	case Var:
   739  		if exist := rc.env.Get(value); exist != nil {
   740  			if !unifies(exist, keys) {
   741  				return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe))
   742  			}
   743  		} else {
   744  			rc.env.tree.PutOne(value, types.Keys(tpe))
   745  		}
   746  
   747  	case Ref:
   748  		if exist := rc.env.Get(value); exist != nil {
   749  			if !unifies(exist, keys) {
   750  				return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe))
   751  			}
   752  		}
   753  
   754  	case *Array, Object, Set:
   755  		if !unify1(rc.env, head, keys, false) {
   756  			return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, rc.env.Get(head), keys, nil)
   757  		}
   758  
   759  	default:
   760  		child := selectConstant(tpe, head)
   761  		if child == nil {
   762  			return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.Keys(tpe), getOneOfForType(tpe))
   763  		}
   764  		return rc.checkRefLeaf(child, ref, idx+1)
   765  	}
   766  
   767  	return rc.checkRefLeaf(types.Values(tpe), ref, idx+1)
   768  }
   769  
   770  func unifies(a, b types.Type) bool {
   771  
   772  	if a == nil || b == nil {
   773  		return false
   774  	}
   775  
   776  	anyA, ok1 := a.(types.Any)
   777  	if ok1 {
   778  		if unifiesAny(anyA, b) {
   779  			return true
   780  		}
   781  	}
   782  
   783  	anyB, ok2 := b.(types.Any)
   784  	if ok2 {
   785  		if unifiesAny(anyB, a) {
   786  			return true
   787  		}
   788  	}
   789  
   790  	if ok1 || ok2 {
   791  		return false
   792  	}
   793  
   794  	switch a := a.(type) {
   795  	case types.Null:
   796  		_, ok := b.(types.Null)
   797  		return ok
   798  	case types.Boolean:
   799  		_, ok := b.(types.Boolean)
   800  		return ok
   801  	case types.Number:
   802  		_, ok := b.(types.Number)
   803  		return ok
   804  	case types.String:
   805  		_, ok := b.(types.String)
   806  		return ok
   807  	case *types.Array:
   808  		b, ok := b.(*types.Array)
   809  		if !ok {
   810  			return false
   811  		}
   812  		return unifiesArrays(a, b)
   813  	case *types.Object:
   814  		b, ok := b.(*types.Object)
   815  		if !ok {
   816  			return false
   817  		}
   818  		return unifiesObjects(a, b)
   819  	case *types.Set:
   820  		b, ok := b.(*types.Set)
   821  		if !ok {
   822  			return false
   823  		}
   824  		return unifies(types.Values(a), types.Values(b))
   825  	case *types.Function:
   826  		// NOTE(sr): variadic functions can only be internal ones, and we've forbidden
   827  		// their replacement via `with`; so we disregard variadic here
   828  		if types.Arity(a) == types.Arity(b) {
   829  			b := b.(*types.Function)
   830  			for i := range a.FuncArgs().Args {
   831  				if !unifies(a.FuncArgs().Arg(i), b.FuncArgs().Arg(i)) {
   832  					return false
   833  				}
   834  			}
   835  			return true
   836  		}
   837  		return false
   838  	default:
   839  		panic("unreachable")
   840  	}
   841  }
   842  
   843  func unifiesAny(a types.Any, b types.Type) bool {
   844  	if _, ok := b.(*types.Function); ok {
   845  		return false
   846  	}
   847  	for i := range a {
   848  		if unifies(a[i], b) {
   849  			return true
   850  		}
   851  	}
   852  	return len(a) == 0
   853  }
   854  
   855  func unifiesArrays(a, b *types.Array) bool {
   856  
   857  	if !unifiesArraysStatic(a, b) {
   858  		return false
   859  	}
   860  
   861  	if !unifiesArraysStatic(b, a) {
   862  		return false
   863  	}
   864  
   865  	return a.Dynamic() == nil || b.Dynamic() == nil || unifies(a.Dynamic(), b.Dynamic())
   866  }
   867  
   868  func unifiesArraysStatic(a, b *types.Array) bool {
   869  	if a.Len() != 0 {
   870  		for i := 0; i < a.Len(); i++ {
   871  			if !unifies(a.Select(i), b.Select(i)) {
   872  				return false
   873  			}
   874  		}
   875  	}
   876  	return true
   877  }
   878  
   879  func unifiesObjects(a, b *types.Object) bool {
   880  	if !unifiesObjectsStatic(a, b) {
   881  		return false
   882  	}
   883  
   884  	if !unifiesObjectsStatic(b, a) {
   885  		return false
   886  	}
   887  
   888  	return a.DynamicValue() == nil || b.DynamicValue() == nil || unifies(a.DynamicValue(), b.DynamicValue())
   889  }
   890  
   891  func unifiesObjectsStatic(a, b *types.Object) bool {
   892  	for _, k := range a.Keys() {
   893  		if !unifies(a.Select(k), b.Select(k)) {
   894  			return false
   895  		}
   896  	}
   897  	return true
   898  }
   899  
   900  // typeErrorCause defines an interface to determine the reason for a type
   901  // error. The type error details implement this interface so that type checking
   902  // can report more actionable errors.
   903  type typeErrorCause interface {
   904  	nilType() bool
   905  }
   906  
   907  func causedByNilType(err *Error) bool {
   908  	cause, ok := err.Details.(typeErrorCause)
   909  	if !ok {
   910  		return false
   911  	}
   912  	return cause.nilType()
   913  }
   914  
   915  // ArgErrDetail represents a generic argument error.
   916  type ArgErrDetail struct {
   917  	Have []types.Type   `json:"have"`
   918  	Want types.FuncArgs `json:"want"`
   919  }
   920  
   921  // Lines returns the string representation of the detail.
   922  func (d *ArgErrDetail) Lines() []string {
   923  	lines := make([]string, 2)
   924  	lines[0] = "have: " + formatArgs(d.Have)
   925  	lines[1] = "want: " + fmt.Sprint(d.Want)
   926  	return lines
   927  }
   928  
   929  func (d *ArgErrDetail) nilType() bool {
   930  	for i := range d.Have {
   931  		if types.Nil(d.Have[i]) {
   932  			return true
   933  		}
   934  	}
   935  	return false
   936  }
   937  
   938  // UnificationErrDetail describes a type mismatch error when two values are
   939  // unified (e.g., x = [1,2,y]).
   940  type UnificationErrDetail struct {
   941  	Left  types.Type `json:"a"`
   942  	Right types.Type `json:"b"`
   943  }
   944  
   945  func (a *UnificationErrDetail) nilType() bool {
   946  	return types.Nil(a.Left) || types.Nil(a.Right)
   947  }
   948  
   949  // Lines returns the string representation of the detail.
   950  func (a *UnificationErrDetail) Lines() []string {
   951  	lines := make([]string, 2)
   952  	lines[0] = fmt.Sprint("left  : ", types.Sprint(a.Left))
   953  	lines[1] = fmt.Sprint("right : ", types.Sprint(a.Right))
   954  	return lines
   955  }
   956  
   957  // RefErrUnsupportedDetail describes an undefined reference error where the
   958  // referenced value does not support dereferencing (e.g., scalars).
   959  type RefErrUnsupportedDetail struct {
   960  	Ref  Ref        `json:"ref"`  // invalid ref
   961  	Pos  int        `json:"pos"`  // invalid element
   962  	Have types.Type `json:"have"` // referenced type
   963  }
   964  
   965  // Lines returns the string representation of the detail.
   966  func (r *RefErrUnsupportedDetail) Lines() []string {
   967  	lines := []string{
   968  		r.Ref.String(),
   969  		strings.Repeat("^", len(r.Ref[:r.Pos+1].String())),
   970  		fmt.Sprintf("have: %v", r.Have),
   971  	}
   972  	return lines
   973  }
   974  
   975  // RefErrInvalidDetail describes an undefined reference error where the referenced
   976  // value does not support the reference operand (e.g., missing object key,
   977  // invalid key type, etc.)
   978  type RefErrInvalidDetail struct {
   979  	Ref   Ref        `json:"ref"`            // invalid ref
   980  	Pos   int        `json:"pos"`            // invalid element
   981  	Have  types.Type `json:"have,omitempty"` // type of invalid element (for var/ref elements)
   982  	Want  types.Type `json:"want"`           // allowed type (for non-object values)
   983  	OneOf []Value    `json:"oneOf"`          // allowed values (e.g., for object keys)
   984  }
   985  
   986  // Lines returns the string representation of the detail.
   987  func (r *RefErrInvalidDetail) Lines() []string {
   988  	lines := []string{r.Ref.String()}
   989  	offset := len(r.Ref[:r.Pos].String()) + 1
   990  	pad := strings.Repeat(" ", offset)
   991  	lines = append(lines, fmt.Sprintf("%s^", pad))
   992  	if r.Have != nil {
   993  		lines = append(lines, fmt.Sprintf("%shave (type): %v", pad, r.Have))
   994  	} else {
   995  		lines = append(lines, fmt.Sprintf("%shave: %v", pad, r.Ref[r.Pos]))
   996  	}
   997  	if len(r.OneOf) > 0 {
   998  		lines = append(lines, fmt.Sprintf("%swant (one of): %v", pad, r.OneOf))
   999  	} else {
  1000  		lines = append(lines, fmt.Sprintf("%swant (type): %v", pad, r.Want))
  1001  	}
  1002  	return lines
  1003  }
  1004  
  1005  func formatArgs(args []types.Type) string {
  1006  	buf := make([]string, len(args))
  1007  	for i := range args {
  1008  		buf[i] = types.Sprint(args[i])
  1009  	}
  1010  	return "(" + strings.Join(buf, ", ") + ")"
  1011  }
  1012  
  1013  func newRefErrInvalid(loc *Location, ref Ref, idx int, have, want types.Type, oneOf []Value) *Error {
  1014  	err := newRefError(loc, ref)
  1015  	err.Details = &RefErrInvalidDetail{
  1016  		Ref:   ref,
  1017  		Pos:   idx,
  1018  		Have:  have,
  1019  		Want:  want,
  1020  		OneOf: oneOf,
  1021  	}
  1022  	return err
  1023  }
  1024  
  1025  func newRefErrUnsupported(loc *Location, ref Ref, idx int, have types.Type) *Error {
  1026  	err := newRefError(loc, ref)
  1027  	err.Details = &RefErrUnsupportedDetail{
  1028  		Ref:  ref,
  1029  		Pos:  idx,
  1030  		Have: have,
  1031  	}
  1032  	return err
  1033  }
  1034  
  1035  func newRefError(loc *Location, ref Ref) *Error {
  1036  	return NewError(TypeErr, loc, "undefined ref: %v", ref)
  1037  }
  1038  
  1039  func newArgError(loc *Location, builtinName Ref, msg string, have []types.Type, want types.FuncArgs) *Error {
  1040  	err := NewError(TypeErr, loc, "%v: %v", builtinName, msg)
  1041  	err.Details = &ArgErrDetail{
  1042  		Have: have,
  1043  		Want: want,
  1044  	}
  1045  	return err
  1046  }
  1047  
  1048  func getOneOfForNode(node *typeTreeNode) (result []Value) {
  1049  	node.Children().Iter(func(k, _ util.T) bool {
  1050  		result = append(result, k.(Value))
  1051  		return false
  1052  	})
  1053  
  1054  	sortValueSlice(result)
  1055  	return result
  1056  }
  1057  
  1058  func getOneOfForType(tpe types.Type) (result []Value) {
  1059  	switch tpe := tpe.(type) {
  1060  	case *types.Object:
  1061  		for _, k := range tpe.Keys() {
  1062  			v, err := InterfaceToValue(k)
  1063  			if err != nil {
  1064  				panic(err)
  1065  			}
  1066  			result = append(result, v)
  1067  		}
  1068  
  1069  	case types.Any:
  1070  		for _, object := range tpe {
  1071  			objRes := getOneOfForType(object)
  1072  			result = append(result, objRes...)
  1073  		}
  1074  	}
  1075  
  1076  	result = removeDuplicate(result)
  1077  	sortValueSlice(result)
  1078  	return result
  1079  }
  1080  
  1081  func sortValueSlice(sl []Value) {
  1082  	sort.Slice(sl, func(i, j int) bool {
  1083  		return sl[i].Compare(sl[j]) < 0
  1084  	})
  1085  }
  1086  
  1087  func removeDuplicate(list []Value) []Value {
  1088  	seen := make(map[Value]bool)
  1089  	var newResult []Value
  1090  	for _, item := range list {
  1091  		if !seen[item] {
  1092  			newResult = append(newResult, item)
  1093  			seen[item] = true
  1094  		}
  1095  	}
  1096  	return newResult
  1097  }
  1098  
  1099  func getArgTypes(env *TypeEnv, args []*Term) []types.Type {
  1100  	pre := make([]types.Type, len(args))
  1101  	for i := range args {
  1102  		pre[i] = env.Get(args[i])
  1103  	}
  1104  	return pre
  1105  }
  1106  
  1107  // getPrefix returns the shortest prefix of ref that exists in env
  1108  func getPrefix(env *TypeEnv, ref Ref) (Ref, types.Type) {
  1109  	if len(ref) == 1 {
  1110  		t := env.Get(ref)
  1111  		if t != nil {
  1112  			return ref, t
  1113  		}
  1114  	}
  1115  	for i := 1; i < len(ref); i++ {
  1116  		t := env.Get(ref[:i])
  1117  		if t != nil {
  1118  			return ref[:i], t
  1119  		}
  1120  	}
  1121  	return nil, nil
  1122  }
  1123  
  1124  // override takes a type t and returns a type obtained from t where the path represented by ref within it has type o (overriding the original type of that path)
  1125  func override(ref Ref, t types.Type, o types.Type, rule *Rule) (types.Type, *Error) {
  1126  	var newStaticProps []*types.StaticProperty
  1127  	obj, ok := t.(*types.Object)
  1128  	if !ok {
  1129  		newType, err := getObjectType(ref, o, rule, types.NewDynamicProperty(types.A, types.A))
  1130  		if err != nil {
  1131  			return nil, err
  1132  		}
  1133  		return newType, nil
  1134  	}
  1135  	found := false
  1136  	if ok {
  1137  		staticProps := obj.StaticProperties()
  1138  		for _, prop := range staticProps {
  1139  			valueCopy := prop.Value
  1140  			key, err := InterfaceToValue(prop.Key)
  1141  			if err != nil {
  1142  				return nil, NewError(TypeErr, rule.Location, "unexpected error in override: %s", err.Error())
  1143  			}
  1144  			if len(ref) > 0 && ref[0].Value.Compare(key) == 0 {
  1145  				found = true
  1146  				if len(ref) == 1 {
  1147  					valueCopy = o
  1148  				} else {
  1149  					newVal, err := override(ref[1:], valueCopy, o, rule)
  1150  					if err != nil {
  1151  						return nil, err
  1152  					}
  1153  					valueCopy = newVal
  1154  				}
  1155  			}
  1156  			newStaticProps = append(newStaticProps, types.NewStaticProperty(prop.Key, valueCopy))
  1157  		}
  1158  	}
  1159  
  1160  	// ref[0] is not a top-level key in staticProps, so it must be added
  1161  	if !found {
  1162  		newType, err := getObjectType(ref, o, rule, obj.DynamicProperties())
  1163  		if err != nil {
  1164  			return nil, err
  1165  		}
  1166  		newStaticProps = append(newStaticProps, newType.StaticProperties()...)
  1167  	}
  1168  	return types.NewObject(newStaticProps, obj.DynamicProperties()), nil
  1169  }
  1170  
  1171  func getKeys(ref Ref, rule *Rule) ([]interface{}, *Error) {
  1172  	keys := []interface{}{}
  1173  	for _, refElem := range ref {
  1174  		key, err := JSON(refElem.Value)
  1175  		if err != nil {
  1176  			return nil, NewError(TypeErr, rule.Location, "error getting key from value: %s", err.Error())
  1177  		}
  1178  		keys = append(keys, key)
  1179  	}
  1180  	return keys, nil
  1181  }
  1182  
  1183  func getObjectTypeRec(keys []interface{}, o types.Type, d *types.DynamicProperty) *types.Object {
  1184  	if len(keys) == 1 {
  1185  		staticProps := []*types.StaticProperty{types.NewStaticProperty(keys[0], o)}
  1186  		return types.NewObject(staticProps, d)
  1187  	}
  1188  
  1189  	staticProps := []*types.StaticProperty{types.NewStaticProperty(keys[0], getObjectTypeRec(keys[1:], o, d))}
  1190  	return types.NewObject(staticProps, d)
  1191  }
  1192  
  1193  func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty) (*types.Object, *Error) {
  1194  	keys, err := getKeys(ref, rule)
  1195  	if err != nil {
  1196  		return nil, err
  1197  	}
  1198  	return getObjectTypeRec(keys, o, d), nil
  1199  }
  1200  
  1201  func getRuleAnnotation(as *AnnotationSet, rule *Rule) (result []*SchemaAnnotation) {
  1202  
  1203  	for _, x := range as.GetSubpackagesScope(rule.Module.Package.Path) {
  1204  		result = append(result, x.Schemas...)
  1205  	}
  1206  
  1207  	if x := as.GetPackageScope(rule.Module.Package); x != nil {
  1208  		result = append(result, x.Schemas...)
  1209  	}
  1210  
  1211  	if x := as.GetDocumentScope(rule.Path()); x != nil {
  1212  		result = append(result, x.Schemas...)
  1213  	}
  1214  
  1215  	for _, x := range as.GetRuleScope(rule) {
  1216  		result = append(result, x.Schemas...)
  1217  	}
  1218  
  1219  	return result
  1220  }
  1221  
  1222  func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allowNet []string) (Ref, types.Type, *Error) {
  1223  
  1224  	var schema interface{}
  1225  
  1226  	if annot.Schema != nil {
  1227  		schema = ss.Get(annot.Schema)
  1228  		if schema == nil {
  1229  			return nil, nil, NewError(TypeErr, rule.Location, "undefined schema: %v", annot.Schema)
  1230  		}
  1231  	} else if annot.Definition != nil {
  1232  		schema = *annot.Definition
  1233  	}
  1234  
  1235  	tpe, err := loadSchema(schema, allowNet)
  1236  	if err != nil {
  1237  		return nil, nil, NewError(TypeErr, rule.Location, err.Error())
  1238  	}
  1239  
  1240  	return annot.Path, tpe, nil
  1241  }
  1242  
  1243  func errAnnotationRedeclared(a *Annotations, other *Location) *Error {
  1244  	return NewError(TypeErr, a.Location, "%v annotation redeclared: %v", a.Scope, other)
  1245  }