go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/explorer/executor/executor.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package executor
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/rs/zerolog/log"
    13  	"go.mondoo.com/cnquery"
    14  	"go.mondoo.com/cnquery/cli/progress"
    15  	"go.mondoo.com/cnquery/explorer"
    16  	"go.mondoo.com/cnquery/llx"
    17  	"go.mondoo.com/cnquery/utils/multierr"
    18  )
    19  
    20  func RunExecutionJob(
    21  	runtime llx.Runtime, collectorSvc explorer.QueryConductor, assetMrn string,
    22  	job *explorer.ExecutionJob, features cnquery.Features, progressReporter progress.Progress,
    23  ) (*instance, error) {
    24  	// We are setting a sensible default timeout for jobs here. This will need
    25  	// user-configuration.
    26  	timeout := 30 * time.Minute
    27  
    28  	res := newInstance(runtime, progressReporter)
    29  	res.assetMrn = assetMrn
    30  	res.collector = collectorSvc
    31  	res.datapoints = job.Datapoints
    32  
    33  	return res, res.runCode(job.Queries, timeout)
    34  }
    35  
    36  func ExecuteFilterQueries(runtime llx.Runtime, queries []*explorer.Mquery, timeout time.Duration) ([]*explorer.Mquery, []error) {
    37  	equeries := map[string]*explorer.ExecutionQuery{}
    38  	mqueries := map[string]*explorer.Mquery{}
    39  	for i := range queries {
    40  		query := queries[i]
    41  		code, err := query.Compile(nil, runtime.Schema())
    42  		// Errors for filter queries are common when they reference resources for
    43  		// providers that are not found on the system.
    44  		if err != nil {
    45  			log.Debug().Err(err).Str("mql", query.Mql).Msg("skipping filter query, not supported")
    46  			continue
    47  		}
    48  
    49  		equeries[code.CodeV2.Id] = &explorer.ExecutionQuery{
    50  			Query: query.Mql,
    51  			Code:  code,
    52  		}
    53  		mqueries[code.CodeV2.Id] = query
    54  	}
    55  
    56  	instance := newInstance(runtime, nil)
    57  	err := instance.runCode(equeries, timeout)
    58  	if err != nil {
    59  		return nil, []error{err}
    60  	}
    61  
    62  	instance.WaitUntilDone(timeout)
    63  
    64  	var errs []error
    65  	res := []*explorer.Mquery{}
    66  	for _, equery := range equeries {
    67  		bundle := equery.Code
    68  		entrypoints := bundle.EntrypointChecksums()
    69  
    70  		allTrue := true
    71  		for j := range entrypoints {
    72  			ep := entrypoints[j]
    73  			res := instance.results[ep]
    74  			if isTrue, _ := res.Data.IsSuccess(); !isTrue {
    75  				allTrue = false
    76  			}
    77  		}
    78  
    79  		if allTrue {
    80  			query, ok := mqueries[bundle.CodeV2.Id]
    81  			if ok {
    82  				res = append(res, query)
    83  			} else {
    84  				errs = append(errs, errors.New("cannot find filter-query for result of bundle "+bundle.CodeV2.Id))
    85  			}
    86  		}
    87  	}
    88  
    89  	return res, errs
    90  }
    91  
    92  func (e *instance) runCode(queries map[string]*explorer.ExecutionQuery, timeout time.Duration) error {
    93  	e.execs = make(map[string]*llx.MQLExecutorV2, len(queries))
    94  
    95  	for i := range queries {
    96  		query := queries[i]
    97  		bundle := query.Code
    98  
    99  		e.queries[bundle.CodeV2.Id] = query
   100  
   101  		checksums := bundle.DatapointChecksums()
   102  		for j := range checksums {
   103  			e.datapointTracker[checksums[j]] = nil
   104  		}
   105  
   106  		checksums = bundle.EntrypointChecksums()
   107  		for j := range checksums {
   108  			e.datapointTracker[checksums[j]] = nil
   109  		}
   110  
   111  		for _, codeId := range query.Properties {
   112  			arr := e.notifyQuery[codeId]
   113  			arr = append(arr, query)
   114  			e.notifyQuery[codeId] = arr
   115  		}
   116  	}
   117  
   118  	// we need to only retain the checksums that notify other queries
   119  	// to be run later on
   120  	for codeID := range e.notifyQuery {
   121  		query := queries[codeID]
   122  		checksums := query.Code.EntrypointChecksums()
   123  
   124  		for k := range checksums {
   125  			checksum := checksums[k]
   126  
   127  			arr := e.datapointTracker[checksum]
   128  			arr = append(arr, query)
   129  			e.datapointTracker[checksum] = arr
   130  		}
   131  	}
   132  
   133  	var errs multierr.Errors
   134  	for i := range queries {
   135  		query := queries[i]
   136  		if len(query.Properties) != 0 {
   137  			continue
   138  		}
   139  
   140  		if err := e.runQuery(query.Code, nil); err != nil {
   141  			errs.Add(err)
   142  		}
   143  	}
   144  
   145  	return errs.Deduplicate()
   146  }
   147  
   148  // One instance of the executor. May be returned but not instantiated
   149  // from outside this package.
   150  type instance struct {
   151  	runtime llx.Runtime
   152  	// raw list of executino queries mapped via CodeID
   153  	queries map[string]*explorer.ExecutionQuery
   154  	// an optional list of datapoints as an allow-list of data that will be returned
   155  	datapoints map[string]*explorer.DataQueryInfo
   156  	// a tracker for all datapoints, that also references the queries that
   157  	// created them
   158  	datapointTracker map[string][]*explorer.ExecutionQuery
   159  	// all code executors that have been started
   160  	execs map[string]*llx.MQLExecutorV2
   161  	// raw results from CodeID to result
   162  	results map[string]*llx.RawResult
   163  	// identifies which queries (CodeID) trigger other queries
   164  	// this is used for properties, where a prop notifies a query that uses it
   165  	notifyQuery      map[string][]*explorer.ExecutionQuery
   166  	mutex            sync.Mutex
   167  	isAborted        bool
   168  	isDone           bool
   169  	errors           error
   170  	done             chan struct{}
   171  	progressReporter progress.Progress
   172  	collector        explorer.QueryConductor
   173  	assetMrn         string
   174  }
   175  
   176  func newInstance(runtime llx.Runtime, progressReporter progress.Progress) *instance {
   177  	if progressReporter == nil {
   178  		progressReporter = progress.Noop{}
   179  	}
   180  
   181  	return &instance{
   182  		runtime:          runtime,
   183  		datapointTracker: map[string][]*explorer.ExecutionQuery{},
   184  		queries:          map[string]*explorer.ExecutionQuery{},
   185  		results:          map[string]*llx.RawResult{},
   186  		notifyQuery:      map[string][]*explorer.ExecutionQuery{},
   187  		isAborted:        false,
   188  		isDone:           false,
   189  		done:             make(chan struct{}),
   190  		progressReporter: progressReporter,
   191  		assetMrn:         runtime.AssetMRN(),
   192  	}
   193  }
   194  
   195  func (e *instance) runQuery(bundle *llx.CodeBundle, props map[string]*llx.Primitive) error {
   196  	exec, err := llx.NewExecutorV2(bundle.CodeV2, e.runtime, props, e.collect)
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	err = exec.Run()
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	e.execs[bundle.CodeV2.Id] = exec
   207  	return nil
   208  }
   209  
   210  func (e *instance) WaitUntilDone(timeout time.Duration) error {
   211  	select {
   212  	case <-e.done:
   213  		return nil
   214  
   215  	case <-time.After(timeout):
   216  		e.mutex.Lock()
   217  		e.isAborted = true
   218  		isDone := e.isDone
   219  		e.mutex.Unlock()
   220  
   221  		if isDone {
   222  			return nil
   223  		}
   224  		return errors.New("execution timed out after " + timeout.String())
   225  	}
   226  }
   227  
   228  func (e *instance) snapshotResults() map[string]*llx.Result {
   229  	if e.datapoints != nil {
   230  		e.mutex.Lock()
   231  		results := make(map[string]*llx.Result, len(e.datapoints))
   232  		for id := range e.datapoints {
   233  			c := e.results[id]
   234  			if c != nil {
   235  				results[id] = c.Result()
   236  			}
   237  		}
   238  		e.mutex.Unlock()
   239  		return results
   240  	}
   241  
   242  	e.mutex.Lock()
   243  	results := make(map[string]*llx.Result, len(e.results))
   244  	for id, v := range e.results {
   245  		results[id] = v.Result()
   246  	}
   247  	e.mutex.Unlock()
   248  	return results
   249  }
   250  
   251  func (e *instance) StoreData() error {
   252  	if e.collector == nil {
   253  		return errors.New("cannot store data, no collector provided")
   254  	}
   255  
   256  	_, err := e.collector.StoreResults(context.Background(), &explorer.StoreResultsReq{
   257  		AssetMrn: e.assetMrn,
   258  		Data:     e.snapshotResults(),
   259  	})
   260  
   261  	return err
   262  }
   263  
   264  func (e *instance) isCollected(query *llx.CodeBundle) bool {
   265  	checksums := query.EntrypointChecksums()
   266  	for i := range checksums {
   267  		checksum := checksums[i]
   268  		if _, ok := e.results[checksum]; !ok {
   269  			return false
   270  		}
   271  	}
   272  
   273  	return true
   274  }
   275  
   276  func (e *instance) getProps(query *explorer.ExecutionQuery) (map[string]*llx.Primitive, error) {
   277  	res := map[string]*llx.Primitive{}
   278  
   279  	for name, queryID := range query.Properties {
   280  		query, ok := e.queries[queryID]
   281  		if !ok {
   282  			return nil, errors.New("cannot find running process for properties of query " + query.Code.Source)
   283  		}
   284  
   285  		eps := query.Code.EntrypointChecksums()
   286  		checksum := eps[0]
   287  		result := e.results[checksum]
   288  		if result == nil {
   289  			return nil, errors.New("cannot find result for property of query " + query.Code.Source)
   290  		}
   291  
   292  		res[name] = result.Result().Data
   293  	}
   294  
   295  	return res, nil
   296  }
   297  
   298  func (e *instance) collect(res *llx.RawResult) {
   299  	var runQueries []*explorer.ExecutionQuery
   300  
   301  	e.mutex.Lock()
   302  
   303  	e.results[res.CodeID] = res
   304  	cur := len(e.results)
   305  	max := len(e.datapointTracker)
   306  	isDone := cur == max
   307  	e.isDone = isDone
   308  	isAborted := e.isAborted
   309  	e.progressReporter.OnProgress(cur, max)
   310  	if isDone {
   311  		e.progressReporter.Completed()
   312  	}
   313  
   314  	// collect all the queries we need to notify + update that list to remove
   315  	// any query that we are about to start (all while inside of mutex lock)
   316  	queries := e.datapointTracker[res.CodeID]
   317  	if len(queries) != 0 {
   318  		remaining := []*explorer.ExecutionQuery{}
   319  		for j := range queries {
   320  			if !e.isCollected(queries[j].Code) {
   321  				remaining = append(remaining, queries[j])
   322  				continue
   323  			}
   324  
   325  			codeID := queries[j].Code.CodeV2.Id
   326  			notified := e.notifyQuery[codeID]
   327  			for k := range notified {
   328  				runQueries = append(runQueries, notified[k])
   329  			}
   330  		}
   331  		e.datapointTracker[res.CodeID] = remaining
   332  	}
   333  
   334  	e.mutex.Unlock()
   335  
   336  	if len(runQueries) != 0 {
   337  		var fatalErr error
   338  		for i := range runQueries {
   339  			query := runQueries[i]
   340  			props, err := e.getProps(query)
   341  			if err != nil {
   342  				fatalErr = err
   343  				break
   344  			}
   345  
   346  			err = e.runQuery(query.Code, props)
   347  			if err != nil {
   348  				fatalErr = err
   349  				break
   350  			}
   351  		}
   352  
   353  		if fatalErr != nil {
   354  			e.mutex.Lock()
   355  			e.errors = errors.Join(e.errors, fatalErr)
   356  			e.isAborted = true
   357  			isAborted = true
   358  			e.mutex.Unlock()
   359  		}
   360  	}
   361  
   362  	if isDone && !isAborted {
   363  		go func() {
   364  			e.done <- struct{}{}
   365  		}()
   366  	}
   367  }