go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/mqlc/assertion.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package mqlc
     5  
     6  import (
     7  	"errors"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"go.mondoo.com/cnquery/llx"
    12  	"go.mondoo.com/cnquery/mqlc/parser"
    13  	"go.mondoo.com/cnquery/types"
    14  )
    15  
    16  func extractComments(c *parser.Expression) string {
    17  	// TODO: we need to clarify how many of the comments we really want to extract.
    18  	// For now we only grab the operand and ignore the rest
    19  	if c == nil || c.Operand == nil {
    20  		return ""
    21  	}
    22  	return c.Operand.Comments
    23  }
    24  
    25  func extractMsgTag(comment string) string {
    26  	lines := strings.Split(comment, "\n")
    27  	var msgLines strings.Builder
    28  
    29  	var i int
    30  	for i < len(lines) {
    31  		if strings.HasPrefix(lines[i], "@msg ") {
    32  			break
    33  		}
    34  		i++
    35  	}
    36  	if i == len(lines) {
    37  		return ""
    38  	}
    39  
    40  	msgLines.WriteString(lines[i][5:])
    41  	msgLines.WriteByte('\n')
    42  	i++
    43  
    44  	for i < len(lines) {
    45  		line := lines[i]
    46  		if line != "" && line[0] == '@' {
    47  			break
    48  		}
    49  		msgLines.WriteString(line)
    50  		msgLines.WriteByte('\n')
    51  		i++
    52  	}
    53  
    54  	return msgLines.String()
    55  }
    56  
    57  func extractMql(s string) (string, error) {
    58  	var openBrackets []byte
    59  	for i := 0; i < len(s); i++ {
    60  		switch s[i] {
    61  		case '"', '\'':
    62  			// TODO: for all of these string things we need to support proper string interpolation...
    63  			d := s[i]
    64  			for ; i < len(s) && s[i] != d; i++ {
    65  			}
    66  		case '{', '(', '[':
    67  			openBrackets = append(openBrackets, s[i])
    68  		case '}':
    69  			if len(openBrackets) == 0 {
    70  				return s[0:i], nil
    71  			}
    72  			last := openBrackets[len(openBrackets)-1]
    73  			if last != '{' {
    74  				return "", errors.New("unexpected closing bracket '" + string(s[i]) + "'")
    75  			}
    76  			openBrackets = openBrackets[0 : len(openBrackets)-1]
    77  		case ')', ']':
    78  			if len(openBrackets) == 0 {
    79  				return "", errors.New("unexpected closing bracket '" + string(s[i]) + "'")
    80  			}
    81  			last := openBrackets[len(openBrackets)-1]
    82  			if (s[i] == ')' && last != '(') || (s[i] == ']' && last != '[') {
    83  				return "", errors.New("unexpected closing bracket '" + string(s[i]) + "'")
    84  			}
    85  			openBrackets = openBrackets[0 : len(openBrackets)-1]
    86  		}
    87  	}
    88  
    89  	return s, nil
    90  }
    91  
    92  func compileAssertionMsg(msg string, c *compiler) (*llx.AssertionMessage, error) {
    93  	template := strings.Builder{}
    94  	var codes []string
    95  	var i int
    96  	max := len(msg)
    97  	textStart := i
    98  	for ; i < max; i++ {
    99  		if msg[i] != '$' {
   100  			continue
   101  		}
   102  		if i+1 == max || msg[i+1] != '{' {
   103  			continue
   104  		}
   105  
   106  		template.WriteString(msg[textStart:i])
   107  		template.WriteByte('$')
   108  		template.WriteString(strconv.Itoa(len(codes)))
   109  
   110  		// extract the code
   111  		code, err := extractMql(msg[i+2:])
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  
   116  		i += 2 + len(code)
   117  		if i >= max {
   118  			return nil, errors.New("cannot extract code in @msg (message ended before '}')")
   119  		}
   120  		if msg[i] != '}' {
   121  			return nil, errors.New("cannot extract code in @msg (expected '}' but got '" + string(msg[i]) + "')")
   122  		}
   123  		textStart = i + 1 // one past the closing '}'
   124  
   125  		codes = append(codes, code)
   126  	}
   127  
   128  	template.WriteString(msg[textStart:])
   129  
   130  	res := llx.AssertionMessage{
   131  		Template: strings.Trim(template.String(), "\n\t "),
   132  	}
   133  
   134  	for i := range codes {
   135  		code := codes[i]
   136  
   137  		// Small helper for assertion messages:
   138  		// At the moment, the parser can't delineate if a given `{}` call
   139  		// is meant to be a map creation or a block call.
   140  		//
   141  		// When it is at the beginning of an operand it is always treated
   142  		// as a map creation, e.g.:
   143  		//     {a: 123, ...}             vs
   144  		//     something { block... }
   145  		//
   146  		// However, in the assertion message case we know it's generally
   147  		// not about map-creation. So we are using a workaround to more
   148  		// easily extract values via blocks.
   149  		//
   150  		// This approach is extremely limited. It works with the most
   151  		// straightforward use-case and prohibits map any type of map
   152  		// creation in assertion messages.
   153  		//
   154  		// TODO: Find a more appropriate solution for this problem.
   155  		// Identify use-cases we don't cover well with this approach
   156  		// before changing it.
   157  
   158  		code = strings.Trim(code, " \t\n")
   159  		if code[0] == '{' {
   160  			code = "_" + code
   161  		}
   162  
   163  		ast, err := parser.Parse(code)
   164  		if err != nil {
   165  			return nil, errors.New("cannot parse code block in comment: " + code)
   166  		}
   167  
   168  		if len(ast.Expressions) == 0 {
   169  			return nil, errors.New("can't have empty calls to `${}` in comments")
   170  		}
   171  		if len(ast.Expressions) > 1 {
   172  			return nil, errors.New("can't have more than one value in `${}`")
   173  		}
   174  		expression := ast.Expressions[0]
   175  
   176  		ref, err := c.compileAndAddExpression(expression)
   177  		if err != nil {
   178  			return nil, errors.New("failed to compile comment: " + err.Error())
   179  		}
   180  
   181  		res.Refs = append(res.Refs, ref)
   182  
   183  		c.block.Datapoints = append(c.block.Datapoints, ref)
   184  	}
   185  
   186  	return &res, nil
   187  }
   188  
   189  func compileListAssertionMsg(c *compiler, typ types.Type, allRef uint64, failedRef uint64, assertionRef uint64) error {
   190  	// assertions
   191  	msg := extractMsgTag(c.comment)
   192  	if msg == "" {
   193  		return nil
   194  	}
   195  
   196  	blockCompiler := c.newBlockCompiler(&variable{
   197  		typ: typ,
   198  		ref: failedRef,
   199  	})
   200  
   201  	blockCompiler.vars.add("$expected", variable{ref: allRef, typ: typ})
   202  
   203  	assertionMsg, err := compileAssertionMsg(msg, &blockCompiler)
   204  	if err != nil {
   205  		return err
   206  	}
   207  	if assertionMsg != nil {
   208  		if c.Result.CodeV2.Assertions == nil {
   209  			c.Result.CodeV2.Assertions = make(map[uint64]*llx.AssertionMessage)
   210  		}
   211  		c.Result.CodeV2.Assertions[assertionRef+2] = assertionMsg
   212  
   213  		args := []*llx.Primitive{
   214  			llx.FunctionPrimitive(blockCompiler.blockRef),
   215  		}
   216  		for _, v := range blockCompiler.blockDeps {
   217  			if c.isInMyBlock(v) {
   218  				args = append(args, llx.RefPrimitiveV2(v))
   219  			}
   220  		}
   221  		c.blockDeps = append(c.blockDeps, blockCompiler.blockDeps...)
   222  		c.addChunk(&llx.Chunk{
   223  			Call: llx.Chunk_FUNCTION,
   224  			Id:   "${}",
   225  			Function: &llx.Function{
   226  				Type:    string(types.Block),
   227  				Binding: failedRef,
   228  				Args:    args,
   229  			},
   230  		})
   231  
   232  		// since it operators on top of a block, we have to add its
   233  		// checksum as the first entry in the list. Once the block is received,
   234  		// all of its child entries are processed for the final result
   235  		blockRef := c.block.TailRef(c.blockRef)
   236  		checksum := c.Result.CodeV2.Checksums[blockRef]
   237  		assertionMsg.Checksums = make([]string, len(assertionMsg.Refs)+1)
   238  		assertionMsg.Checksums[0] = checksum
   239  		c.block.Datapoints = append(c.Result.CodeV2.Blocks[0].Datapoints, blockRef)
   240  
   241  		blocksums := blockCompiler.Result.CodeV2.Checksums
   242  		for i := range assertionMsg.Refs {
   243  			sum, ok := blocksums[assertionMsg.Refs[i]]
   244  			if !ok {
   245  				return errors.New("cannot find checksum for datapoint in @msg tag")
   246  			}
   247  
   248  			assertionMsg.Checksums[i+1] = sum
   249  		}
   250  		assertionMsg.Refs = nil
   251  		// panic("Something about blocks decoding...")
   252  		assertionMsg.DecodeBlock = true
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  // UpdateAssertions in a bundle and remove all intermediate assertion objects
   259  func UpdateAssertions(bundle *llx.CodeBundle) error {
   260  	bundle.Assertions = map[string]*llx.AssertionMessage{}
   261  	return updateCodeAssertions(bundle, bundle.CodeV2)
   262  }
   263  
   264  func updateCodeAssertions(bundle *llx.CodeBundle, code *llx.CodeV2) error {
   265  	for ref, assert := range code.Assertions {
   266  		sum, ok := code.Checksums[ref]
   267  		if !ok {
   268  			return errors.New("cannot find reference for assertion")
   269  		}
   270  
   271  		if !assert.DecodeBlock {
   272  			assert.Checksums = make([]string, len(assert.Refs))
   273  			for i := range assert.Refs {
   274  				ref := assert.Refs[i]
   275  				assert.Checksums[i], ok = code.Checksums[ref]
   276  				if !ok {
   277  					return errors.New("cannot find reference to data in assertion")
   278  				}
   279  			}
   280  			assert.Refs = nil
   281  		}
   282  
   283  		bundle.Assertions[sum] = assert
   284  	}
   285  	code.Assertions = nil
   286  
   287  	return nil
   288  }