github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/expr.go (about) 1 // Copyright 2023 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package prog 5 6 import ( 7 "errors" 8 "fmt" 9 ) 10 11 func (bo *BinaryExpression) Evaluate(finder ArgFinder) (uint64, bool) { 12 left, ok := bo.Left.Evaluate(finder) 13 if !ok { 14 return 0, false 15 } 16 right, ok := bo.Right.Evaluate(finder) 17 if !ok { 18 return 0, false 19 } 20 switch bo.Operator { 21 case OperatorCompareEq: 22 if left == right { 23 return 1, true 24 } 25 return 0, true 26 case OperatorCompareNeq: 27 if left != right { 28 return 1, true 29 } 30 return 0, true 31 case OperatorBinaryAnd: 32 return left & right, true 33 } 34 panic(fmt.Sprintf("unknown operator %q", bo.Operator)) 35 } 36 37 func (v *Value) Evaluate(finder ArgFinder) (uint64, bool) { 38 if len(v.Path) == 0 { 39 return v.Value, true 40 } 41 found := finder(v.Path) 42 if found == SquashedArgFound { 43 // This is expectable. 44 return 0, false 45 } 46 if found == nil { 47 panic(fmt.Sprintf("no argument was found by %v", v.Path)) 48 } 49 constArg, ok := found.(*ConstArg) 50 if !ok { 51 panic("value expressions must only rely on int fields") 52 } 53 return constArg.Val, true 54 } 55 56 func makeArgFinder(t *Target, c *Call, unionArg *UnionArg, parents parentStack) ArgFinder { 57 return func(path []string) Arg { 58 f := t.findArg(unionArg.Option, path, nil, nil, parents, 0) 59 if f == nil { 60 return nil 61 } 62 if f.isAnyPtr { 63 return SquashedArgFound 64 } 65 return f.arg 66 } 67 } 68 69 func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, changed bool) { 70 if r.inPatchConditional { 71 return nil, false 72 } 73 r.inPatchConditional = true 74 defer func() { r.inPatchConditional = false }() 75 76 var extraCalls []*Call 77 var anyPatched bool 78 for { 79 replace := map[Arg]Arg{} 80 forEachStaleUnion(r.target, c, 81 func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { 82 idx := okIndices[r.Intn(len(okIndices))] 83 newType, newDir := unionType.Fields[idx].Type, 84 unionType.Fields[idx].Dir(unionArg.Dir()) 85 newTypeArg, newCalls := r.generateArg(s, newType, newDir) 86 replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx) 87 extraCalls = append(extraCalls, newCalls...) 88 anyPatched = true 89 }) 90 for old, new := range replace { 91 replaceArg(old, new) 92 } 93 // The newly inserted argument might contain more arguments we need 94 // to patch. 95 // Repeat until we have to change nothing. 96 if len(replace) == 0 { 97 break 98 } 99 } 100 return extraCalls, anyPatched 101 } 102 103 func forEachStaleUnion(target *Target, c *Call, cb func(*UnionArg, *UnionType, []int)) { 104 for _, callArg := range c.Args { 105 foreachSubArgWithStack(callArg, func(arg Arg, argCtx *ArgCtx) { 106 if target.isAnyPtr(arg.Type()) { 107 argCtx.Stop = true 108 return 109 } 110 unionArg, ok := arg.(*UnionArg) 111 if !ok { 112 return 113 } 114 unionType, ok := arg.Type().(*UnionType) 115 if !ok || !unionType.isConditional() { 116 return 117 } 118 argFinder := makeArgFinder(target, c, unionArg, argCtx.parentStack) 119 ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder) 120 if !calculated { 121 // Let it stay as is. 122 return 123 } 124 if !unionArg.transient && ok { 125 return 126 } 127 matchingIndices := matchingUnionArgs(unionType, argFinder) 128 if len(matchingIndices) == 0 { 129 // Conditional fields are transformed in such a way 130 // that one field always matches. 131 // For unions we demand that there's a field w/o conditions. 132 panic(fmt.Sprintf("no matching union fields: %#v", unionType)) 133 } 134 cb(unionArg, unionType, matchingIndices) 135 }) 136 } 137 } 138 139 func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) { 140 field := typ.Fields[idx] 141 if field.Condition == nil { 142 return true, true 143 } 144 val, ok := field.Condition.Evaluate(finder) 145 if !ok { 146 // We could not calculate the expression. 147 // Let the union stay as it was. 148 return true, false 149 } 150 return val != 0, true 151 } 152 153 func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int { 154 var ret []int 155 for i := range typ.Fields { 156 ok, _ := checkUnionArg(i, typ, finder) 157 if ok { 158 ret = append(ret, i) 159 } 160 } 161 return ret 162 } 163 164 func (p *Prog) checkConditions() error { 165 for _, c := range p.Calls { 166 err := c.checkConditions(p.Target, false) 167 if err != nil { 168 return err 169 } 170 } 171 return nil 172 } 173 174 var ErrViolatedConditions = errors.New("conditional fields rules violation") 175 176 func (c *Call) checkConditions(target *Target, ignoreTransient bool) error { 177 var ret error 178 forEachStaleUnion(target, c, 179 func(a *UnionArg, t *UnionType, okIndices []int) { 180 if ignoreTransient && a.transient { 181 return 182 } 183 ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions", 184 ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name, 185 okIndices) 186 }) 187 return ret 188 } 189 190 func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool { 191 var anyReplaced bool 192 // Replace stale conditions with the default values of their correct types. 193 for { 194 replace := map[Arg]Arg{} 195 forEachStaleUnion(target, c, 196 func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { 197 if transientOnly && !unionArg.transient { 198 return 199 } 200 // If several union options match, take the first one. 201 idx := okIndices[0] 202 field := unionType.Fields[idx] 203 replace[unionArg] = MakeUnionArg(unionType, 204 unionArg.Dir(), 205 field.DefaultArg(field.Dir(unionArg.Dir())), 206 idx) 207 }) 208 for old, new := range replace { 209 anyReplaced = true 210 replaceArg(old, new) 211 } 212 if len(replace) == 0 { 213 break 214 } 215 } 216 return anyReplaced 217 }