github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/expr.go (about)

     1  // Copyright 2023 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  )
    10  
    11  func (bo *BinaryExpression) Evaluate(finder ArgFinder) (uint64, bool) {
    12  	left, ok := bo.Left.Evaluate(finder)
    13  	if !ok {
    14  		return 0, false
    15  	}
    16  	right, ok := bo.Right.Evaluate(finder)
    17  	if !ok {
    18  		return 0, false
    19  	}
    20  	switch bo.Operator {
    21  	case OperatorCompareEq:
    22  		if left == right {
    23  			return 1, true
    24  		}
    25  		return 0, true
    26  	case OperatorCompareNeq:
    27  		if left != right {
    28  			return 1, true
    29  		}
    30  		return 0, true
    31  	case OperatorBinaryAnd:
    32  		return left & right, true
    33  	}
    34  	panic(fmt.Sprintf("unknown operator %q", bo.Operator))
    35  }
    36  
    37  func (v *Value) Evaluate(finder ArgFinder) (uint64, bool) {
    38  	if len(v.Path) == 0 {
    39  		return v.Value, true
    40  	}
    41  	found := finder(v.Path)
    42  	if found == SquashedArgFound {
    43  		// This is expectable.
    44  		return 0, false
    45  	}
    46  	if found == nil {
    47  		panic(fmt.Sprintf("no argument was found by %v", v.Path))
    48  	}
    49  	constArg, ok := found.(*ConstArg)
    50  	if !ok {
    51  		panic("value expressions must only rely on int fields")
    52  	}
    53  	return constArg.Val, true
    54  }
    55  
    56  func makeArgFinder(t *Target, c *Call, unionArg *UnionArg, parents parentStack) ArgFinder {
    57  	return func(path []string) Arg {
    58  		f := t.findArg(unionArg.Option, path, nil, nil, parents, 0)
    59  		if f == nil {
    60  			return nil
    61  		}
    62  		if f.isAnyPtr {
    63  			return SquashedArgFound
    64  		}
    65  		return f.arg
    66  	}
    67  }
    68  
    69  func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, changed bool) {
    70  	if r.inPatchConditional {
    71  		return nil, false
    72  	}
    73  	r.inPatchConditional = true
    74  	defer func() { r.inPatchConditional = false }()
    75  
    76  	var extraCalls []*Call
    77  	var anyPatched bool
    78  	for {
    79  		replace := map[Arg]Arg{}
    80  		forEachStaleUnion(r.target, c,
    81  			func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
    82  				idx := okIndices[r.Intn(len(okIndices))]
    83  				newType, newDir := unionType.Fields[idx].Type,
    84  					unionType.Fields[idx].Dir(unionArg.Dir())
    85  				newTypeArg, newCalls := r.generateArg(s, newType, newDir)
    86  				replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx)
    87  				extraCalls = append(extraCalls, newCalls...)
    88  				anyPatched = true
    89  			})
    90  		for old, new := range replace {
    91  			replaceArg(old, new)
    92  		}
    93  		// The newly inserted argument might contain more arguments we need
    94  		// to patch.
    95  		// Repeat until we have to change nothing.
    96  		if len(replace) == 0 {
    97  			break
    98  		}
    99  	}
   100  	return extraCalls, anyPatched
   101  }
   102  
   103  func forEachStaleUnion(target *Target, c *Call, cb func(*UnionArg, *UnionType, []int)) {
   104  	for _, callArg := range c.Args {
   105  		foreachSubArgWithStack(callArg, func(arg Arg, argCtx *ArgCtx) {
   106  			if target.isAnyPtr(arg.Type()) {
   107  				argCtx.Stop = true
   108  				return
   109  			}
   110  			unionArg, ok := arg.(*UnionArg)
   111  			if !ok {
   112  				return
   113  			}
   114  			unionType, ok := arg.Type().(*UnionType)
   115  			if !ok || !unionType.isConditional() {
   116  				return
   117  			}
   118  			argFinder := makeArgFinder(target, c, unionArg, argCtx.parentStack)
   119  			ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder)
   120  			if !calculated {
   121  				// Let it stay as is.
   122  				return
   123  			}
   124  			if !unionArg.transient && ok {
   125  				return
   126  			}
   127  			matchingIndices := matchingUnionArgs(unionType, argFinder)
   128  			if len(matchingIndices) == 0 {
   129  				// Conditional fields are transformed in such a way
   130  				// that one field always matches.
   131  				// For unions we demand that there's a field w/o conditions.
   132  				panic(fmt.Sprintf("no matching union fields: %#v", unionType))
   133  			}
   134  			cb(unionArg, unionType, matchingIndices)
   135  		})
   136  	}
   137  }
   138  
   139  func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) {
   140  	field := typ.Fields[idx]
   141  	if field.Condition == nil {
   142  		return true, true
   143  	}
   144  	val, ok := field.Condition.Evaluate(finder)
   145  	if !ok {
   146  		// We could not calculate the expression.
   147  		// Let the union stay as it was.
   148  		return true, false
   149  	}
   150  	return val != 0, true
   151  }
   152  
   153  func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int {
   154  	var ret []int
   155  	for i := range typ.Fields {
   156  		ok, _ := checkUnionArg(i, typ, finder)
   157  		if ok {
   158  			ret = append(ret, i)
   159  		}
   160  	}
   161  	return ret
   162  }
   163  
   164  func (p *Prog) checkConditions() error {
   165  	for _, c := range p.Calls {
   166  		err := c.checkConditions(p.Target, false)
   167  		if err != nil {
   168  			return err
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  var ErrViolatedConditions = errors.New("conditional fields rules violation")
   175  
   176  func (c *Call) checkConditions(target *Target, ignoreTransient bool) error {
   177  	var ret error
   178  	forEachStaleUnion(target, c,
   179  		func(a *UnionArg, t *UnionType, okIndices []int) {
   180  			if ignoreTransient && a.transient {
   181  				return
   182  			}
   183  			ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions",
   184  				ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name,
   185  				okIndices)
   186  		})
   187  	return ret
   188  }
   189  
   190  func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool {
   191  	var anyReplaced bool
   192  	// Replace stale conditions with the default values of their correct types.
   193  	for {
   194  		replace := map[Arg]Arg{}
   195  		forEachStaleUnion(target, c,
   196  			func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
   197  				if transientOnly && !unionArg.transient {
   198  					return
   199  				}
   200  				// If several union options match, take the first one.
   201  				idx := okIndices[0]
   202  				field := unionType.Fields[idx]
   203  				replace[unionArg] = MakeUnionArg(unionType,
   204  					unionArg.Dir(),
   205  					field.DefaultArg(field.Dir(unionArg.Dir())),
   206  					idx)
   207  			})
   208  		for old, new := range replace {
   209  			anyReplaced = true
   210  			replaceArg(old, new)
   211  		}
   212  		if len(replace) == 0 {
   213  			break
   214  		}
   215  	}
   216  	return anyReplaced
   217  }