go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/mqlc/builtin_array.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package mqlc
     5  
     6  import (
     7  	"errors"
     8  
     9  	"go.mondoo.com/cnquery/llx"
    10  	"go.mondoo.com/cnquery/mqlc/parser"
    11  	"go.mondoo.com/cnquery/types"
    12  )
    13  
    14  func compileArrayWhere(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
    15  	if call == nil {
    16  		return types.Nil, errors.New("missing filter argument for calling '" + id + "'")
    17  	}
    18  	if len(call.Function) > 1 {
    19  		return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported")
    20  	}
    21  
    22  	// if the where function is called without arguments, we don't have to do anything
    23  	// so we just return the caller type as no additional step in the compiler is necessary
    24  	if len(call.Function) == 0 {
    25  		return typ, nil
    26  	}
    27  
    28  	arg := call.Function[0]
    29  	if arg.Name != "" {
    30  		return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
    31  	}
    32  
    33  	refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref)
    34  	if err != nil {
    35  		return types.Nil, err
    36  	}
    37  	if refs.block == 0 {
    38  		return types.Nil, errors.New("called '" + id + "' without a function block")
    39  	}
    40  	ref = refs.binding
    41  
    42  	argExpectation := llx.FunctionPrimitive(refs.block)
    43  
    44  	// if we have a standalone body in the where clause, then we need to check if
    45  	// it's a value, in which case we need to compare the array value to it
    46  	if refs.isStandalone {
    47  		block := c.Result.CodeV2.Block(refs.block)
    48  
    49  		if block == nil {
    50  			return types.Nil, errors.New("cannot find block for standalone compilation of array.where")
    51  		}
    52  		blockValueRef := block.TailRef(refs.block)
    53  
    54  		blockTyp := c.Result.CodeV2.DereferencedBlockType(block)
    55  		childType := typ.Child()
    56  		chunkId := "==" + string(childType)
    57  		if blockTyp != childType {
    58  			chunkId = "==" + string(blockTyp)
    59  			_, err := llx.BuiltinFunctionV2(childType, chunkId)
    60  			if err != nil {
    61  				return types.Nil, errors.New("called '" + id + "' with wrong type; either provide a type " + childType.Label() + " value or write it as an expression (e.g. \"_ == 123\")")
    62  			}
    63  		}
    64  
    65  		block.AddChunk(c.Result.CodeV2, refs.block, &llx.Chunk{
    66  			Call: llx.Chunk_FUNCTION,
    67  			Id:   chunkId,
    68  			Function: &llx.Function{
    69  				Type:    string(types.Bool),
    70  				Binding: refs.block | 1,
    71  				Args:    []*llx.Primitive{llx.RefPrimitiveV2(blockValueRef)},
    72  			},
    73  		})
    74  
    75  		block.Entrypoints = []uint64{block.TailRef(refs.block)}
    76  	}
    77  
    78  	args := []*llx.Primitive{
    79  		llx.RefPrimitiveV2(ref),
    80  		argExpectation,
    81  	}
    82  	for _, v := range refs.deps {
    83  		if c.isInMyBlock(v) {
    84  			args = append(args, llx.RefPrimitiveV2(v))
    85  		}
    86  	}
    87  	c.blockDeps = append(c.blockDeps, refs.deps...)
    88  
    89  	c.addChunk(&llx.Chunk{
    90  		Call: llx.Chunk_FUNCTION,
    91  		Id:   id,
    92  		Function: &llx.Function{
    93  			Type:    string(typ),
    94  			Binding: ref,
    95  			Args:    args,
    96  		},
    97  	})
    98  	return typ, nil
    99  }
   100  
   101  func compileArrayDuplicates(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   102  	if call != nil && len(call.Function) > 1 {
   103  		return types.Nil, errors.New("too many arguments when calling '" + id + "'")
   104  	} else if call != nil && len(call.Function) == 1 {
   105  		arg := call.Function[0]
   106  
   107  		refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref)
   108  		if err != nil {
   109  			return types.Nil, err
   110  		}
   111  		if refs.block == 0 {
   112  			return types.Nil, errors.New("called '" + id + "' without a function block")
   113  		}
   114  		ref = refs.binding
   115  		argExpectation := llx.FunctionPrimitive(refs.block)
   116  
   117  		if refs.isStandalone {
   118  			return typ, errors.New("called duplicates with a field name on an invalid type")
   119  		}
   120  
   121  		args := []*llx.Primitive{
   122  			llx.RefPrimitiveV2(ref),
   123  			argExpectation,
   124  		}
   125  
   126  		for _, v := range refs.deps {
   127  			if c.isInMyBlock(v) {
   128  				args = append(args, llx.RefPrimitiveV2(v))
   129  			}
   130  		}
   131  		c.blockDeps = append(c.blockDeps, refs.deps...)
   132  
   133  		c.addChunk(&llx.Chunk{
   134  			Call: llx.Chunk_FUNCTION,
   135  			Id:   "fieldDuplicates",
   136  			Function: &llx.Function{
   137  				Type:    string(typ),
   138  				Binding: ref,
   139  				Args:    args,
   140  			},
   141  		})
   142  		return typ, nil
   143  	}
   144  
   145  	// Duplicates is being called with 0 arguments, which means it should be on an
   146  	// array of basic types
   147  	ct := typ.Child()
   148  	_, ok := types.Equal[ct]
   149  	if !ok {
   150  		return typ, errors.New("cannot extract duplicates from array, must be a basic type. Try using a field argument.")
   151  	}
   152  
   153  	c.addChunk(&llx.Chunk{
   154  		Call: llx.Chunk_FUNCTION,
   155  		Id:   id,
   156  		Function: &llx.Function{
   157  			Type:    string(typ),
   158  			Binding: ref,
   159  		},
   160  	})
   161  	return typ, nil
   162  }
   163  
   164  func compileArrayUnique(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   165  	if call != nil && len(call.Function) > 0 {
   166  		return types.Nil, errors.New("too many arguments when calling '" + id + "'")
   167  	}
   168  
   169  	ct := typ.Child()
   170  	_, ok := types.Equal[ct]
   171  	if !ok {
   172  		return typ, errors.New("cannot extract uniques from array, don't know how to compare entries")
   173  	}
   174  
   175  	c.addChunk(&llx.Chunk{
   176  		Call: llx.Chunk_FUNCTION,
   177  		Id:   id,
   178  		Function: &llx.Function{
   179  			Type:    string(typ),
   180  			Binding: ref,
   181  		},
   182  	})
   183  	return typ, nil
   184  }
   185  
   186  func compileArrayContains(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   187  	_, err := compileArrayWhere(c, typ, ref, "where", call)
   188  	if err != nil {
   189  		return types.Nil, err
   190  	}
   191  
   192  	// .length
   193  	c.addChunk(&llx.Chunk{
   194  		Call: llx.Chunk_FUNCTION,
   195  		Id:   "length",
   196  		Function: &llx.Function{
   197  			Type:    string(types.Int),
   198  			Binding: c.tailRef(),
   199  		},
   200  	})
   201  
   202  	// > 0
   203  	c.addChunk(&llx.Chunk{
   204  		Call: llx.Chunk_FUNCTION,
   205  		Id:   string(">" + types.Int),
   206  		Function: &llx.Function{
   207  			Type:    string(types.Bool),
   208  			Binding: c.tailRef(),
   209  			Args: []*llx.Primitive{
   210  				llx.IntPrimitive(0),
   211  			},
   212  		},
   213  	})
   214  
   215  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   216  	c.Result.Labels.Labels[checksum] = "[].contains()"
   217  
   218  	return types.Bool, nil
   219  }
   220  
   221  func compileArrayContainsOnly(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   222  	if call == nil || len(call.Function) != 1 {
   223  		return types.Nil, errors.New("function " + id + " needs one argument (array)")
   224  	}
   225  
   226  	f := call.Function[0]
   227  	if f.Value == nil || f.Value.Operand == nil {
   228  		return types.Nil, errors.New("function " + id + " needs one argument (array)")
   229  	}
   230  
   231  	val, err := c.compileOperand(f.Value.Operand)
   232  	if err != nil {
   233  		return types.Nil, err
   234  	}
   235  
   236  	valType, err := c.dereferenceType(val)
   237  	if err != nil {
   238  		return types.Nil, err
   239  	}
   240  
   241  	if valType != typ {
   242  		return types.Nil, errors.New("types don't match for calling contains (got: " + valType.Label() + ", expected: " + typ.Label() + ")")
   243  	}
   244  
   245  	// .difference
   246  	c.addChunk(&llx.Chunk{
   247  		Call: llx.Chunk_FUNCTION,
   248  		Id:   "difference",
   249  		Function: &llx.Function{
   250  			Type:    string(typ),
   251  			Binding: ref,
   252  			Args: []*llx.Primitive{
   253  				val,
   254  			},
   255  		},
   256  	})
   257  
   258  	// == []
   259  	c.addChunk(&llx.Chunk{
   260  		Call: llx.Chunk_FUNCTION,
   261  		Id:   string("=="),
   262  		Function: &llx.Function{
   263  			Type:    string(types.Bool),
   264  			Binding: c.tailRef(),
   265  			Args: []*llx.Primitive{
   266  				llx.ArrayPrimitive([]*llx.Primitive{}, typ.Child()),
   267  			},
   268  		},
   269  	})
   270  
   271  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   272  	c.Result.Labels.Labels[checksum] = "[].containsOnly()"
   273  
   274  	return types.Bool, nil
   275  }
   276  
   277  func compileArrayContainsNone(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   278  	if call == nil || len(call.Function) != 1 {
   279  		return types.Nil, errors.New("function " + id + " needs one argument (array)")
   280  	}
   281  
   282  	f := call.Function[0]
   283  	if f.Value == nil || f.Value.Operand == nil {
   284  		return types.Nil, errors.New("function " + id + " needs one argument")
   285  	}
   286  
   287  	val, err := c.compileOperand(f.Value.Operand)
   288  	if err != nil {
   289  		return types.Nil, err
   290  	}
   291  
   292  	valType, err := c.dereferenceType(val)
   293  	if err != nil {
   294  		return types.Nil, err
   295  	}
   296  
   297  	if valType != typ {
   298  		return types.Nil, errors.New("types don't match for calling contains (got: " + valType.Label() + ", expected: " + typ.Label() + ")")
   299  	}
   300  
   301  	// .containsNone
   302  	c.addChunk(&llx.Chunk{
   303  		Call: llx.Chunk_FUNCTION,
   304  		Id:   "containsNone",
   305  		Function: &llx.Function{
   306  			Type:    string(typ),
   307  			Binding: ref,
   308  			Args: []*llx.Primitive{
   309  				val,
   310  			},
   311  		},
   312  	})
   313  
   314  	// == []
   315  	c.addChunk(&llx.Chunk{
   316  		Call: llx.Chunk_FUNCTION,
   317  		Id:   string("=="),
   318  		Function: &llx.Function{
   319  			Type:    string(types.Bool),
   320  			Binding: c.tailRef(),
   321  			Args: []*llx.Primitive{
   322  				llx.ArrayPrimitive([]*llx.Primitive{}, typ.Child()),
   323  			},
   324  		},
   325  	})
   326  
   327  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   328  	c.Result.Labels.Labels[checksum] = "[].containsNone()"
   329  
   330  	return types.Bool, nil
   331  }
   332  
   333  func compileArrayAll(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   334  	_, err := compileArrayWhere(c, typ, ref, "$whereNot", call)
   335  	if err != nil {
   336  		return types.Nil, err
   337  	}
   338  	listRef := c.tailRef()
   339  
   340  	if err := compileListAssertionMsg(c, typ, ref, listRef, listRef); err != nil {
   341  		return types.Nil, err
   342  	}
   343  
   344  	c.addChunk(&llx.Chunk{
   345  		Call: llx.Chunk_FUNCTION,
   346  		Id:   "$all",
   347  		Function: &llx.Function{
   348  			Type:    string(types.Bool),
   349  			Binding: listRef,
   350  		},
   351  	})
   352  
   353  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   354  	c.Result.Labels.Labels[checksum] = "[].all()"
   355  
   356  	return types.Bool, nil
   357  }
   358  
   359  func compileArrayAny(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   360  	_, err := compileArrayWhere(c, typ, ref, "where", call)
   361  	if err != nil {
   362  		return types.Nil, err
   363  	}
   364  	listRef := c.tailRef()
   365  
   366  	if err := compileListAssertionMsg(c, typ, ref, ref, listRef); err != nil {
   367  		return types.Nil, err
   368  	}
   369  
   370  	c.addChunk(&llx.Chunk{
   371  		Call: llx.Chunk_FUNCTION,
   372  		Id:   "$any",
   373  		Function: &llx.Function{
   374  			Type:    string(types.Bool),
   375  			Binding: listRef,
   376  		},
   377  	})
   378  
   379  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   380  	c.Result.Labels.Labels[checksum] = "[].any()"
   381  
   382  	return types.Bool, nil
   383  }
   384  
   385  func compileArrayOne(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   386  	_, err := compileArrayWhere(c, typ, ref, "where", call)
   387  	if err != nil {
   388  		return types.Nil, err
   389  	}
   390  	listRef := c.tailRef()
   391  
   392  	if err := compileListAssertionMsg(c, typ, ref, listRef, listRef); err != nil {
   393  		return types.Nil, err
   394  	}
   395  
   396  	c.addChunk(&llx.Chunk{
   397  		Call: llx.Chunk_FUNCTION,
   398  		Id:   "$one",
   399  		Function: &llx.Function{
   400  			Type:    string(types.Bool),
   401  			Binding: listRef,
   402  		},
   403  	})
   404  
   405  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   406  	c.Result.Labels.Labels[checksum] = "[].one()"
   407  
   408  	return types.Bool, nil
   409  }
   410  
   411  func compileArrayNone(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   412  	_, err := compileArrayWhere(c, typ, ref, "where", call)
   413  	if err != nil {
   414  		return types.Nil, err
   415  	}
   416  	listRef := c.tailRef()
   417  
   418  	if err := compileListAssertionMsg(c, typ, ref, listRef, listRef); err != nil {
   419  		return types.Nil, err
   420  	}
   421  
   422  	c.addChunk(&llx.Chunk{
   423  		Call: llx.Chunk_FUNCTION,
   424  		Id:   "$none",
   425  		Function: &llx.Function{
   426  			Type:    string(types.Bool),
   427  			Binding: listRef,
   428  		},
   429  	})
   430  
   431  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   432  	c.Result.Labels.Labels[checksum] = "[].none()"
   433  
   434  	return types.Bool, nil
   435  }
   436  
   437  func compileArrayMap(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   438  	if call == nil {
   439  		return types.Nil, errors.New("missing filter argument for calling '" + id + "'")
   440  	}
   441  	if len(call.Function) > 1 {
   442  		return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported")
   443  	}
   444  
   445  	// if the map function is called without arguments, we don't have to do anything
   446  	// so we just return the caller type as no additional step in the compiler is necessary
   447  	if len(call.Function) == 0 {
   448  		return typ, nil
   449  	}
   450  
   451  	arg := call.Function[0]
   452  	if arg.Name != "" {
   453  		return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
   454  	}
   455  
   456  	refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref)
   457  	if err != nil {
   458  		return types.Nil, err
   459  	}
   460  	if refs.block == 0 {
   461  		return types.Nil, errors.New("called '" + id + "' without a function block")
   462  	}
   463  	ref = refs.binding
   464  	argExpectation := llx.FunctionPrimitive(refs.block)
   465  
   466  	block := c.Result.CodeV2.Block(refs.block)
   467  	if len(block.Entrypoints) != 1 {
   468  		return types.Nil, errors.New("called '" + id + "' with a bad function block, you can only return 1 value")
   469  	}
   470  	mappedType := c.Result.CodeV2.DereferencedBlockType(block)
   471  
   472  	args := []*llx.Primitive{
   473  		llx.RefPrimitiveV2(ref),
   474  		argExpectation,
   475  	}
   476  	for _, v := range refs.deps {
   477  		if c.isInMyBlock(v) {
   478  			args = append(args, llx.RefPrimitiveV2(v))
   479  		}
   480  	}
   481  	c.blockDeps = append(c.blockDeps, refs.deps...)
   482  
   483  	c.addChunk(&llx.Chunk{
   484  		Call: llx.Chunk_FUNCTION,
   485  		Id:   id,
   486  		Function: &llx.Function{
   487  			Type:    string(types.Array(mappedType)),
   488  			Binding: ref,
   489  			Args:    args,
   490  		},
   491  	})
   492  	return types.Array(mappedType), nil
   493  }
   494  
   495  func compileArrayFlat(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   496  	if call != nil && len(call.Function) > 0 {
   497  		return types.Nil, errors.New("no arguments supported for '" + id + "'")
   498  	}
   499  
   500  	for typ.IsArray() {
   501  		typ = typ.Child()
   502  	}
   503  	typ = types.Array(typ)
   504  
   505  	c.addChunk(&llx.Chunk{
   506  		Call: llx.Chunk_FUNCTION,
   507  		Id:   id,
   508  		Function: &llx.Function{
   509  			Type:    string(typ),
   510  			Binding: ref,
   511  		},
   512  	})
   513  	return typ, nil
   514  }