go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/providers-sdk/v1/testutils/testutils.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package testutils
     5  
     6  import (
     7  	"fmt"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"runtime"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  	"go.mondoo.com/cnquery"
    20  	"go.mondoo.com/cnquery/llx"
    21  	"go.mondoo.com/cnquery/logger"
    22  	"go.mondoo.com/cnquery/mql"
    23  	"go.mondoo.com/cnquery/mqlc"
    24  	"go.mondoo.com/cnquery/providers"
    25  	"go.mondoo.com/cnquery/providers-sdk/v1/lr"
    26  	"go.mondoo.com/cnquery/providers-sdk/v1/lr/docs"
    27  	"go.mondoo.com/cnquery/providers-sdk/v1/resources"
    28  	"go.mondoo.com/cnquery/providers-sdk/v1/testutils/mockprovider"
    29  	networkconf "go.mondoo.com/cnquery/providers/network/config"
    30  	networkprovider "go.mondoo.com/cnquery/providers/network/provider"
    31  	osconf "go.mondoo.com/cnquery/providers/os/config"
    32  	osprovider "go.mondoo.com/cnquery/providers/os/provider"
    33  	"sigs.k8s.io/yaml"
    34  )
    35  
    36  var (
    37  	Features     cnquery.Features
    38  	TestutilsDir string
    39  )
    40  
    41  func init() {
    42  	logger.InitTestEnv()
    43  	Features = getEnvFeatures()
    44  
    45  	_, pathToFile, _, ok := runtime.Caller(0)
    46  	if !ok {
    47  		panic("unable to get runtime for testutils for cnquery providers")
    48  	}
    49  	TestutilsDir = path.Dir(pathToFile)
    50  }
    51  
    52  func getEnvFeatures() cnquery.Features {
    53  	env := os.Getenv("FEATURES")
    54  	if env == "" {
    55  		return cnquery.Features{byte(cnquery.PiperCode)}
    56  	}
    57  
    58  	arr := strings.Split(env, ",")
    59  	var fts cnquery.Features
    60  	for i := range arr {
    61  		v, ok := cnquery.FeaturesValue[arr[i]]
    62  		if ok {
    63  			fmt.Println("--> activate feature: " + arr[i])
    64  			fts = append(Features, byte(v))
    65  		} else {
    66  			panic("cannot find requested feature: " + arr[i])
    67  		}
    68  	}
    69  	return fts
    70  }
    71  
    72  type tester struct {
    73  	Runtime llx.Runtime
    74  }
    75  
    76  type SchemaProvider struct {
    77  	Provider string
    78  	Path     string
    79  }
    80  
    81  func InitTester(runtime llx.Runtime) *tester {
    82  	return &tester{
    83  		Runtime: runtime,
    84  	}
    85  }
    86  
    87  func (ctx *tester) Compile(query string) (*llx.CodeBundle, error) {
    88  	return mqlc.Compile(query, nil, mqlc.NewConfig(ctx.Runtime.Schema(), Features))
    89  }
    90  
    91  func (ctx *tester) ExecuteCode(bundle *llx.CodeBundle, props map[string]*llx.Primitive) (map[string]*llx.RawResult, error) {
    92  	return mql.ExecuteCode(ctx.Runtime, bundle, props, Features)
    93  }
    94  
    95  func (ctx *tester) TestQueryP(t *testing.T, query string, props map[string]*llx.Primitive) []*llx.RawResult {
    96  	t.Helper()
    97  	bundle, err := mqlc.Compile(query, props, mqlc.NewConfig(ctx.Runtime.Schema(), Features))
    98  	if err != nil {
    99  		t.Fatal("failed to compile code: " + err.Error())
   100  	}
   101  	err = mqlc.Invariants.Check(bundle)
   102  	require.NoError(t, err)
   103  	return ctx.TestMqlc(t, bundle, props)
   104  }
   105  
   106  func (ctx *tester) TestQuery(t *testing.T, query string) []*llx.RawResult {
   107  	return ctx.TestQueryP(t, query, nil)
   108  }
   109  
   110  func (ctx *tester) TestMqlc(t *testing.T, bundle *llx.CodeBundle, props map[string]*llx.Primitive) []*llx.RawResult {
   111  	t.Helper()
   112  
   113  	resultMap, err := mql.ExecuteCode(ctx.Runtime, bundle, props, Features)
   114  	require.NoError(t, err)
   115  
   116  	lastQueryResult := &llx.RawResult{}
   117  	results := make([]*llx.RawResult, 0, len(resultMap)+1)
   118  
   119  	refs := make([]uint64, 0, len(bundle.CodeV2.Checksums))
   120  	for _, datapointArr := range [][]uint64{bundle.CodeV2.Datapoints(), bundle.CodeV2.Entrypoints()} {
   121  		refs = append(refs, datapointArr...)
   122  	}
   123  
   124  	sort.Slice(refs, func(i, j int) bool {
   125  		return refs[i] < refs[j]
   126  	})
   127  
   128  	for idx, ref := range refs {
   129  		checksum := bundle.CodeV2.Checksums[ref]
   130  		if d, ok := resultMap[checksum]; ok {
   131  			results = append(results, d)
   132  			if idx+1 == len(refs) {
   133  				lastQueryResult.CodeID = d.CodeID
   134  				if d.Data.Error != nil {
   135  					lastQueryResult.Data = &llx.RawData{
   136  						Error: d.Data.Error,
   137  					}
   138  				} else {
   139  					success, valid := d.Data.IsSuccess()
   140  					lastQueryResult.Data = llx.BoolData(success && valid)
   141  				}
   142  			}
   143  		}
   144  	}
   145  
   146  	results = append(results, lastQueryResult)
   147  	return results
   148  }
   149  
   150  func MustLoadSchema(provider SchemaProvider) *resources.Schema {
   151  	if provider.Path == "" && provider.Provider == "" {
   152  		panic("cannot load schema without provider name or path")
   153  	}
   154  	var path string
   155  	// path towards the .yaml manifest, containing metadata abou the resources
   156  	var manifestPath string
   157  	if provider.Provider != "" {
   158  		switch provider.Provider {
   159  		// special handling for the mockprovider
   160  		case "mockprovider":
   161  			path = filepath.Join(TestutilsDir, "mockprovider/resources/mockprovider.lr")
   162  		default:
   163  			manifestPath = filepath.Join(TestutilsDir, "../../../providers/"+provider.Provider+"/resources/"+provider.Provider+".lr.manifest.yaml")
   164  			path = filepath.Join(TestutilsDir, "../../../providers/"+provider.Provider+"/resources/"+provider.Provider+".lr")
   165  		}
   166  	} else if provider.Path != "" {
   167  		path = provider.Path
   168  	}
   169  
   170  	res, err := lr.Resolve(path, func(path string) ([]byte, error) { return os.ReadFile(path) })
   171  	if err != nil {
   172  		panic(err.Error())
   173  	}
   174  	schema, err := lr.Schema(res)
   175  	if err != nil {
   176  		panic(err.Error())
   177  	}
   178  	// TODO: we should make a function that takes the Schema and the metadata and merges those.
   179  	// Then we can use that in the LR code and the testutils code too
   180  	if manifestPath != "" {
   181  		// we will attempt to auto-detect the manifest to inject some metadata
   182  		// into the schema
   183  		raw, err := os.ReadFile(manifestPath)
   184  		if err == nil {
   185  			var lrDocsData docs.LrDocs
   186  			err = yaml.Unmarshal(raw, &lrDocsData)
   187  			if err == nil {
   188  				docs.InjectMetadata(schema, &lrDocsData)
   189  			}
   190  		}
   191  	}
   192  
   193  	return schema
   194  }
   195  
   196  func Local() llx.Runtime {
   197  	osSchema := MustLoadSchema(SchemaProvider{Provider: "os"})
   198  	coreSchema := MustLoadSchema(SchemaProvider{Provider: "core"})
   199  	networkSchema := MustLoadSchema(SchemaProvider{Provider: "network"})
   200  	mockSchema := MustLoadSchema(SchemaProvider{Provider: "mockprovider"})
   201  
   202  	runtime := providers.Coordinator.NewRuntime()
   203  
   204  	provider := &providers.RunningProvider{
   205  		Name:   osconf.Config.Name,
   206  		ID:     osconf.Config.ID,
   207  		Plugin: osprovider.Init(),
   208  		Schema: osSchema.Add(coreSchema),
   209  	}
   210  	runtime.Provider = &providers.ConnectedProvider{Instance: provider}
   211  	runtime.AddConnectedProvider(runtime.Provider)
   212  
   213  	provider = &providers.RunningProvider{
   214  		Name:   networkconf.Config.Name,
   215  		ID:     networkconf.Config.ID,
   216  		Plugin: networkprovider.Init(),
   217  		Schema: networkSchema,
   218  	}
   219  	runtime.AddConnectedProvider(&providers.ConnectedProvider{Instance: provider})
   220  
   221  	provider = &providers.RunningProvider{
   222  		Name:   mockprovider.Config.Name,
   223  		ID:     mockprovider.Config.ID,
   224  		Plugin: mockprovider.Init(),
   225  		Schema: mockSchema,
   226  	}
   227  	runtime.AddConnectedProvider(&providers.ConnectedProvider{Instance: provider})
   228  
   229  	return runtime
   230  }
   231  
   232  func mockRuntime(testdata string) llx.Runtime {
   233  	return mockRuntimeAbs(filepath.Join(TestutilsDir, testdata))
   234  }
   235  
   236  func mockRuntimeAbs(testdata string) llx.Runtime {
   237  	runtime := Local().(*providers.Runtime)
   238  
   239  	abs, _ := filepath.Abs(testdata)
   240  	recording, err := providers.LoadRecordingFile(abs)
   241  	if err != nil {
   242  		panic("failed to load recording: " + err.Error())
   243  	}
   244  	roRecording := recording.ReadOnly()
   245  
   246  	err = runtime.SetMockRecording(roRecording, runtime.Provider.Instance.ID, true)
   247  	if err != nil {
   248  		panic("failed to set recording: " + err.Error())
   249  	}
   250  	err = runtime.SetMockRecording(roRecording, networkconf.Config.ID, true)
   251  	if err != nil {
   252  		panic("failed to set recording: " + err.Error())
   253  	}
   254  	err = runtime.SetMockRecording(roRecording, mockprovider.Config.ID, true)
   255  	if err != nil {
   256  		panic("failed to set recording: " + err.Error())
   257  	}
   258  
   259  	return runtime
   260  }
   261  
   262  func LinuxMock() llx.Runtime {
   263  	return mockRuntime("testdata/arch.json")
   264  }
   265  
   266  func KubeletMock() llx.Runtime {
   267  	return mockRuntime("testdata/kubelet.json")
   268  }
   269  
   270  func KubeletAKSMock() llx.Runtime {
   271  	return mockRuntime("testdata/kubelet-aks.json")
   272  }
   273  
   274  func WindowsMock() llx.Runtime {
   275  	return mockRuntime("testdata/windows.json")
   276  }
   277  
   278  func RecordingMock(absTestdataPath string) llx.Runtime {
   279  	return mockRuntimeAbs(absTestdataPath)
   280  }
   281  
   282  type SimpleTest struct {
   283  	Code        string
   284  	ResultIndex int
   285  	Expectation interface{}
   286  }
   287  
   288  func (ctx *tester) TestSimple(t *testing.T, tests []SimpleTest) {
   289  	t.Helper()
   290  	for i := range tests {
   291  		cur := tests[i]
   292  		t.Run(cur.Code, func(t *testing.T) {
   293  			res := ctx.TestQuery(t, cur.Code)
   294  			assert.NotEmpty(t, res)
   295  
   296  			if len(res) <= cur.ResultIndex {
   297  				t.Error("insufficient results, looking for result idx " + strconv.Itoa(cur.ResultIndex))
   298  				return
   299  			}
   300  
   301  			data := res[cur.ResultIndex].Data
   302  			require.NoError(t, data.Error)
   303  			assert.Equal(t, cur.Expectation, data.Value)
   304  		})
   305  	}
   306  }
   307  
   308  func (ctx *tester) TestNoErrorsNonEmpty(t *testing.T, tests []SimpleTest) {
   309  	t.Helper()
   310  	for i := range tests {
   311  		cur := tests[i]
   312  		t.Run(cur.Code, func(t *testing.T) {
   313  			res := ctx.TestQuery(t, cur.Code)
   314  			assert.NotEmpty(t, res)
   315  		})
   316  	}
   317  }
   318  
   319  func (ctx *tester) TestSimpleErrors(t *testing.T, tests []SimpleTest) {
   320  	for i := range tests {
   321  		cur := tests[i]
   322  		t.Run(cur.Code, func(t *testing.T) {
   323  			res := ctx.TestQuery(t, cur.Code)
   324  			assert.NotEmpty(t, res)
   325  			assert.Equal(t, cur.Expectation, res[cur.ResultIndex].Result().Error)
   326  			assert.Nil(t, res[cur.ResultIndex].Data.Value)
   327  		})
   328  	}
   329  }
   330  
   331  func TestNoResultErrors(t *testing.T, r []*llx.RawResult) bool {
   332  	var found bool
   333  	for i := range r {
   334  		err := r[i].Data.Error
   335  		if err != nil {
   336  			t.Error("result has error: " + err.Error())
   337  			found = true
   338  		}
   339  	}
   340  	return found
   341  }