go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/mqlc/builtin_map.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package mqlc
     5  
     6  import (
     7  	"errors"
     8  
     9  	"go.mondoo.com/cnquery/llx"
    10  	"go.mondoo.com/cnquery/mqlc/parser"
    11  	"go.mondoo.com/cnquery/types"
    12  )
    13  
    14  func compileDictWhere(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
    15  	if call == nil {
    16  		return types.Nil, errors.New("missing filter argument for calling '" + id + "'")
    17  	}
    18  	if len(call.Function) > 1 {
    19  		return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported")
    20  	}
    21  
    22  	// if the where function is called without arguments, we don't have to do anything
    23  	// so we just return the caller type as no additional step in the compiler is necessary
    24  	if len(call.Function) == 0 {
    25  		return typ, nil
    26  	}
    27  
    28  	arg := call.Function[0]
    29  	if arg.Name != "" {
    30  		return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
    31  	}
    32  
    33  	keyType := types.Dict
    34  	valueType := types.Dict
    35  	bindingChecksum := c.Result.CodeV2.Checksums[c.tailRef()]
    36  
    37  	blockCompiler := c.newBlockCompiler(&variable{
    38  		typ: typ,
    39  		ref: ref,
    40  	})
    41  
    42  	blockCompiler.addArgumentPlaceholder(keyType, bindingChecksum)
    43  	blockCompiler.vars.add("key", variable{
    44  		ref: blockCompiler.tailRef(),
    45  		typ: keyType,
    46  		callback: func() {
    47  			blockCompiler.standalone = false
    48  		},
    49  	})
    50  
    51  	blockCompiler.addArgumentPlaceholder(valueType, bindingChecksum)
    52  	blockCompiler.vars.add("value", variable{
    53  		ref: blockCompiler.tailRef(),
    54  		typ: valueType,
    55  		callback: func() {
    56  			blockCompiler.standalone = false
    57  		},
    58  	})
    59  
    60  	// we want to make sure the `_` points to the value, which is useful when dealing
    61  	// with arrays and the default in maps
    62  	blockCompiler.Binding.ref = blockCompiler.tailRef()
    63  
    64  	err := blockCompiler.compileExpressions([]*parser.Expression{arg.Value})
    65  	c.Result.Suggestions = append(c.Result.Suggestions, blockCompiler.Result.Suggestions...)
    66  	if err != nil {
    67  		return typ, err
    68  	}
    69  
    70  	// if we have a standalone body in the where clause, then we need to check if
    71  	// it's a value, in which case we need to compare the array value to it
    72  	if blockCompiler.standalone {
    73  		block := blockCompiler.block
    74  		blockValueRef := block.TailRef(blockCompiler.blockRef)
    75  
    76  		blockTyp := c.Result.CodeV2.DereferencedBlockType(block)
    77  		childType := typ.Child()
    78  		chunkId := "==" + string(childType)
    79  		if blockTyp != childType {
    80  			chunkId = "==" + string(blockTyp)
    81  			_, err := llx.BuiltinFunctionV2(childType, chunkId)
    82  			if err != nil {
    83  				return types.Nil, errors.New("called '" + id + "' with wrong type; either provide a type " + childType.Label() + " value or write it as an expression (e.g. \"_ == 123\")")
    84  			}
    85  		}
    86  
    87  		block.AddChunk(c.Result.CodeV2, blockCompiler.blockRef, &llx.Chunk{
    88  			Call: llx.Chunk_FUNCTION,
    89  			Id:   chunkId,
    90  			Function: &llx.Function{
    91  				Type:    string(types.Bool),
    92  				Binding: blockCompiler.blockRef | 2,
    93  				Args:    []*llx.Primitive{llx.RefPrimitiveV2(blockValueRef)},
    94  			},
    95  		})
    96  
    97  		block.Entrypoints = []uint64{block.TailRef(blockCompiler.blockRef)}
    98  	}
    99  
   100  	argExpectation := llx.FunctionPrimitive(blockCompiler.blockRef)
   101  
   102  	args := []*llx.Primitive{
   103  		llx.RefPrimitiveV2(ref),
   104  		argExpectation,
   105  	}
   106  	for _, v := range blockCompiler.blockDeps {
   107  		if c.isInMyBlock(v) {
   108  			args = append(args, llx.RefPrimitiveV2(v))
   109  		}
   110  	}
   111  	c.blockDeps = append(c.blockDeps, blockCompiler.blockDeps...)
   112  
   113  	c.addChunk(&llx.Chunk{
   114  		Call: llx.Chunk_FUNCTION,
   115  		Id:   id,
   116  		Function: &llx.Function{
   117  			Type:    string(typ),
   118  			Binding: ref,
   119  			Args:    args,
   120  		},
   121  	})
   122  	return typ, nil
   123  }
   124  
   125  func compileDictContains(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   126  	_, err := compileDictWhere(c, typ, ref, "where", call)
   127  	if err != nil {
   128  		return types.Nil, err
   129  	}
   130  
   131  	// .length
   132  	c.addChunk(&llx.Chunk{
   133  		Call: llx.Chunk_FUNCTION,
   134  		Id:   "length",
   135  		Function: &llx.Function{
   136  			Type:    string(types.Int),
   137  			Binding: c.tailRef(),
   138  		},
   139  	})
   140  
   141  	// > 0
   142  	c.addChunk(&llx.Chunk{
   143  		Call: llx.Chunk_FUNCTION,
   144  		Id:   string(">" + types.Int),
   145  		Function: &llx.Function{
   146  			Type:    string(types.Bool),
   147  			Binding: c.tailRef(),
   148  			Args: []*llx.Primitive{
   149  				llx.IntPrimitive(0),
   150  			},
   151  		},
   152  	})
   153  
   154  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   155  	c.Result.Labels.Labels[checksum] = "[].contains()"
   156  
   157  	return types.Bool, nil
   158  }
   159  
   160  func compileDictContainsOnly(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   161  	if call == nil || len(call.Function) != 1 {
   162  		return types.Nil, errors.New("function " + id + " needs one argument (dict)")
   163  	}
   164  
   165  	f := call.Function[0]
   166  	if f.Value == nil || f.Value.Operand == nil {
   167  		return types.Nil, errors.New("function " + id + " needs one argument (dict)")
   168  	}
   169  
   170  	val, err := c.compileOperand(f.Value.Operand)
   171  	if err != nil {
   172  		return types.Nil, err
   173  	}
   174  
   175  	valType, err := c.dereferenceType(val)
   176  	if err != nil {
   177  		return types.Nil, err
   178  	}
   179  
   180  	chunkId := "==" + string(typ)
   181  	if valType != typ {
   182  		chunkId = "==" + string(valType)
   183  		_, err := llx.BuiltinFunctionV2(typ, chunkId)
   184  		if err != nil {
   185  			return types.Nil, errors.New("called '" + id + "' with wrong type; either provide a type " + typ.Label() + " value or write it as an expression (e.g. \"_ == 123\")")
   186  		}
   187  	}
   188  
   189  	// .difference
   190  	c.addChunk(&llx.Chunk{
   191  		Call: llx.Chunk_FUNCTION,
   192  		Id:   "difference",
   193  		Function: &llx.Function{
   194  			Type:    string(typ),
   195  			Binding: ref,
   196  			Args: []*llx.Primitive{
   197  				val,
   198  			},
   199  		},
   200  	})
   201  
   202  	// == []
   203  	c.addChunk(&llx.Chunk{
   204  		Call: llx.Chunk_FUNCTION,
   205  		Id:   chunkId,
   206  		Function: &llx.Function{
   207  			Type:    string(types.Bool),
   208  			Binding: c.tailRef(),
   209  			Args: []*llx.Primitive{
   210  				llx.ArrayPrimitive([]*llx.Primitive{}, typ.Child()),
   211  			},
   212  		},
   213  	})
   214  
   215  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   216  	c.Result.Labels.Labels[checksum] = "[].containsOnly()"
   217  
   218  	return types.Bool, nil
   219  }
   220  
   221  func compileDictContainsNone(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   222  	if call == nil || len(call.Function) != 1 {
   223  		return types.Nil, errors.New("function " + id + " needs one argument (dict)")
   224  	}
   225  
   226  	f := call.Function[0]
   227  	if f.Value == nil || f.Value.Operand == nil {
   228  		return types.Nil, errors.New("function " + id + " needs one argument (dict)")
   229  	}
   230  
   231  	val, err := c.compileOperand(f.Value.Operand)
   232  	if err != nil {
   233  		return types.Nil, err
   234  	}
   235  
   236  	valType, err := c.dereferenceType(val)
   237  	if err != nil {
   238  		return types.Nil, err
   239  	}
   240  
   241  	chunkId := "==" + string(typ)
   242  	if valType != typ {
   243  		chunkId = "==" + string(valType)
   244  		_, err := llx.BuiltinFunctionV2(typ, chunkId)
   245  		if err != nil {
   246  			return types.Nil, errors.New("called '" + id + "' with wrong type; either provide a type " + typ.Label() + " value or write it as an expression (e.g. \"_ == 123\")")
   247  		}
   248  	}
   249  
   250  	// .containsNone
   251  	c.addChunk(&llx.Chunk{
   252  		Call: llx.Chunk_FUNCTION,
   253  		Id:   "containsNone",
   254  		Function: &llx.Function{
   255  			Type:    string(typ),
   256  			Binding: ref,
   257  			Args: []*llx.Primitive{
   258  				val,
   259  			},
   260  		},
   261  	})
   262  
   263  	// == []
   264  	c.addChunk(&llx.Chunk{
   265  		Call: llx.Chunk_FUNCTION,
   266  		Id:   chunkId,
   267  		Function: &llx.Function{
   268  			Type:    string(types.Bool),
   269  			Binding: c.tailRef(),
   270  			Args: []*llx.Primitive{
   271  				llx.ArrayPrimitive([]*llx.Primitive{}, typ.Child()),
   272  			},
   273  		},
   274  	})
   275  
   276  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   277  	c.Result.Labels.Labels[checksum] = "[].containsNone()"
   278  
   279  	return types.Bool, nil
   280  }
   281  
   282  func compileDictAll(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   283  	_, err := compileDictWhere(c, typ, ref, "$whereNot", call)
   284  	if err != nil {
   285  		return types.Nil, err
   286  	}
   287  	listRef := c.tailRef()
   288  
   289  	if err := compileListAssertionMsg(c, typ, ref, listRef, listRef); err != nil {
   290  		return types.Nil, err
   291  	}
   292  
   293  	c.addChunk(&llx.Chunk{
   294  		Call: llx.Chunk_FUNCTION,
   295  		Id:   "$all",
   296  		Function: &llx.Function{
   297  			Type:    string(types.Bool),
   298  			Binding: listRef,
   299  		},
   300  	})
   301  
   302  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   303  	c.Result.Labels.Labels[checksum] = "[].all()"
   304  
   305  	return types.Bool, nil
   306  }
   307  
   308  func compileDictAny(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   309  	_, err := compileDictWhere(c, typ, ref, "where", call)
   310  	if err != nil {
   311  		return types.Nil, err
   312  	}
   313  	listRef := c.tailRef()
   314  
   315  	if err := compileListAssertionMsg(c, typ, ref, ref, listRef); err != nil {
   316  		return types.Nil, err
   317  	}
   318  
   319  	c.addChunk(&llx.Chunk{
   320  		Call: llx.Chunk_FUNCTION,
   321  		Id:   "$any",
   322  		Function: &llx.Function{
   323  			Type:    string(types.Bool),
   324  			Binding: listRef,
   325  		},
   326  	})
   327  
   328  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   329  	c.Result.Labels.Labels[checksum] = "[].any()"
   330  
   331  	return types.Bool, nil
   332  }
   333  
   334  func compileDictOne(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   335  	_, err := compileDictWhere(c, typ, ref, "where", call)
   336  	if err != nil {
   337  		return types.Nil, err
   338  	}
   339  	listRef := c.tailRef()
   340  
   341  	if err := compileListAssertionMsg(c, typ, ref, listRef, listRef); err != nil {
   342  		return types.Nil, err
   343  	}
   344  
   345  	c.addChunk(&llx.Chunk{
   346  		Call: llx.Chunk_FUNCTION,
   347  		Id:   "$one",
   348  		Function: &llx.Function{
   349  			Type:    string(types.Bool),
   350  			Binding: listRef,
   351  		},
   352  	})
   353  
   354  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   355  	c.Result.Labels.Labels[checksum] = "[].one()"
   356  
   357  	return types.Bool, nil
   358  }
   359  
   360  func compileDictNone(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   361  	_, err := compileDictWhere(c, typ, ref, "where", call)
   362  	if err != nil {
   363  		return types.Nil, err
   364  	}
   365  	listRef := c.tailRef()
   366  
   367  	if err := compileListAssertionMsg(c, typ, ref, listRef, listRef); err != nil {
   368  		return types.Nil, err
   369  	}
   370  
   371  	c.addChunk(&llx.Chunk{
   372  		Call: llx.Chunk_FUNCTION,
   373  		Id:   "$none",
   374  		Function: &llx.Function{
   375  			Type:    string(types.Bool),
   376  			Binding: listRef,
   377  		},
   378  	})
   379  
   380  	checksum := c.Result.CodeV2.Checksums[c.tailRef()]
   381  	c.Result.Labels.Labels[checksum] = "[].none()"
   382  
   383  	return types.Bool, nil
   384  }
   385  
   386  func compileDictFlat(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   387  	if call != nil && len(call.Function) > 0 {
   388  		return types.Nil, errors.New("no arguments supported for '" + id + "'")
   389  	}
   390  
   391  	typ = types.Array(types.Dict)
   392  	c.addChunk(&llx.Chunk{
   393  		Call: llx.Chunk_FUNCTION,
   394  		Id:   id,
   395  		Function: &llx.Function{
   396  			Type:    string(typ),
   397  			Binding: ref,
   398  		},
   399  	})
   400  	return typ, nil
   401  }
   402  
   403  func compileMapWhere(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   404  	if call == nil {
   405  		return types.Nil, errors.New("missing filter argument for calling '" + id + "'")
   406  	}
   407  	if len(call.Function) > 1 {
   408  		return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported")
   409  	}
   410  
   411  	// if the where function is called without arguments, we don't have to do anything
   412  	// so we just return the caller type as no additional step in the compiler is necessary
   413  	if len(call.Function) == 0 {
   414  		return typ, nil
   415  	}
   416  
   417  	arg := call.Function[0]
   418  	if arg.Name != "" {
   419  		return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
   420  	}
   421  
   422  	keyType := typ.Key()
   423  	valueType := typ.Child()
   424  	bindingChecksum := c.Result.CodeV2.Checksums[c.tailRef()]
   425  
   426  	blockCompiler := c.newBlockCompiler(&variable{
   427  		typ: typ,
   428  		ref: ref,
   429  	})
   430  
   431  	blockCompiler.addArgumentPlaceholder(keyType, bindingChecksum)
   432  	blockCompiler.vars.add("key", variable{
   433  		ref: blockCompiler.tailRef(),
   434  		typ: keyType,
   435  		callback: func() {
   436  			blockCompiler.standalone = false
   437  		},
   438  	})
   439  
   440  	blockCompiler.addArgumentPlaceholder(valueType, bindingChecksum)
   441  	blockCompiler.vars.add("value", variable{
   442  		ref: blockCompiler.tailRef(),
   443  		typ: valueType,
   444  		callback: func() {
   445  			blockCompiler.standalone = false
   446  		},
   447  	})
   448  
   449  	// we want to make sure the `_` points to the value, which is useful when dealing
   450  	// with arrays and the default in maps
   451  	blockCompiler.Binding.ref = blockCompiler.tailRef()
   452  
   453  	err := blockCompiler.compileExpressions([]*parser.Expression{arg.Value})
   454  	c.Result.Suggestions = append(c.Result.Suggestions, blockCompiler.Result.Suggestions...)
   455  	if err != nil {
   456  		return typ, err
   457  	}
   458  
   459  	argExpectation := llx.FunctionPrimitive(blockCompiler.blockRef)
   460  
   461  	args := []*llx.Primitive{
   462  		llx.RefPrimitiveV2(ref),
   463  		argExpectation,
   464  	}
   465  	for _, v := range blockCompiler.blockDeps {
   466  		if c.isInMyBlock(v) {
   467  			args = append(args, llx.RefPrimitiveV2(v))
   468  		}
   469  	}
   470  	c.blockDeps = append(c.blockDeps, blockCompiler.blockDeps...)
   471  
   472  	c.addChunk(&llx.Chunk{
   473  		Call: llx.Chunk_FUNCTION,
   474  		Id:   id,
   475  		Function: &llx.Function{
   476  			Type:    string(typ),
   477  			Binding: ref,
   478  			Args:    args,
   479  		},
   480  	})
   481  	return typ, nil
   482  }
   483  
   484  func compileMapValues(c *compiler, typ types.Type, ref uint64, id string, call *parser.Call) (types.Type, error) {
   485  	typ = types.Array(typ.Child())
   486  	c.addChunk(&llx.Chunk{
   487  		Call: llx.Chunk_FUNCTION,
   488  		Id:   id,
   489  		Function: &llx.Function{
   490  			Type:    string(typ),
   491  			Binding: ref,
   492  		},
   493  	})
   494  	return typ, nil
   495  }