go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/providers-sdk/v1/testutils/testutils.go (about) 1 // Copyright (c) Mondoo, Inc. 2 // SPDX-License-Identifier: BUSL-1.1 3 4 package testutils 5 6 import ( 7 "fmt" 8 "os" 9 "path" 10 "path/filepath" 11 "runtime" 12 "sort" 13 "strconv" 14 "strings" 15 "testing" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 "go.mondoo.com/cnquery" 20 "go.mondoo.com/cnquery/llx" 21 "go.mondoo.com/cnquery/logger" 22 "go.mondoo.com/cnquery/mql" 23 "go.mondoo.com/cnquery/mqlc" 24 "go.mondoo.com/cnquery/providers" 25 "go.mondoo.com/cnquery/providers-sdk/v1/lr" 26 "go.mondoo.com/cnquery/providers-sdk/v1/lr/docs" 27 "go.mondoo.com/cnquery/providers-sdk/v1/resources" 28 "go.mondoo.com/cnquery/providers-sdk/v1/testutils/mockprovider" 29 networkconf "go.mondoo.com/cnquery/providers/network/config" 30 networkprovider "go.mondoo.com/cnquery/providers/network/provider" 31 osconf "go.mondoo.com/cnquery/providers/os/config" 32 osprovider "go.mondoo.com/cnquery/providers/os/provider" 33 "sigs.k8s.io/yaml" 34 ) 35 36 var ( 37 Features cnquery.Features 38 TestutilsDir string 39 ) 40 41 func init() { 42 logger.InitTestEnv() 43 Features = getEnvFeatures() 44 45 _, pathToFile, _, ok := runtime.Caller(0) 46 if !ok { 47 panic("unable to get runtime for testutils for cnquery providers") 48 } 49 TestutilsDir = path.Dir(pathToFile) 50 } 51 52 func getEnvFeatures() cnquery.Features { 53 env := os.Getenv("FEATURES") 54 if env == "" { 55 return cnquery.Features{byte(cnquery.PiperCode)} 56 } 57 58 arr := strings.Split(env, ",") 59 var fts cnquery.Features 60 for i := range arr { 61 v, ok := cnquery.FeaturesValue[arr[i]] 62 if ok { 63 fmt.Println("--> activate feature: " + arr[i]) 64 fts = append(Features, byte(v)) 65 } else { 66 panic("cannot find requested feature: " + arr[i]) 67 } 68 } 69 return fts 70 } 71 72 type tester struct { 73 Runtime llx.Runtime 74 } 75 76 type SchemaProvider struct { 77 Provider string 78 Path string 79 } 80 81 func InitTester(runtime llx.Runtime) *tester { 82 return &tester{ 83 Runtime: runtime, 84 } 85 } 86 87 func (ctx *tester) Compile(query string) (*llx.CodeBundle, error) { 88 return mqlc.Compile(query, nil, mqlc.NewConfig(ctx.Runtime.Schema(), Features)) 89 } 90 91 func (ctx *tester) ExecuteCode(bundle *llx.CodeBundle, props map[string]*llx.Primitive) (map[string]*llx.RawResult, error) { 92 return mql.ExecuteCode(ctx.Runtime, bundle, props, Features) 93 } 94 95 func (ctx *tester) TestQueryP(t *testing.T, query string, props map[string]*llx.Primitive) []*llx.RawResult { 96 t.Helper() 97 bundle, err := mqlc.Compile(query, props, mqlc.NewConfig(ctx.Runtime.Schema(), Features)) 98 if err != nil { 99 t.Fatal("failed to compile code: " + err.Error()) 100 } 101 err = mqlc.Invariants.Check(bundle) 102 require.NoError(t, err) 103 return ctx.TestMqlc(t, bundle, props) 104 } 105 106 func (ctx *tester) TestQuery(t *testing.T, query string) []*llx.RawResult { 107 return ctx.TestQueryP(t, query, nil) 108 } 109 110 func (ctx *tester) TestMqlc(t *testing.T, bundle *llx.CodeBundle, props map[string]*llx.Primitive) []*llx.RawResult { 111 t.Helper() 112 113 resultMap, err := mql.ExecuteCode(ctx.Runtime, bundle, props, Features) 114 require.NoError(t, err) 115 116 lastQueryResult := &llx.RawResult{} 117 results := make([]*llx.RawResult, 0, len(resultMap)+1) 118 119 refs := make([]uint64, 0, len(bundle.CodeV2.Checksums)) 120 for _, datapointArr := range [][]uint64{bundle.CodeV2.Datapoints(), bundle.CodeV2.Entrypoints()} { 121 refs = append(refs, datapointArr...) 122 } 123 124 sort.Slice(refs, func(i, j int) bool { 125 return refs[i] < refs[j] 126 }) 127 128 for idx, ref := range refs { 129 checksum := bundle.CodeV2.Checksums[ref] 130 if d, ok := resultMap[checksum]; ok { 131 results = append(results, d) 132 if idx+1 == len(refs) { 133 lastQueryResult.CodeID = d.CodeID 134 if d.Data.Error != nil { 135 lastQueryResult.Data = &llx.RawData{ 136 Error: d.Data.Error, 137 } 138 } else { 139 success, valid := d.Data.IsSuccess() 140 lastQueryResult.Data = llx.BoolData(success && valid) 141 } 142 } 143 } 144 } 145 146 results = append(results, lastQueryResult) 147 return results 148 } 149 150 func MustLoadSchema(provider SchemaProvider) *resources.Schema { 151 if provider.Path == "" && provider.Provider == "" { 152 panic("cannot load schema without provider name or path") 153 } 154 var path string 155 // path towards the .yaml manifest, containing metadata abou the resources 156 var manifestPath string 157 if provider.Provider != "" { 158 switch provider.Provider { 159 // special handling for the mockprovider 160 case "mockprovider": 161 path = filepath.Join(TestutilsDir, "mockprovider/resources/mockprovider.lr") 162 default: 163 manifestPath = filepath.Join(TestutilsDir, "../../../providers/"+provider.Provider+"/resources/"+provider.Provider+".lr.manifest.yaml") 164 path = filepath.Join(TestutilsDir, "../../../providers/"+provider.Provider+"/resources/"+provider.Provider+".lr") 165 } 166 } else if provider.Path != "" { 167 path = provider.Path 168 } 169 170 res, err := lr.Resolve(path, func(path string) ([]byte, error) { return os.ReadFile(path) }) 171 if err != nil { 172 panic(err.Error()) 173 } 174 schema, err := lr.Schema(res) 175 if err != nil { 176 panic(err.Error()) 177 } 178 // TODO: we should make a function that takes the Schema and the metadata and merges those. 179 // Then we can use that in the LR code and the testutils code too 180 if manifestPath != "" { 181 // we will attempt to auto-detect the manifest to inject some metadata 182 // into the schema 183 raw, err := os.ReadFile(manifestPath) 184 if err == nil { 185 var lrDocsData docs.LrDocs 186 err = yaml.Unmarshal(raw, &lrDocsData) 187 if err == nil { 188 docs.InjectMetadata(schema, &lrDocsData) 189 } 190 } 191 } 192 193 return schema 194 } 195 196 func Local() llx.Runtime { 197 osSchema := MustLoadSchema(SchemaProvider{Provider: "os"}) 198 coreSchema := MustLoadSchema(SchemaProvider{Provider: "core"}) 199 networkSchema := MustLoadSchema(SchemaProvider{Provider: "network"}) 200 mockSchema := MustLoadSchema(SchemaProvider{Provider: "mockprovider"}) 201 202 runtime := providers.Coordinator.NewRuntime() 203 204 provider := &providers.RunningProvider{ 205 Name: osconf.Config.Name, 206 ID: osconf.Config.ID, 207 Plugin: osprovider.Init(), 208 Schema: osSchema.Add(coreSchema), 209 } 210 runtime.Provider = &providers.ConnectedProvider{Instance: provider} 211 runtime.AddConnectedProvider(runtime.Provider) 212 213 provider = &providers.RunningProvider{ 214 Name: networkconf.Config.Name, 215 ID: networkconf.Config.ID, 216 Plugin: networkprovider.Init(), 217 Schema: networkSchema, 218 } 219 runtime.AddConnectedProvider(&providers.ConnectedProvider{Instance: provider}) 220 221 provider = &providers.RunningProvider{ 222 Name: mockprovider.Config.Name, 223 ID: mockprovider.Config.ID, 224 Plugin: mockprovider.Init(), 225 Schema: mockSchema, 226 } 227 runtime.AddConnectedProvider(&providers.ConnectedProvider{Instance: provider}) 228 229 return runtime 230 } 231 232 func mockRuntime(testdata string) llx.Runtime { 233 return mockRuntimeAbs(filepath.Join(TestutilsDir, testdata)) 234 } 235 236 func mockRuntimeAbs(testdata string) llx.Runtime { 237 runtime := Local().(*providers.Runtime) 238 239 abs, _ := filepath.Abs(testdata) 240 recording, err := providers.LoadRecordingFile(abs) 241 if err != nil { 242 panic("failed to load recording: " + err.Error()) 243 } 244 roRecording := recording.ReadOnly() 245 246 err = runtime.SetMockRecording(roRecording, runtime.Provider.Instance.ID, true) 247 if err != nil { 248 panic("failed to set recording: " + err.Error()) 249 } 250 err = runtime.SetMockRecording(roRecording, networkconf.Config.ID, true) 251 if err != nil { 252 panic("failed to set recording: " + err.Error()) 253 } 254 err = runtime.SetMockRecording(roRecording, mockprovider.Config.ID, true) 255 if err != nil { 256 panic("failed to set recording: " + err.Error()) 257 } 258 259 return runtime 260 } 261 262 func LinuxMock() llx.Runtime { 263 return mockRuntime("testdata/arch.json") 264 } 265 266 func KubeletMock() llx.Runtime { 267 return mockRuntime("testdata/kubelet.json") 268 } 269 270 func KubeletAKSMock() llx.Runtime { 271 return mockRuntime("testdata/kubelet-aks.json") 272 } 273 274 func WindowsMock() llx.Runtime { 275 return mockRuntime("testdata/windows.json") 276 } 277 278 func RecordingMock(absTestdataPath string) llx.Runtime { 279 return mockRuntimeAbs(absTestdataPath) 280 } 281 282 type SimpleTest struct { 283 Code string 284 ResultIndex int 285 Expectation interface{} 286 } 287 288 func (ctx *tester) TestSimple(t *testing.T, tests []SimpleTest) { 289 t.Helper() 290 for i := range tests { 291 cur := tests[i] 292 t.Run(cur.Code, func(t *testing.T) { 293 res := ctx.TestQuery(t, cur.Code) 294 assert.NotEmpty(t, res) 295 296 if len(res) <= cur.ResultIndex { 297 t.Error("insufficient results, looking for result idx " + strconv.Itoa(cur.ResultIndex)) 298 return 299 } 300 301 data := res[cur.ResultIndex].Data 302 require.NoError(t, data.Error) 303 assert.Equal(t, cur.Expectation, data.Value) 304 }) 305 } 306 } 307 308 func (ctx *tester) TestNoErrorsNonEmpty(t *testing.T, tests []SimpleTest) { 309 t.Helper() 310 for i := range tests { 311 cur := tests[i] 312 t.Run(cur.Code, func(t *testing.T) { 313 res := ctx.TestQuery(t, cur.Code) 314 assert.NotEmpty(t, res) 315 }) 316 } 317 } 318 319 func (ctx *tester) TestSimpleErrors(t *testing.T, tests []SimpleTest) { 320 for i := range tests { 321 cur := tests[i] 322 t.Run(cur.Code, func(t *testing.T) { 323 res := ctx.TestQuery(t, cur.Code) 324 assert.NotEmpty(t, res) 325 assert.Equal(t, cur.Expectation, res[cur.ResultIndex].Result().Error) 326 assert.Nil(t, res[cur.ResultIndex].Data.Value) 327 }) 328 } 329 } 330 331 func TestNoResultErrors(t *testing.T, r []*llx.RawResult) bool { 332 var found bool 333 for i := range r { 334 err := r[i].Data.Error 335 if err != nil { 336 t.Error("result has error: " + err.Error()) 337 found = true 338 } 339 } 340 return found 341 }