go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/mql/internal/nodes_test.go (about) 1 // Copyright (c) Mondoo, Inc. 2 // SPDX-License-Identifier: BUSL-1.1 3 4 package internal 5 6 import ( 7 "errors" 8 "testing" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12 "go.mondoo.com/cnquery/llx" 13 "go.mondoo.com/cnquery/types" 14 ) 15 16 func TestDatapointNode(t *testing.T) { 17 newNodeData := func() *DatapointNodeData { 18 return &DatapointNodeData{} 19 } 20 t.Run("initialize/recalculate", func(t *testing.T) { 21 t.Run("does not recalculate if data is not provided", func(t *testing.T) { 22 nodeData := newNodeData() 23 24 nodeData.initialize() 25 data := nodeData.recalculate() 26 27 assert.Nil(t, data) 28 }) 29 30 t.Run("recalculates if data is provided", func(t *testing.T) { 31 nodeData := newNodeData() 32 nodeData.res = &llx.RawResult{ 33 CodeID: "checksum", 34 Data: llx.BoolTrue, 35 } 36 37 nodeData.initialize() 38 data := nodeData.recalculate() 39 40 require.NotNil(t, data) 41 require.NotNil(t, data.res) 42 assert.Equal(t, "checksum", data.res.CodeID) 43 assert.Equal(t, llx.BoolTrue, data.res.Data) 44 }) 45 46 t.Run("casts if required type is provided", func(t *testing.T) { 47 nodeData := newNodeData() 48 typ := string(types.Bool) 49 nodeData.expectedType = &typ 50 nodeData.res = &llx.RawResult{ 51 CodeID: "checksum", 52 Data: llx.StringData("hello"), 53 } 54 55 nodeData.initialize() 56 data := nodeData.recalculate() 57 58 require.NotNil(t, data) 59 require.NotNil(t, data.res) 60 assert.Equal(t, "checksum", data.res.CodeID) 61 assert.Equal(t, llx.BoolTrue, data.res.Data) 62 }) 63 }) 64 65 t.Run("consume/recalculate", func(t *testing.T) { 66 t.Run("ignores nils", func(t *testing.T) { 67 nodeData := newNodeData() 68 69 nodeData.initialize() 70 nodeData.recalculate() 71 72 nodeData.consume(NodeID("__executor__"), &envelope{}) 73 data := nodeData.recalculate() 74 assert.Nil(t, data) 75 }) 76 77 t.Run("recalculate when data arrives", func(t *testing.T) { 78 nodeData := newNodeData() 79 80 nodeData.initialize() 81 nodeData.recalculate() 82 83 nodeData.consume(NodeID("__executor__"), &envelope{ 84 res: &llx.RawResult{ 85 CodeID: "checksum", 86 Data: llx.BoolTrue, 87 }, 88 }) 89 data := nodeData.recalculate() 90 91 require.NotNil(t, data) 92 require.NotNil(t, data.res) 93 assert.Equal(t, "checksum", data.res.CodeID) 94 assert.Equal(t, llx.BoolTrue, data.res.Data) 95 }) 96 97 t.Run("doesn't recalculate multiple times", func(t *testing.T) { 98 nodeData := newNodeData() 99 nodeData.res = &llx.RawResult{ 100 CodeID: "checksum", 101 Data: llx.BoolTrue, 102 } 103 104 nodeData.initialize() 105 data := nodeData.recalculate() 106 require.NotNil(t, data) 107 assert.NotNil(t, data.res) 108 109 nodeData.consume(NodeID("__executor__"), &envelope{ 110 res: &llx.RawResult{ 111 CodeID: "checksum", 112 Data: llx.BoolFalse, 113 }, 114 }) 115 data = nodeData.recalculate() 116 assert.Nil(t, data) 117 }) 118 119 t.Run("casts if required type is provided", func(t *testing.T) { 120 nodeData := newNodeData() 121 typ := string(types.Bool) 122 nodeData.expectedType = &typ 123 124 nodeData.initialize() 125 nodeData.recalculate() 126 127 nodeData.consume(NodeID("__executor__"), &envelope{ 128 res: &llx.RawResult{ 129 CodeID: "checksum", 130 Data: llx.StringData("hello"), 131 }, 132 }) 133 data := nodeData.recalculate() 134 135 require.NotNil(t, data) 136 require.NotNil(t, data.res) 137 assert.Equal(t, "checksum", data.res.CodeID) 138 assert.Equal(t, llx.BoolTrue, data.res.Data) 139 }) 140 141 t.Run("skips cast if required type are same", func(t *testing.T) { 142 nodeData := newNodeData() 143 typ := string(types.String) 144 nodeData.expectedType = &typ 145 146 nodeData.initialize() 147 nodeData.recalculate() 148 149 resData := llx.StringData("hello") 150 nodeData.consume(NodeID("__executor__"), &envelope{ 151 res: &llx.RawResult{ 152 CodeID: "checksum", 153 Data: resData, 154 }, 155 }) 156 data := nodeData.recalculate() 157 158 require.NotNil(t, data) 159 require.NotNil(t, data.res) 160 assert.Equal(t, "checksum", data.res.CodeID) 161 assert.Equal(t, resData, data.res.Data) 162 }) 163 164 t.Run("skips cast if datapoint is error", func(t *testing.T) { 165 nodeData := newNodeData() 166 typ := string(types.String) 167 nodeData.expectedType = &typ 168 169 nodeData.initialize() 170 nodeData.recalculate() 171 172 nodeData.consume(NodeID("__executor__"), &envelope{ 173 res: &llx.RawResult{ 174 CodeID: "checksum", 175 Data: &llx.RawData{ 176 Error: errors.New("error happened"), 177 }, 178 }, 179 }) 180 data := nodeData.recalculate() 181 182 require.NotNil(t, data) 183 require.NotNil(t, data.res) 184 assert.Equal(t, "checksum", data.res.CodeID) 185 require.NotNil(t, data.res.Data.Error) 186 assert.Equal(t, "error happened", data.res.Data.Error.Error()) 187 assert.Nil(t, data.res.Data.Value) 188 }) 189 190 t.Run("skips cast if expected type is unset", func(t *testing.T) { 191 nodeData := newNodeData() 192 typ := string(types.Unset) 193 nodeData.expectedType = &typ 194 195 nodeData.initialize() 196 nodeData.recalculate() 197 198 resData := llx.StringData("hello") 199 nodeData.consume(NodeID("__executor__"), &envelope{ 200 res: &llx.RawResult{ 201 CodeID: "checksum", 202 Data: resData, 203 }, 204 }) 205 data := nodeData.recalculate() 206 207 require.NotNil(t, data) 208 require.NotNil(t, data.res) 209 assert.Equal(t, "checksum", data.res.CodeID) 210 assert.Equal(t, resData, data.res.Data) 211 }) 212 }) 213 } 214 215 func TestExecutionQueryNode(t *testing.T) { 216 newNodeData := func() (*ExecutionQueryNodeData, chan runQueueItem) { 217 q := make(chan runQueueItem, 1) 218 data := &ExecutionQueryNodeData{ 219 queryID: "testqueryid", 220 requiredProperties: map[string]*executionQueryProperty{}, 221 runState: notReadyQueryNotReady, 222 runQueue: q, 223 codeBundle: &llx.CodeBundle{ 224 CodeV2: &llx.CodeV2{ 225 Id: "testqueryid", 226 }, 227 }, 228 } 229 return data, q 230 } 231 t.Run("initialize/recalculate", func(t *testing.T) { 232 t.Run("does not recalculate if dependencies not satisfied", func(t *testing.T) { 233 nodeData, q := newNodeData() 234 nodeData.requiredProperties = map[string]*executionQueryProperty{ 235 "prop1": { 236 name: "prop1", 237 checksum: "checksum1", 238 resolved: false, 239 }, 240 } 241 nodeData.initialize() 242 data := nodeData.recalculate() 243 assert.Nil(t, data) 244 select { 245 case <-q: 246 assert.Fail(t, "not ready for execution") 247 default: 248 } 249 }) 250 t.Run("recalculates if dependencies are satisfied", func(t *testing.T) { 251 nodeData, q := newNodeData() 252 nodeData.requiredProperties = map[string]*executionQueryProperty{ 253 "prop1": { 254 name: "prop1", 255 checksum: "checksum1", 256 resolved: true, 257 value: llx.BoolFalse.Result(), 258 }, 259 "prop2": { 260 name: "prop2", 261 checksum: "checksum1", 262 resolved: true, 263 value: llx.BoolFalse.Result(), 264 }, 265 } 266 nodeData.initialize() 267 data := nodeData.recalculate() 268 assert.NotNil(t, data) 269 assert.Nil(t, data.res) 270 select { 271 case item := <-q: 272 require.NotNil(t, item.codeBundle) 273 assert.Equal(t, "testqueryid", item.codeBundle.CodeV2.Id) 274 assert.Contains(t, item.props, "prop1") 275 default: 276 assert.Fail(t, "expected something to be executed") 277 } 278 }) 279 }) 280 281 t.Run("consume/recalculate", func(t *testing.T) { 282 t.Run("does not recalculate if dependencies not satisfied", func(t *testing.T) { 283 nodeData, q := newNodeData() 284 nodeData.requiredProperties = map[string]*executionQueryProperty{ 285 "prop1": { 286 name: "prop1", 287 checksum: "checksum1", 288 }, 289 "prop2": { 290 name: "prop2", 291 checksum: "checksum2", 292 }, 293 } 294 nodeData.initialize() 295 data := nodeData.recalculate() 296 assert.Nil(t, data) 297 nodeData.consume(NodeID("checksum1"), &envelope{ 298 res: &llx.RawResult{ 299 CodeID: "checksum1", 300 Data: llx.BoolTrue, 301 }, 302 }) 303 304 select { 305 case <-q: 306 assert.Fail(t, "not ready for execution") 307 default: 308 } 309 }) 310 t.Run("only recalculates once", func(t *testing.T) { 311 nodeData, q := newNodeData() 312 nodeData.requiredProperties = map[string]*executionQueryProperty{ 313 "prop1": { 314 name: "prop1", 315 checksum: "checksum1", 316 }, 317 "prop2": { 318 name: "prop2", 319 checksum: "checksum1", 320 }, 321 } 322 nodeData.initialize() 323 data := nodeData.recalculate() 324 assert.Nil(t, data) 325 nodeData.consume(NodeID("checksum1"), &envelope{ 326 res: &llx.RawResult{ 327 CodeID: "checksum1", 328 Data: llx.BoolTrue, 329 }, 330 }) 331 data = nodeData.recalculate() 332 assert.NotNil(t, data) 333 select { 334 case _ = <-q: 335 default: 336 assert.Fail(t, "expected something to be executed") 337 } 338 339 nodeData.consume(NodeID("checksum1"), &envelope{ 340 res: &llx.RawResult{ 341 CodeID: "checksum1", 342 Data: llx.BoolTrue, 343 }, 344 }) 345 data = nodeData.recalculate() 346 select { 347 case _ = <-q: 348 assert.Fail(t, "query should not re-execute") 349 default: 350 } 351 }) 352 t.Run("recalculates after all dependencies are satisfied", func(t *testing.T) {}) 353 }) 354 } 355 356 func TestCollectionFinisherNode(t *testing.T) { 357 newNodeData := func(reporter func(numCompleted int, total int)) *CollectionFinisherNodeData { 358 data := &CollectionFinisherNodeData{ 359 progressReporter: ProgressReporterFunc(reporter), 360 doneChan: make(chan struct{}), 361 } 362 return data 363 } 364 365 results := map[string]*llx.RawResult{ 366 "codeID1": { 367 CodeID: "codeID1", 368 Data: llx.BoolData(true), 369 }, 370 } 371 372 t.Run("initialize/recalculate", func(t *testing.T) { 373 t.Run("recalculates if there are no remaining datapoints", func(t *testing.T) { 374 nodeData := newNodeData(func(completed int, total int) { 375 assert.Equal(t, 0, completed) 376 assert.Equal(t, 0, total) 377 }) 378 379 nodeData.initialize() 380 nodeData.recalculate() 381 382 select { 383 case _, ok := <-nodeData.doneChan: 384 assert.False(t, ok) 385 default: 386 assert.Fail(t, "expected channel to be closed") 387 } 388 }) 389 t.Run("does not recalculate if there are remaining datapoints", func(t *testing.T) { 390 nodeData := newNodeData(func(completed int, total int) { 391 assert.Fail(t, "should not recalculate") 392 }) 393 394 nodeData.totalDatapoints = 2 395 nodeData.remainingDatapoints = map[string]struct{}{ 396 "codeID1": {}, 397 "codeID2": {}, 398 } 399 400 nodeData.initialize() 401 nodeData.recalculate() 402 403 select { 404 case _, _ = <-nodeData.doneChan: 405 assert.Fail(t, "expected channel to be open") 406 default: 407 } 408 }) 409 }) 410 411 t.Run("consume/recalculate", func(t *testing.T) { 412 t.Run("notifies progress when partially complete", func(t *testing.T) { 413 progressCalled := false 414 nodeData := newNodeData(func(completed int, total int) { 415 progressCalled = true 416 assert.Equal(t, 1, completed) 417 assert.Equal(t, 2, total) 418 }) 419 nodeData.totalDatapoints = 2 420 nodeData.remainingDatapoints = map[string]struct{}{ 421 "codeID1": {}, 422 "codeID2": {}, 423 } 424 nodeData.initialize() 425 nodeData.consume("codeID1", &envelope{ 426 res: results["codeID1"], 427 }) 428 nodeData.recalculate() 429 430 assert.True(t, progressCalled) 431 select { 432 case _, _ = <-nodeData.doneChan: 433 assert.Fail(t, "expected channel to be open") 434 default: 435 } 436 }) 437 t.Run("notifies progress and signals finish when fully complete", func(t *testing.T) { 438 progressCalled := false 439 nodeData := newNodeData(func(completed int, total int) { 440 progressCalled = true 441 assert.Equal(t, 1, completed) 442 assert.Equal(t, 1, total) 443 }) 444 nodeData.totalDatapoints = 1 445 nodeData.remainingDatapoints = map[string]struct{}{ 446 "codeID1": {}, 447 } 448 nodeData.initialize() 449 nodeData.consume("codeID1", &envelope{ 450 res: results["codeID1"], 451 }) 452 nodeData.recalculate() 453 454 assert.True(t, progressCalled) 455 select { 456 case _, ok := <-nodeData.doneChan: 457 assert.False(t, ok) 458 default: 459 assert.Fail(t, "expected channel to be closed") 460 } 461 }) 462 }) 463 } 464 465 func TestDatapointCollectorNode(t *testing.T) { 466 newNodeData := func(collectorFunc func(results []*llx.RawResult)) *DatapointCollectorNodeData { 467 data := &DatapointCollectorNodeData{ 468 unreported: make(map[string]*llx.RawResult), 469 collectors: []DatapointCollector{ 470 &FuncCollector{ 471 SinkDataFunc: collectorFunc, 472 }, 473 }, 474 } 475 return data 476 } 477 478 initExpectedData := func() map[string]*llx.RawResult { 479 return map[string]*llx.RawResult{ 480 "codeID1": { 481 CodeID: "codeID1", 482 Data: llx.BoolData(true), 483 }, 484 "codeID2": { 485 CodeID: "codeID2", 486 Data: llx.BoolData(false), 487 }, 488 } 489 } 490 t.Run("initialize/recalculate", func(t *testing.T) { 491 t.Run("recalculates if unreported datapoints are available", func(t *testing.T) { 492 collected := map[string]int{} 493 expectedData := initExpectedData() 494 nodeData := newNodeData(func(results []*llx.RawResult) { 495 for _, r := range results { 496 assert.Equal(t, expectedData[r.CodeID], r) 497 collected[r.CodeID] = collected[r.CodeID] + 1 498 } 499 }) 500 501 nodeData.unreported = expectedData 502 503 nodeData.initialize() 504 nodeData.recalculate() 505 506 assert.Equal(t, 2, len(collected)) 507 for _, v := range collected { 508 assert.Equal(t, 1, v) 509 } 510 }) 511 512 t.Run("does not recalculate if no unreported data", func(t *testing.T) { 513 calls := 0 514 nodeData := newNodeData(func(results []*llx.RawResult) { 515 calls += 1 516 }) 517 518 nodeData.initialize() 519 nodeData.recalculate() 520 521 assert.Equal(t, 0, calls) 522 }) 523 }) 524 525 t.Run("consume/recalculate", func(t *testing.T) { 526 t.Run("recalculates if unreported datapoints are available", func(t *testing.T) { 527 collected := map[string]int{} 528 expectedData := initExpectedData() 529 530 nodeData := newNodeData(func(results []*llx.RawResult) { 531 for _, r := range results { 532 assert.Equal(t, expectedData[r.CodeID], r) 533 collected[r.CodeID] = collected[r.CodeID] + 1 534 } 535 }) 536 537 nodeData.initialize() 538 nodeData.consume("codeID1", &envelope{ 539 res: expectedData["codeID1"], 540 }) 541 nodeData.consume("rjID1", &envelope{ 542 res: expectedData["codeID2"], 543 }) 544 nodeData.recalculate() 545 546 assert.Equal(t, 2, len(collected)) 547 for _, v := range collected { 548 assert.Equal(t, 1, v) 549 } 550 }) 551 }) 552 }