github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/prog/expr.go (about)

     1  // Copyright 2023 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  )
    10  
    11  func (bo BinaryExpression) Evaluate(finder ArgFinder) (uint64, bool) {
    12  	left, ok := bo.Left.Evaluate(finder)
    13  	if !ok {
    14  		return 0, false
    15  	}
    16  	right, ok := bo.Right.Evaluate(finder)
    17  	if !ok {
    18  		return 0, false
    19  	}
    20  	switch bo.Operator {
    21  	case OperatorCompareEq:
    22  		if left == right {
    23  			return 1, true
    24  		}
    25  		return 0, true
    26  	case OperatorCompareNeq:
    27  		if left != right {
    28  			return 1, true
    29  		}
    30  		return 0, true
    31  	case OperatorBinaryAnd:
    32  		return left & right, true
    33  	case OperatorOr:
    34  		if left != 0 || right != 0 {
    35  			return 1, true
    36  		}
    37  		return 0, true
    38  	}
    39  	panic(fmt.Sprintf("unknown operator %q", bo.Operator))
    40  }
    41  
    42  func (v *Value) Evaluate(finder ArgFinder) (uint64, bool) {
    43  	if len(v.Path) == 0 {
    44  		return v.Value, true
    45  	}
    46  	found := finder(v.Path)
    47  	if found == SquashedArgFound {
    48  		// This is expectable.
    49  		return 0, false
    50  	}
    51  	if found == nil {
    52  		panic(fmt.Sprintf("no argument was found by %v", v.Path))
    53  	}
    54  	constArg, ok := found.(*ConstArg)
    55  	if !ok {
    56  		panic("value expressions must only rely on int fields")
    57  	}
    58  	return constArg.Val, true
    59  }
    60  
    61  func makeArgFinder(t *Target, c *Call, unionArg *UnionArg, parents parentStack) ArgFinder {
    62  	return func(path []string) Arg {
    63  		f := t.findArg(unionArg.Option, path, nil, nil, parents, 0)
    64  		if f == nil {
    65  			return nil
    66  		}
    67  		if f.isAnyPtr {
    68  			return SquashedArgFound
    69  		}
    70  		return f.arg
    71  	}
    72  }
    73  
    74  func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, changed bool) {
    75  	if r.patchConditionalDepth > 1 {
    76  		// Some nested patchConditionalFields() calls are fine as we could trigger a resource
    77  		// constructor via generateArg(). But since nested createResource() calls are prohibited,
    78  		// patchConditionalFields() should never be nested more than 2 times.
    79  		panic("third nested patchConditionalFields call")
    80  	}
    81  	r.patchConditionalDepth++
    82  	defer func() { r.patchConditionalDepth-- }()
    83  
    84  	var extraCalls []*Call
    85  	var anyPatched bool
    86  	for {
    87  		replace := map[Arg]Arg{}
    88  		forEachStaleUnion(r.target, c,
    89  			func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
    90  				idx := okIndices[r.Intn(len(okIndices))]
    91  				newType, newDir := unionType.Fields[idx].Type,
    92  					unionType.Fields[idx].Dir(unionArg.Dir())
    93  				newTypeArg, newCalls := r.generateArg(s, newType, newDir)
    94  				replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx)
    95  				extraCalls = append(extraCalls, newCalls...)
    96  				anyPatched = true
    97  			})
    98  		for old, new := range replace {
    99  			replaceArg(old, new)
   100  		}
   101  		// The newly inserted argument might contain more arguments we need
   102  		// to patch.
   103  		// Repeat until we have to change nothing.
   104  		if len(replace) == 0 {
   105  			break
   106  		}
   107  	}
   108  	return extraCalls, anyPatched
   109  }
   110  
   111  func forEachStaleUnion(target *Target, c *Call, cb func(*UnionArg, *UnionType, []int)) {
   112  	for _, callArg := range c.Args {
   113  		foreachSubArgWithStack(callArg, func(arg Arg, argCtx *ArgCtx) {
   114  			if target.isAnyPtr(arg.Type()) {
   115  				argCtx.Stop = true
   116  				return
   117  			}
   118  			unionArg, ok := arg.(*UnionArg)
   119  			if !ok {
   120  				return
   121  			}
   122  			unionType, ok := arg.Type().(*UnionType)
   123  			if !ok || !unionType.isConditional() {
   124  				return
   125  			}
   126  			argFinder := makeArgFinder(target, c, unionArg, argCtx.parentStack)
   127  			ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder)
   128  			if !calculated {
   129  				// Let it stay as is.
   130  				return
   131  			}
   132  			if !unionArg.transient && ok {
   133  				return
   134  			}
   135  			matchingIndices := matchingUnionArgs(unionType, argFinder)
   136  			if len(matchingIndices) == 0 {
   137  				// Conditional fields are transformed in such a way
   138  				// that one field always matches.
   139  				// For unions we demand that there's a field w/o conditions.
   140  				panic(fmt.Sprintf("no matching union fields: %#v", unionType))
   141  			}
   142  			cb(unionArg, unionType, matchingIndices)
   143  		})
   144  	}
   145  }
   146  
   147  func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) {
   148  	field := typ.Fields[idx]
   149  	if field.Condition == nil {
   150  		return true, true
   151  	}
   152  	val, ok := field.Condition.Evaluate(finder)
   153  	if !ok {
   154  		// We could not calculate the expression.
   155  		// Let the union stay as it was.
   156  		return true, false
   157  	}
   158  	return val != 0, true
   159  }
   160  
   161  func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int {
   162  	var ret []int
   163  	for i := range typ.Fields {
   164  		ok, _ := checkUnionArg(i, typ, finder)
   165  		if ok {
   166  			ret = append(ret, i)
   167  		}
   168  	}
   169  	return ret
   170  }
   171  
   172  func (p *Prog) checkConditions() error {
   173  	for _, c := range p.Calls {
   174  		err := c.checkConditions(p.Target, false)
   175  		if err != nil {
   176  			return err
   177  		}
   178  	}
   179  	return nil
   180  }
   181  
   182  var ErrViolatedConditions = errors.New("conditional fields rules violation")
   183  
   184  func (c *Call) checkConditions(target *Target, ignoreTransient bool) error {
   185  	var ret error
   186  	forEachStaleUnion(target, c,
   187  		func(a *UnionArg, t *UnionType, okIndices []int) {
   188  			if ignoreTransient && a.transient {
   189  				return
   190  			}
   191  			ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions",
   192  				ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name,
   193  				okIndices)
   194  		})
   195  	return ret
   196  }
   197  
   198  func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool {
   199  	var anyReplaced bool
   200  	// Replace stale conditions with the default values of their correct types.
   201  	for {
   202  		replace := map[Arg]Arg{}
   203  		forEachStaleUnion(target, c,
   204  			func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
   205  				if transientOnly && !unionArg.transient {
   206  					return
   207  				}
   208  				idx := okIndices[0]
   209  				if defIdx, ok := unionType.defaultField(); ok {
   210  					// If there's a default value available, use it.
   211  					idx = defIdx
   212  				}
   213  				field := unionType.Fields[idx]
   214  				replace[unionArg] = MakeUnionArg(unionType,
   215  					unionArg.Dir(),
   216  					field.DefaultArg(field.Dir(unionArg.Dir())),
   217  					idx)
   218  			})
   219  		for old, new := range replace {
   220  			anyReplaced = true
   221  			replaceArg(old, new)
   222  		}
   223  		if len(replace) == 0 {
   224  			break
   225  		}
   226  	}
   227  	return anyReplaced
   228  }