github.com/rajeev159/opa@v0.45.0/topdown/query.go (about)

     1  package topdown
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"io"
     7  	"sort"
     8  	"time"
     9  
    10  	"github.com/open-policy-agent/opa/ast"
    11  	"github.com/open-policy-agent/opa/metrics"
    12  	"github.com/open-policy-agent/opa/resolver"
    13  	"github.com/open-policy-agent/opa/storage"
    14  	"github.com/open-policy-agent/opa/topdown/builtins"
    15  	"github.com/open-policy-agent/opa/topdown/cache"
    16  	"github.com/open-policy-agent/opa/topdown/copypropagation"
    17  	"github.com/open-policy-agent/opa/topdown/print"
    18  	"github.com/open-policy-agent/opa/tracing"
    19  )
    20  
    21  // QueryResultSet represents a collection of results returned by a query.
    22  type QueryResultSet []QueryResult
    23  
    24  // QueryResult represents a single result returned by a query. The result
    25  // contains bindings for all variables that appear in the query.
    26  type QueryResult map[ast.Var]*ast.Term
    27  
    28  // Query provides a configurable interface for performing query evaluation.
    29  type Query struct {
    30  	seed                   io.Reader
    31  	time                   time.Time
    32  	cancel                 Cancel
    33  	query                  ast.Body
    34  	queryCompiler          ast.QueryCompiler
    35  	compiler               *ast.Compiler
    36  	store                  storage.Store
    37  	txn                    storage.Transaction
    38  	input                  *ast.Term
    39  	external               *resolverTrie
    40  	tracers                []QueryTracer
    41  	plugTraceVars          bool
    42  	unknowns               []*ast.Term
    43  	partialNamespace       string
    44  	skipSaveNamespace      bool
    45  	metrics                metrics.Metrics
    46  	instr                  *Instrumentation
    47  	disableInlining        []ast.Ref
    48  	shallowInlining        bool
    49  	genvarprefix           string
    50  	runtime                *ast.Term
    51  	builtins               map[string]*Builtin
    52  	indexing               bool
    53  	earlyExit              bool
    54  	interQueryBuiltinCache cache.InterQueryCache
    55  	strictBuiltinErrors    bool
    56  	printHook              print.Hook
    57  	tracingOpts            tracing.Options
    58  }
    59  
    60  // Builtin represents a built-in function that queries can call.
    61  type Builtin struct {
    62  	Decl *ast.Builtin
    63  	Func BuiltinFunc
    64  }
    65  
    66  // NewQuery returns a new Query object that can be run.
    67  func NewQuery(query ast.Body) *Query {
    68  	return &Query{
    69  		query:        query,
    70  		genvarprefix: ast.WildcardPrefix,
    71  		indexing:     true,
    72  		earlyExit:    true,
    73  		external:     newResolverTrie(),
    74  	}
    75  }
    76  
    77  // WithQueryCompiler sets the queryCompiler used for the query.
    78  func (q *Query) WithQueryCompiler(queryCompiler ast.QueryCompiler) *Query {
    79  	q.queryCompiler = queryCompiler
    80  	return q
    81  }
    82  
    83  // WithCompiler sets the compiler to use for the query.
    84  func (q *Query) WithCompiler(compiler *ast.Compiler) *Query {
    85  	q.compiler = compiler
    86  	return q
    87  }
    88  
    89  // WithStore sets the store to use for the query.
    90  func (q *Query) WithStore(store storage.Store) *Query {
    91  	q.store = store
    92  	return q
    93  }
    94  
    95  // WithTransaction sets the transaction to use for the query. All queries
    96  // should be performed over a consistent snapshot of the storage layer.
    97  func (q *Query) WithTransaction(txn storage.Transaction) *Query {
    98  	q.txn = txn
    99  	return q
   100  }
   101  
   102  // WithCancel sets the cancellation object to use for the query. Set this if
   103  // you need to abort queries based on a deadline. This is optional.
   104  func (q *Query) WithCancel(cancel Cancel) *Query {
   105  	q.cancel = cancel
   106  	return q
   107  }
   108  
   109  // WithInput sets the input object to use for the query. References rooted at
   110  // input will be evaluated against this value. This is optional.
   111  func (q *Query) WithInput(input *ast.Term) *Query {
   112  	q.input = input
   113  	return q
   114  }
   115  
   116  // WithTracer adds a query tracer to use during evaluation. This is optional.
   117  // Deprecated: Use WithQueryTracer instead.
   118  func (q *Query) WithTracer(tracer Tracer) *Query {
   119  	qt, ok := tracer.(QueryTracer)
   120  	if !ok {
   121  		qt = WrapLegacyTracer(tracer)
   122  	}
   123  	return q.WithQueryTracer(qt)
   124  }
   125  
   126  // WithQueryTracer adds a query tracer to use during evaluation. This is optional.
   127  // Disabled QueryTracers will be ignored.
   128  func (q *Query) WithQueryTracer(tracer QueryTracer) *Query {
   129  	if !tracer.Enabled() {
   130  		return q
   131  	}
   132  
   133  	q.tracers = append(q.tracers, tracer)
   134  
   135  	// If *any* of the tracers require local variable metadata we need to
   136  	// enabled plugging local trace variables.
   137  	conf := tracer.Config()
   138  	if conf.PlugLocalVars {
   139  		q.plugTraceVars = true
   140  	}
   141  
   142  	return q
   143  }
   144  
   145  // WithMetrics sets the metrics collection to add evaluation metrics to. This
   146  // is optional.
   147  func (q *Query) WithMetrics(m metrics.Metrics) *Query {
   148  	q.metrics = m
   149  	return q
   150  }
   151  
   152  // WithInstrumentation sets the instrumentation configuration to enable on the
   153  // evaluation process. By default, instrumentation is turned off.
   154  func (q *Query) WithInstrumentation(instr *Instrumentation) *Query {
   155  	q.instr = instr
   156  	return q
   157  }
   158  
   159  // WithUnknowns sets the initial set of variables or references to treat as
   160  // unknown during query evaluation. This is required for partial evaluation.
   161  func (q *Query) WithUnknowns(terms []*ast.Term) *Query {
   162  	q.unknowns = terms
   163  	return q
   164  }
   165  
   166  // WithPartialNamespace sets the namespace to use for supporting rules
   167  // generated as part of the partial evaluation process. The ns value must be a
   168  // valid package path component.
   169  func (q *Query) WithPartialNamespace(ns string) *Query {
   170  	q.partialNamespace = ns
   171  	return q
   172  }
   173  
   174  // WithSkipPartialNamespace disables namespacing of saved support rules that are generated
   175  // from the original policy (rules which are completely synthetic are still namespaced.)
   176  func (q *Query) WithSkipPartialNamespace(yes bool) *Query {
   177  	q.skipSaveNamespace = yes
   178  	return q
   179  }
   180  
   181  // WithDisableInlining adds a set of paths to the query that should be excluded from
   182  // inlining. Inlining during partial evaluation can be expensive in some cases
   183  // (e.g., when a cross-product is computed.) Disabling inlining avoids expensive
   184  // computation at the cost of generating support rules.
   185  func (q *Query) WithDisableInlining(paths []ast.Ref) *Query {
   186  	q.disableInlining = paths
   187  	return q
   188  }
   189  
   190  // WithShallowInlining disables aggressive inlining performed during partial evaluation.
   191  // When shallow inlining is enabled rules that depend (transitively) on unknowns are not inlined.
   192  // Only rules/values that are completely known will be inlined.
   193  func (q *Query) WithShallowInlining(yes bool) *Query {
   194  	q.shallowInlining = yes
   195  	return q
   196  }
   197  
   198  // WithRuntime sets the runtime data to execute the query with. The runtime data
   199  // can be returned by the `opa.runtime` built-in function.
   200  func (q *Query) WithRuntime(runtime *ast.Term) *Query {
   201  	q.runtime = runtime
   202  	return q
   203  }
   204  
   205  // WithBuiltins adds a set of built-in functions that can be called by the
   206  // query.
   207  func (q *Query) WithBuiltins(builtins map[string]*Builtin) *Query {
   208  	q.builtins = builtins
   209  	return q
   210  }
   211  
   212  // WithIndexing will enable or disable using rule indexing for the evaluation
   213  // of the query. The default is enabled.
   214  func (q *Query) WithIndexing(enabled bool) *Query {
   215  	q.indexing = enabled
   216  	return q
   217  }
   218  
   219  // WithEarlyExit will enable or disable using 'early exit' for the evaluation
   220  // of the query. The default is enabled.
   221  func (q *Query) WithEarlyExit(enabled bool) *Query {
   222  	q.earlyExit = enabled
   223  	return q
   224  }
   225  
   226  // WithSeed sets a reader that will seed randomization required by built-in functions.
   227  // If a seed is not provided crypto/rand.Reader is used.
   228  func (q *Query) WithSeed(r io.Reader) *Query {
   229  	q.seed = r
   230  	return q
   231  }
   232  
   233  // WithTime sets the time that will be returned by the time.now_ns() built-in function.
   234  func (q *Query) WithTime(x time.Time) *Query {
   235  	q.time = x
   236  	return q
   237  }
   238  
   239  // WithInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize.
   240  func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query {
   241  	q.interQueryBuiltinCache = c
   242  	return q
   243  }
   244  
   245  // WithStrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors.
   246  func (q *Query) WithStrictBuiltinErrors(yes bool) *Query {
   247  	q.strictBuiltinErrors = yes
   248  	return q
   249  }
   250  
   251  // WithResolver configures an external resolver to use for the given ref.
   252  func (q *Query) WithResolver(ref ast.Ref, r resolver.Resolver) *Query {
   253  	q.external.Put(ref, r)
   254  	return q
   255  }
   256  
   257  func (q *Query) WithPrintHook(h print.Hook) *Query {
   258  	q.printHook = h
   259  	return q
   260  }
   261  
   262  // WithDistributedTracingOpts sets the options to be used by distributed tracing.
   263  func (q *Query) WithDistributedTracingOpts(tr tracing.Options) *Query {
   264  	q.tracingOpts = tr
   265  	return q
   266  }
   267  
   268  // PartialRun executes partial evaluation on the query with respect to unknown
   269  // values. Partial evaluation attempts to evaluate as much of the query as
   270  // possible without requiring values for the unknowns set on the query. The
   271  // result of partial evaluation is a new set of queries that can be evaluated
   272  // once the unknown value is known. In addition to new queries, partial
   273  // evaluation may produce additional support modules that should be used in
   274  // conjunction with the partially evaluated queries.
   275  func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []*ast.Module, err error) {
   276  	if q.partialNamespace == "" {
   277  		q.partialNamespace = "partial" // lazily initialize partial namespace
   278  	}
   279  	if q.seed == nil {
   280  		q.seed = rand.Reader
   281  	}
   282  	if !q.time.IsZero() {
   283  		q.time = time.Now()
   284  	}
   285  	if q.metrics == nil {
   286  		q.metrics = metrics.New()
   287  	}
   288  	f := &queryIDFactory{}
   289  	b := newBindings(0, q.instr)
   290  	e := &eval{
   291  		ctx:                    ctx,
   292  		metrics:                q.metrics,
   293  		seed:                   q.seed,
   294  		time:                   ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())),
   295  		cancel:                 q.cancel,
   296  		query:                  q.query,
   297  		queryCompiler:          q.queryCompiler,
   298  		queryIDFact:            f,
   299  		queryID:                f.Next(),
   300  		bindings:               b,
   301  		compiler:               q.compiler,
   302  		store:                  q.store,
   303  		baseCache:              newBaseCache(),
   304  		targetStack:            newRefStack(),
   305  		txn:                    q.txn,
   306  		input:                  q.input,
   307  		external:               q.external,
   308  		tracers:                q.tracers,
   309  		traceEnabled:           len(q.tracers) > 0,
   310  		plugTraceVars:          q.plugTraceVars,
   311  		instr:                  q.instr,
   312  		builtins:               q.builtins,
   313  		builtinCache:           builtins.Cache{},
   314  		functionMocks:          newFunctionMocksStack(),
   315  		interQueryBuiltinCache: q.interQueryBuiltinCache,
   316  		virtualCache:           newVirtualCache(),
   317  		comprehensionCache:     newComprehensionCache(),
   318  		saveSet:                newSaveSet(q.unknowns, b, q.instr),
   319  		saveStack:              newSaveStack(),
   320  		saveSupport:            newSaveSupport(),
   321  		saveNamespace:          ast.StringTerm(q.partialNamespace),
   322  		skipSaveNamespace:      q.skipSaveNamespace,
   323  		inliningControl: &inliningControl{
   324  			shallow: q.shallowInlining,
   325  		},
   326  		genvarprefix:  q.genvarprefix,
   327  		runtime:       q.runtime,
   328  		indexing:      q.indexing,
   329  		earlyExit:     q.earlyExit,
   330  		builtinErrors: &builtinErrors{},
   331  		printHook:     q.printHook,
   332  	}
   333  
   334  	if len(q.disableInlining) > 0 {
   335  		e.inliningControl.PushDisable(q.disableInlining, false)
   336  	}
   337  
   338  	e.caller = e
   339  	q.metrics.Timer(metrics.RegoPartialEval).Start()
   340  	defer q.metrics.Timer(metrics.RegoPartialEval).Stop()
   341  
   342  	livevars := ast.NewVarSet()
   343  
   344  	ast.WalkVars(q.query, func(x ast.Var) bool {
   345  		if !x.IsGenerated() {
   346  			livevars.Add(x)
   347  		}
   348  		return false
   349  	})
   350  
   351  	p := copypropagation.New(livevars).WithCompiler(q.compiler)
   352  
   353  	err = e.Run(func(e *eval) error {
   354  
   355  		// Build output from saved expressions.
   356  		body := ast.NewBody()
   357  
   358  		for _, elem := range e.saveStack.Stack[len(e.saveStack.Stack)-1] {
   359  			body.Append(elem.Plug(e.bindings))
   360  		}
   361  
   362  		// Include bindings as exprs so that when caller evals the result, they
   363  		// can obtain values for the vars in their query.
   364  		bindingExprs := []*ast.Expr{}
   365  		_ = e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
   366  			bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b))
   367  			return nil
   368  		}) // cannot return error
   369  
   370  		// Sort binding expressions so that results are deterministic.
   371  		sort.Slice(bindingExprs, func(i, j int) bool {
   372  			return bindingExprs[i].Compare(bindingExprs[j]) < 0
   373  		})
   374  
   375  		for i := range bindingExprs {
   376  			body.Append(bindingExprs[i])
   377  		}
   378  
   379  		// Skip this rule body if it fails to type-check.
   380  		// Type-checking failure means the rule body will never succeed.
   381  		if !e.compiler.PassesTypeCheck(body) {
   382  			return nil
   383  		}
   384  
   385  		if !q.shallowInlining {
   386  			body = applyCopyPropagation(p, e.instr, body)
   387  		}
   388  
   389  		partials = append(partials, body)
   390  		return nil
   391  	})
   392  
   393  	support = e.saveSupport.List()
   394  
   395  	if q.strictBuiltinErrors && len(e.builtinErrors.errs) > 0 {
   396  		err = e.builtinErrors.errs[0]
   397  	}
   398  
   399  	for i := range support {
   400  		sort.Slice(support[i].Rules, func(j, k int) bool {
   401  			return support[i].Rules[j].Compare(support[i].Rules[k]) < 0
   402  		})
   403  	}
   404  
   405  	return partials, support, err
   406  }
   407  
   408  // Run is a wrapper around Iter that accumulates query results and returns them
   409  // in one shot.
   410  func (q *Query) Run(ctx context.Context) (QueryResultSet, error) {
   411  	qrs := QueryResultSet{}
   412  	return qrs, q.Iter(ctx, func(qr QueryResult) error {
   413  		qrs = append(qrs, qr)
   414  		return nil
   415  	})
   416  }
   417  
   418  // Iter executes the query and invokes the iter function with query results
   419  // produced by evaluating the query.
   420  func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
   421  	if q.seed == nil {
   422  		q.seed = rand.Reader
   423  	}
   424  	if q.time.IsZero() {
   425  		q.time = time.Now()
   426  	}
   427  	if q.metrics == nil {
   428  		q.metrics = metrics.New()
   429  	}
   430  	f := &queryIDFactory{}
   431  	e := &eval{
   432  		ctx:                    ctx,
   433  		metrics:                q.metrics,
   434  		seed:                   q.seed,
   435  		time:                   ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())),
   436  		cancel:                 q.cancel,
   437  		query:                  q.query,
   438  		queryCompiler:          q.queryCompiler,
   439  		queryIDFact:            f,
   440  		queryID:                f.Next(),
   441  		bindings:               newBindings(0, q.instr),
   442  		compiler:               q.compiler,
   443  		store:                  q.store,
   444  		baseCache:              newBaseCache(),
   445  		targetStack:            newRefStack(),
   446  		txn:                    q.txn,
   447  		input:                  q.input,
   448  		external:               q.external,
   449  		tracers:                q.tracers,
   450  		traceEnabled:           len(q.tracers) > 0,
   451  		plugTraceVars:          q.plugTraceVars,
   452  		instr:                  q.instr,
   453  		builtins:               q.builtins,
   454  		builtinCache:           builtins.Cache{},
   455  		functionMocks:          newFunctionMocksStack(),
   456  		interQueryBuiltinCache: q.interQueryBuiltinCache,
   457  		virtualCache:           newVirtualCache(),
   458  		comprehensionCache:     newComprehensionCache(),
   459  		genvarprefix:           q.genvarprefix,
   460  		runtime:                q.runtime,
   461  		indexing:               q.indexing,
   462  		earlyExit:              q.earlyExit,
   463  		builtinErrors:          &builtinErrors{},
   464  		printHook:              q.printHook,
   465  		tracingOpts:            q.tracingOpts,
   466  	}
   467  	e.caller = e
   468  	q.metrics.Timer(metrics.RegoQueryEval).Start()
   469  	err := e.Run(func(e *eval) error {
   470  		qr := QueryResult{}
   471  		_ = e.bindings.Iter(nil, func(k, v *ast.Term) error {
   472  			qr[k.Value.(ast.Var)] = v
   473  			return nil
   474  		}) // cannot return error
   475  		return iter(qr)
   476  	})
   477  
   478  	if q.strictBuiltinErrors && err == nil && len(e.builtinErrors.errs) > 0 {
   479  		err = e.builtinErrors.errs[0]
   480  	}
   481  
   482  	q.metrics.Timer(metrics.RegoQueryEval).Stop()
   483  	return err
   484  }