go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/mql/internal/nodes_test.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package internal
     5  
     6  import (
     7  	"errors"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  	"go.mondoo.com/cnquery/llx"
    13  	"go.mondoo.com/cnquery/types"
    14  )
    15  
    16  func TestDatapointNode(t *testing.T) {
    17  	newNodeData := func() *DatapointNodeData {
    18  		return &DatapointNodeData{}
    19  	}
    20  	t.Run("initialize/recalculate", func(t *testing.T) {
    21  		t.Run("does not recalculate if data is not provided", func(t *testing.T) {
    22  			nodeData := newNodeData()
    23  
    24  			nodeData.initialize()
    25  			data := nodeData.recalculate()
    26  
    27  			assert.Nil(t, data)
    28  		})
    29  
    30  		t.Run("recalculates if data is provided", func(t *testing.T) {
    31  			nodeData := newNodeData()
    32  			nodeData.res = &llx.RawResult{
    33  				CodeID: "checksum",
    34  				Data:   llx.BoolTrue,
    35  			}
    36  
    37  			nodeData.initialize()
    38  			data := nodeData.recalculate()
    39  
    40  			require.NotNil(t, data)
    41  			require.NotNil(t, data.res)
    42  			assert.Equal(t, "checksum", data.res.CodeID)
    43  			assert.Equal(t, llx.BoolTrue, data.res.Data)
    44  		})
    45  
    46  		t.Run("casts if required type is provided", func(t *testing.T) {
    47  			nodeData := newNodeData()
    48  			typ := string(types.Bool)
    49  			nodeData.expectedType = &typ
    50  			nodeData.res = &llx.RawResult{
    51  				CodeID: "checksum",
    52  				Data:   llx.StringData("hello"),
    53  			}
    54  
    55  			nodeData.initialize()
    56  			data := nodeData.recalculate()
    57  
    58  			require.NotNil(t, data)
    59  			require.NotNil(t, data.res)
    60  			assert.Equal(t, "checksum", data.res.CodeID)
    61  			assert.Equal(t, llx.BoolTrue, data.res.Data)
    62  		})
    63  	})
    64  
    65  	t.Run("consume/recalculate", func(t *testing.T) {
    66  		t.Run("ignores nils", func(t *testing.T) {
    67  			nodeData := newNodeData()
    68  
    69  			nodeData.initialize()
    70  			nodeData.recalculate()
    71  
    72  			nodeData.consume(NodeID("__executor__"), &envelope{})
    73  			data := nodeData.recalculate()
    74  			assert.Nil(t, data)
    75  		})
    76  
    77  		t.Run("recalculate when data arrives", func(t *testing.T) {
    78  			nodeData := newNodeData()
    79  
    80  			nodeData.initialize()
    81  			nodeData.recalculate()
    82  
    83  			nodeData.consume(NodeID("__executor__"), &envelope{
    84  				res: &llx.RawResult{
    85  					CodeID: "checksum",
    86  					Data:   llx.BoolTrue,
    87  				},
    88  			})
    89  			data := nodeData.recalculate()
    90  
    91  			require.NotNil(t, data)
    92  			require.NotNil(t, data.res)
    93  			assert.Equal(t, "checksum", data.res.CodeID)
    94  			assert.Equal(t, llx.BoolTrue, data.res.Data)
    95  		})
    96  
    97  		t.Run("doesn't recalculate multiple times", func(t *testing.T) {
    98  			nodeData := newNodeData()
    99  			nodeData.res = &llx.RawResult{
   100  				CodeID: "checksum",
   101  				Data:   llx.BoolTrue,
   102  			}
   103  
   104  			nodeData.initialize()
   105  			data := nodeData.recalculate()
   106  			require.NotNil(t, data)
   107  			assert.NotNil(t, data.res)
   108  
   109  			nodeData.consume(NodeID("__executor__"), &envelope{
   110  				res: &llx.RawResult{
   111  					CodeID: "checksum",
   112  					Data:   llx.BoolFalse,
   113  				},
   114  			})
   115  			data = nodeData.recalculate()
   116  			assert.Nil(t, data)
   117  		})
   118  
   119  		t.Run("casts if required type is provided", func(t *testing.T) {
   120  			nodeData := newNodeData()
   121  			typ := string(types.Bool)
   122  			nodeData.expectedType = &typ
   123  
   124  			nodeData.initialize()
   125  			nodeData.recalculate()
   126  
   127  			nodeData.consume(NodeID("__executor__"), &envelope{
   128  				res: &llx.RawResult{
   129  					CodeID: "checksum",
   130  					Data:   llx.StringData("hello"),
   131  				},
   132  			})
   133  			data := nodeData.recalculate()
   134  
   135  			require.NotNil(t, data)
   136  			require.NotNil(t, data.res)
   137  			assert.Equal(t, "checksum", data.res.CodeID)
   138  			assert.Equal(t, llx.BoolTrue, data.res.Data)
   139  		})
   140  
   141  		t.Run("skips cast if required type are same", func(t *testing.T) {
   142  			nodeData := newNodeData()
   143  			typ := string(types.String)
   144  			nodeData.expectedType = &typ
   145  
   146  			nodeData.initialize()
   147  			nodeData.recalculate()
   148  
   149  			resData := llx.StringData("hello")
   150  			nodeData.consume(NodeID("__executor__"), &envelope{
   151  				res: &llx.RawResult{
   152  					CodeID: "checksum",
   153  					Data:   resData,
   154  				},
   155  			})
   156  			data := nodeData.recalculate()
   157  
   158  			require.NotNil(t, data)
   159  			require.NotNil(t, data.res)
   160  			assert.Equal(t, "checksum", data.res.CodeID)
   161  			assert.Equal(t, resData, data.res.Data)
   162  		})
   163  
   164  		t.Run("skips cast if datapoint is error", func(t *testing.T) {
   165  			nodeData := newNodeData()
   166  			typ := string(types.String)
   167  			nodeData.expectedType = &typ
   168  
   169  			nodeData.initialize()
   170  			nodeData.recalculate()
   171  
   172  			nodeData.consume(NodeID("__executor__"), &envelope{
   173  				res: &llx.RawResult{
   174  					CodeID: "checksum",
   175  					Data: &llx.RawData{
   176  						Error: errors.New("error happened"),
   177  					},
   178  				},
   179  			})
   180  			data := nodeData.recalculate()
   181  
   182  			require.NotNil(t, data)
   183  			require.NotNil(t, data.res)
   184  			assert.Equal(t, "checksum", data.res.CodeID)
   185  			require.NotNil(t, data.res.Data.Error)
   186  			assert.Equal(t, "error happened", data.res.Data.Error.Error())
   187  			assert.Nil(t, data.res.Data.Value)
   188  		})
   189  
   190  		t.Run("skips cast if expected type is unset", func(t *testing.T) {
   191  			nodeData := newNodeData()
   192  			typ := string(types.Unset)
   193  			nodeData.expectedType = &typ
   194  
   195  			nodeData.initialize()
   196  			nodeData.recalculate()
   197  
   198  			resData := llx.StringData("hello")
   199  			nodeData.consume(NodeID("__executor__"), &envelope{
   200  				res: &llx.RawResult{
   201  					CodeID: "checksum",
   202  					Data:   resData,
   203  				},
   204  			})
   205  			data := nodeData.recalculate()
   206  
   207  			require.NotNil(t, data)
   208  			require.NotNil(t, data.res)
   209  			assert.Equal(t, "checksum", data.res.CodeID)
   210  			assert.Equal(t, resData, data.res.Data)
   211  		})
   212  	})
   213  }
   214  
   215  func TestExecutionQueryNode(t *testing.T) {
   216  	newNodeData := func() (*ExecutionQueryNodeData, chan runQueueItem) {
   217  		q := make(chan runQueueItem, 1)
   218  		data := &ExecutionQueryNodeData{
   219  			queryID:            "testqueryid",
   220  			requiredProperties: map[string]*executionQueryProperty{},
   221  			runState:           notReadyQueryNotReady,
   222  			runQueue:           q,
   223  			codeBundle: &llx.CodeBundle{
   224  				CodeV2: &llx.CodeV2{
   225  					Id: "testqueryid",
   226  				},
   227  			},
   228  		}
   229  		return data, q
   230  	}
   231  	t.Run("initialize/recalculate", func(t *testing.T) {
   232  		t.Run("does not recalculate if dependencies not satisfied", func(t *testing.T) {
   233  			nodeData, q := newNodeData()
   234  			nodeData.requiredProperties = map[string]*executionQueryProperty{
   235  				"prop1": {
   236  					name:     "prop1",
   237  					checksum: "checksum1",
   238  					resolved: false,
   239  				},
   240  			}
   241  			nodeData.initialize()
   242  			data := nodeData.recalculate()
   243  			assert.Nil(t, data)
   244  			select {
   245  			case <-q:
   246  				assert.Fail(t, "not ready for execution")
   247  			default:
   248  			}
   249  		})
   250  		t.Run("recalculates if dependencies are satisfied", func(t *testing.T) {
   251  			nodeData, q := newNodeData()
   252  			nodeData.requiredProperties = map[string]*executionQueryProperty{
   253  				"prop1": {
   254  					name:     "prop1",
   255  					checksum: "checksum1",
   256  					resolved: true,
   257  					value:    llx.BoolFalse.Result(),
   258  				},
   259  				"prop2": {
   260  					name:     "prop2",
   261  					checksum: "checksum1",
   262  					resolved: true,
   263  					value:    llx.BoolFalse.Result(),
   264  				},
   265  			}
   266  			nodeData.initialize()
   267  			data := nodeData.recalculate()
   268  			assert.NotNil(t, data)
   269  			assert.Nil(t, data.res)
   270  			select {
   271  			case item := <-q:
   272  				require.NotNil(t, item.codeBundle)
   273  				assert.Equal(t, "testqueryid", item.codeBundle.CodeV2.Id)
   274  				assert.Contains(t, item.props, "prop1")
   275  			default:
   276  				assert.Fail(t, "expected something to be executed")
   277  			}
   278  		})
   279  	})
   280  
   281  	t.Run("consume/recalculate", func(t *testing.T) {
   282  		t.Run("does not recalculate if dependencies not satisfied", func(t *testing.T) {
   283  			nodeData, q := newNodeData()
   284  			nodeData.requiredProperties = map[string]*executionQueryProperty{
   285  				"prop1": {
   286  					name:     "prop1",
   287  					checksum: "checksum1",
   288  				},
   289  				"prop2": {
   290  					name:     "prop2",
   291  					checksum: "checksum2",
   292  				},
   293  			}
   294  			nodeData.initialize()
   295  			data := nodeData.recalculate()
   296  			assert.Nil(t, data)
   297  			nodeData.consume(NodeID("checksum1"), &envelope{
   298  				res: &llx.RawResult{
   299  					CodeID: "checksum1",
   300  					Data:   llx.BoolTrue,
   301  				},
   302  			})
   303  
   304  			select {
   305  			case <-q:
   306  				assert.Fail(t, "not ready for execution")
   307  			default:
   308  			}
   309  		})
   310  		t.Run("only recalculates once", func(t *testing.T) {
   311  			nodeData, q := newNodeData()
   312  			nodeData.requiredProperties = map[string]*executionQueryProperty{
   313  				"prop1": {
   314  					name:     "prop1",
   315  					checksum: "checksum1",
   316  				},
   317  				"prop2": {
   318  					name:     "prop2",
   319  					checksum: "checksum1",
   320  				},
   321  			}
   322  			nodeData.initialize()
   323  			data := nodeData.recalculate()
   324  			assert.Nil(t, data)
   325  			nodeData.consume(NodeID("checksum1"), &envelope{
   326  				res: &llx.RawResult{
   327  					CodeID: "checksum1",
   328  					Data:   llx.BoolTrue,
   329  				},
   330  			})
   331  			data = nodeData.recalculate()
   332  			assert.NotNil(t, data)
   333  			select {
   334  			case _ = <-q:
   335  			default:
   336  				assert.Fail(t, "expected something to be executed")
   337  			}
   338  
   339  			nodeData.consume(NodeID("checksum1"), &envelope{
   340  				res: &llx.RawResult{
   341  					CodeID: "checksum1",
   342  					Data:   llx.BoolTrue,
   343  				},
   344  			})
   345  			data = nodeData.recalculate()
   346  			select {
   347  			case _ = <-q:
   348  				assert.Fail(t, "query should not re-execute")
   349  			default:
   350  			}
   351  		})
   352  		t.Run("recalculates after all dependencies are satisfied", func(t *testing.T) {})
   353  	})
   354  }
   355  
   356  func TestCollectionFinisherNode(t *testing.T) {
   357  	newNodeData := func(reporter func(numCompleted int, total int)) *CollectionFinisherNodeData {
   358  		data := &CollectionFinisherNodeData{
   359  			progressReporter: ProgressReporterFunc(reporter),
   360  			doneChan:         make(chan struct{}),
   361  		}
   362  		return data
   363  	}
   364  
   365  	results := map[string]*llx.RawResult{
   366  		"codeID1": {
   367  			CodeID: "codeID1",
   368  			Data:   llx.BoolData(true),
   369  		},
   370  	}
   371  
   372  	t.Run("initialize/recalculate", func(t *testing.T) {
   373  		t.Run("recalculates if there are no remaining datapoints", func(t *testing.T) {
   374  			nodeData := newNodeData(func(completed int, total int) {
   375  				assert.Equal(t, 0, completed)
   376  				assert.Equal(t, 0, total)
   377  			})
   378  
   379  			nodeData.initialize()
   380  			nodeData.recalculate()
   381  
   382  			select {
   383  			case _, ok := <-nodeData.doneChan:
   384  				assert.False(t, ok)
   385  			default:
   386  				assert.Fail(t, "expected channel to be closed")
   387  			}
   388  		})
   389  		t.Run("does not recalculate if there are remaining datapoints", func(t *testing.T) {
   390  			nodeData := newNodeData(func(completed int, total int) {
   391  				assert.Fail(t, "should not recalculate")
   392  			})
   393  
   394  			nodeData.totalDatapoints = 2
   395  			nodeData.remainingDatapoints = map[string]struct{}{
   396  				"codeID1": {},
   397  				"codeID2": {},
   398  			}
   399  
   400  			nodeData.initialize()
   401  			nodeData.recalculate()
   402  
   403  			select {
   404  			case _, _ = <-nodeData.doneChan:
   405  				assert.Fail(t, "expected channel to be open")
   406  			default:
   407  			}
   408  		})
   409  	})
   410  
   411  	t.Run("consume/recalculate", func(t *testing.T) {
   412  		t.Run("notifies progress when partially complete", func(t *testing.T) {
   413  			progressCalled := false
   414  			nodeData := newNodeData(func(completed int, total int) {
   415  				progressCalled = true
   416  				assert.Equal(t, 1, completed)
   417  				assert.Equal(t, 2, total)
   418  			})
   419  			nodeData.totalDatapoints = 2
   420  			nodeData.remainingDatapoints = map[string]struct{}{
   421  				"codeID1": {},
   422  				"codeID2": {},
   423  			}
   424  			nodeData.initialize()
   425  			nodeData.consume("codeID1", &envelope{
   426  				res: results["codeID1"],
   427  			})
   428  			nodeData.recalculate()
   429  
   430  			assert.True(t, progressCalled)
   431  			select {
   432  			case _, _ = <-nodeData.doneChan:
   433  				assert.Fail(t, "expected channel to be open")
   434  			default:
   435  			}
   436  		})
   437  		t.Run("notifies progress and signals finish when fully complete", func(t *testing.T) {
   438  			progressCalled := false
   439  			nodeData := newNodeData(func(completed int, total int) {
   440  				progressCalled = true
   441  				assert.Equal(t, 1, completed)
   442  				assert.Equal(t, 1, total)
   443  			})
   444  			nodeData.totalDatapoints = 1
   445  			nodeData.remainingDatapoints = map[string]struct{}{
   446  				"codeID1": {},
   447  			}
   448  			nodeData.initialize()
   449  			nodeData.consume("codeID1", &envelope{
   450  				res: results["codeID1"],
   451  			})
   452  			nodeData.recalculate()
   453  
   454  			assert.True(t, progressCalled)
   455  			select {
   456  			case _, ok := <-nodeData.doneChan:
   457  				assert.False(t, ok)
   458  			default:
   459  				assert.Fail(t, "expected channel to be closed")
   460  			}
   461  		})
   462  	})
   463  }
   464  
   465  func TestDatapointCollectorNode(t *testing.T) {
   466  	newNodeData := func(collectorFunc func(results []*llx.RawResult)) *DatapointCollectorNodeData {
   467  		data := &DatapointCollectorNodeData{
   468  			unreported: make(map[string]*llx.RawResult),
   469  			collectors: []DatapointCollector{
   470  				&FuncCollector{
   471  					SinkDataFunc: collectorFunc,
   472  				},
   473  			},
   474  		}
   475  		return data
   476  	}
   477  
   478  	initExpectedData := func() map[string]*llx.RawResult {
   479  		return map[string]*llx.RawResult{
   480  			"codeID1": {
   481  				CodeID: "codeID1",
   482  				Data:   llx.BoolData(true),
   483  			},
   484  			"codeID2": {
   485  				CodeID: "codeID2",
   486  				Data:   llx.BoolData(false),
   487  			},
   488  		}
   489  	}
   490  	t.Run("initialize/recalculate", func(t *testing.T) {
   491  		t.Run("recalculates if unreported datapoints are available", func(t *testing.T) {
   492  			collected := map[string]int{}
   493  			expectedData := initExpectedData()
   494  			nodeData := newNodeData(func(results []*llx.RawResult) {
   495  				for _, r := range results {
   496  					assert.Equal(t, expectedData[r.CodeID], r)
   497  					collected[r.CodeID] = collected[r.CodeID] + 1
   498  				}
   499  			})
   500  
   501  			nodeData.unreported = expectedData
   502  
   503  			nodeData.initialize()
   504  			nodeData.recalculate()
   505  
   506  			assert.Equal(t, 2, len(collected))
   507  			for _, v := range collected {
   508  				assert.Equal(t, 1, v)
   509  			}
   510  		})
   511  
   512  		t.Run("does not recalculate if no unreported data", func(t *testing.T) {
   513  			calls := 0
   514  			nodeData := newNodeData(func(results []*llx.RawResult) {
   515  				calls += 1
   516  			})
   517  
   518  			nodeData.initialize()
   519  			nodeData.recalculate()
   520  
   521  			assert.Equal(t, 0, calls)
   522  		})
   523  	})
   524  
   525  	t.Run("consume/recalculate", func(t *testing.T) {
   526  		t.Run("recalculates if unreported datapoints are available", func(t *testing.T) {
   527  			collected := map[string]int{}
   528  			expectedData := initExpectedData()
   529  
   530  			nodeData := newNodeData(func(results []*llx.RawResult) {
   531  				for _, r := range results {
   532  					assert.Equal(t, expectedData[r.CodeID], r)
   533  					collected[r.CodeID] = collected[r.CodeID] + 1
   534  				}
   535  			})
   536  
   537  			nodeData.initialize()
   538  			nodeData.consume("codeID1", &envelope{
   539  				res: expectedData["codeID1"],
   540  			})
   541  			nodeData.consume("rjID1", &envelope{
   542  				res: expectedData["codeID2"],
   543  			})
   544  			nodeData.recalculate()
   545  
   546  			assert.Equal(t, 2, len(collected))
   547  			for _, v := range collected {
   548  				assert.Equal(t, 1, v)
   549  			}
   550  		})
   551  	})
   552  }