github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/prog/prio.go (about)

     1  // Copyright 2015/2016 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"math/rand"
    10  	"slices"
    11  	"sort"
    12  )
    13  
    14  // Calulation of call-to-call priorities.
    15  // For a given pair of calls X and Y, the priority is our guess as to whether
    16  // additional of call Y into a program containing call X is likely to give
    17  // new coverage or not.
    18  // The current algorithm has two components: static and dynamic.
    19  // The static component is based on analysis of argument types. For example,
    20  // if call X and call Y both accept fd[sock], then they are more likely to give
    21  // new coverage together.
    22  // The dynamic component is based on frequency of occurrence of a particular
    23  // pair of syscalls in a single program in corpus. For example, if socket and
    24  // connect frequently occur in programs together, we give higher priority to
    25  // this pair of syscalls.
    26  // Note: the current implementation is very basic, there is no theory behind any
    27  // constants.
    28  
    29  // CalculatePriorities returns the priority matrix as well as the map of generatable syscalls.
    30  // The rows/columns corresponding to the non-generatable syscalls are left to be 0.
    31  func (target *Target) CalculatePriorities(corpus []*Prog, enabled map[*Syscall]bool) ([][]int32, map[*Syscall]bool) {
    32  	enabled = target.prepareEnabledSyscalls(corpus, enabled)
    33  	static := target.calcStaticPriorities(enabled)
    34  	if len(corpus) != 0 {
    35  		// Let's just sum the static and dynamic distributions.
    36  		dynamic := target.calcDynamicPrio(corpus, enabled)
    37  		for i, prios := range dynamic {
    38  			dst := static[i]
    39  			for j, p := range prios {
    40  				dst[j] += p
    41  			}
    42  		}
    43  	}
    44  	if debug {
    45  		for _, syscall := range target.Syscalls {
    46  			if enabled[syscall] {
    47  				continue
    48  			}
    49  			for i := range static {
    50  				if static[i][syscall.ID] != 0 || static[syscall.ID][i] != 0 {
    51  					panic(fmt.Sprintf("prio matrix has non-zero value for a disabled syscall %d",
    52  						syscall.ID))
    53  				}
    54  			}
    55  		}
    56  	}
    57  	return static, enabled
    58  }
    59  
    60  func (target *Target) prepareEnabledSyscalls(corpus []*Prog, enabled map[*Syscall]bool) map[*Syscall]bool {
    61  	if enabled == nil {
    62  		enabled = make(map[*Syscall]bool)
    63  		for _, c := range target.Syscalls {
    64  			enabled[c] = true
    65  		}
    66  	}
    67  	noGenerateCalls := make(map[int]bool)
    68  	enabledCalls := make(map[*Syscall]bool)
    69  	for call := range enabled {
    70  		if call.Attrs.NoGenerate {
    71  			noGenerateCalls[call.ID] = true
    72  		} else if !call.Attrs.Disabled {
    73  			enabledCalls[call] = true
    74  		}
    75  	}
    76  	// Some validation checks.
    77  	if len(enabledCalls) == 0 {
    78  		panic("no syscalls enabled and generatable")
    79  	}
    80  	for _, p := range corpus {
    81  		for _, call := range p.Calls {
    82  			if !enabledCalls[call.Meta] && !noGenerateCalls[call.Meta.ID] {
    83  				fmt.Printf("corpus contains disabled syscall %v\n", call.Meta.Name)
    84  				for call := range enabled {
    85  					fmt.Printf("%s: enabled\n", call.Name)
    86  				}
    87  				panic("disabled syscall")
    88  			}
    89  		}
    90  	}
    91  	return enabledCalls
    92  }
    93  
    94  func (target *Target) calcStaticPriorities(enabled map[*Syscall]bool) [][]int32 {
    95  	uses := target.calcResourceUsage(enabled)
    96  	prios := make([][]int32, len(target.Syscalls))
    97  	for i := range prios {
    98  		prios[i] = make([]int32, len(target.Syscalls))
    99  	}
   100  	for _, weights := range uses {
   101  		for _, w0 := range weights {
   102  			for _, w1 := range weights {
   103  				if w0.call == w1.call {
   104  					// Self-priority is assigned below.
   105  					continue
   106  				}
   107  				// The static priority is assigned based on the direction of arguments. A higher priority will be
   108  				// assigned when c0 is a call that produces a resource and c1 a call that uses that resource.
   109  				prios[w0.call][w1.call] += w0.inout*w1.in*3/2 + w0.inout*w1.inout
   110  			}
   111  		}
   112  	}
   113  	// The value assigned for self-priority (call wrt itself) have to be high, but not too high.
   114  	for c := range enabled {
   115  		id, pp := c.ID, prios[c.ID]
   116  		max := slices.Max(pp)
   117  		if max == 0 {
   118  			pp[id] = 1
   119  		} else {
   120  			pp[id] = max * 3 / 4
   121  		}
   122  	}
   123  	normalizePrios(prios, len(enabled))
   124  	return prios
   125  }
   126  
   127  func (target *Target) calcResourceUsage(enabled map[*Syscall]bool) map[string]map[int]weights {
   128  	uses := make(map[string]map[int]weights)
   129  	ForeachType(target.Syscalls, func(t Type, ctx *TypeCtx) {
   130  		c := ctx.Meta
   131  		if !enabled[c] {
   132  			ctx.Stop = true
   133  			return
   134  		}
   135  		switch a := t.(type) {
   136  		case *ResourceType:
   137  			if target.AuxResources[a.Desc.Name] {
   138  				noteUsagef(uses, c, 1, ctx.Dir, "res%v", a.Desc.Name)
   139  			} else {
   140  				str := "res"
   141  				for i, k := range a.Desc.Kind {
   142  					str += "-" + k
   143  					w := int32(10)
   144  					if i < len(a.Desc.Kind)-1 {
   145  						w = 2
   146  					}
   147  					noteUsage(uses, c, w, ctx.Dir, str)
   148  				}
   149  			}
   150  		case *PtrType:
   151  			if _, ok := a.Elem.(*StructType); ok {
   152  				noteUsagef(uses, c, 10, ctx.Dir, "ptrto-%v", a.Elem.Name())
   153  			}
   154  			if _, ok := a.Elem.(*UnionType); ok {
   155  				noteUsagef(uses, c, 10, ctx.Dir, "ptrto-%v", a.Elem.Name())
   156  			}
   157  			if arr, ok := a.Elem.(*ArrayType); ok {
   158  				noteUsagef(uses, c, 10, ctx.Dir, "ptrto-%v", arr.Elem.Name())
   159  			}
   160  		case *BufferType:
   161  			switch a.Kind {
   162  			case BufferBlobRand, BufferBlobRange, BufferText, BufferCompressed:
   163  			case BufferString, BufferGlob:
   164  				if a.SubKind != "" {
   165  					noteUsagef(uses, c, 2, ctx.Dir, "str-%v", a.SubKind)
   166  				}
   167  			case BufferFilename:
   168  				noteUsage(uses, c, 10, DirIn, "filename")
   169  			default:
   170  				panic("unknown buffer kind")
   171  			}
   172  		case *VmaType:
   173  			noteUsage(uses, c, 5, ctx.Dir, "vma")
   174  		case *IntType:
   175  			switch a.Kind {
   176  			case IntPlain, IntRange:
   177  			default:
   178  				panic("unknown int kind")
   179  			}
   180  		}
   181  	})
   182  	return uses
   183  }
   184  
   185  type weights struct {
   186  	call  int
   187  	in    int32
   188  	inout int32
   189  }
   190  
   191  func noteUsage(uses map[string]map[int]weights, c *Syscall, weight int32, dir Dir, str string) {
   192  	noteUsagef(uses, c, weight, dir, "%v", str)
   193  }
   194  
   195  func noteUsagef(uses map[string]map[int]weights, c *Syscall, weight int32, dir Dir, str string, args ...interface{}) {
   196  	id := fmt.Sprintf(str, args...)
   197  	if uses[id] == nil {
   198  		uses[id] = make(map[int]weights)
   199  	}
   200  	callWeight := uses[id][c.ID]
   201  	callWeight.call = c.ID
   202  	if dir != DirOut {
   203  		if weight > uses[id][c.ID].in {
   204  			callWeight.in = weight
   205  		}
   206  	}
   207  	if weight > uses[id][c.ID].inout {
   208  		callWeight.inout = weight
   209  	}
   210  	uses[id][c.ID] = callWeight
   211  }
   212  
   213  func (target *Target) calcDynamicPrio(corpus []*Prog, enabled map[*Syscall]bool) [][]int32 {
   214  	prios := make([][]int32, len(target.Syscalls))
   215  	for i := range prios {
   216  		prios[i] = make([]int32, len(target.Syscalls))
   217  	}
   218  	for _, p := range corpus {
   219  		for idx0, c0 := range p.Calls {
   220  			if !enabled[c0.Meta] {
   221  				continue
   222  			}
   223  			for _, c1 := range p.Calls[idx0+1:] {
   224  				if !enabled[c1.Meta] {
   225  					continue
   226  				}
   227  				prios[c0.Meta.ID][c1.Meta.ID]++
   228  			}
   229  		}
   230  	}
   231  	for i := range prios {
   232  		for j, val := range prios[i] {
   233  			// It's more important that some calls do coexist than whether
   234  			// it happened 50 or 100 times.
   235  			// Let's use sqrt() to lessen the effect of large counts.
   236  			prios[i][j] = int32(2.0 * math.Sqrt(float64(val)))
   237  		}
   238  	}
   239  	normalizePrios(prios, len(enabled))
   240  	return prios
   241  }
   242  
   243  // normalizePrio distributes |N| * 10 points proportional to the values in the matrix.
   244  // |N| is the number of the generatable syscalls.
   245  func normalizePrios(prios [][]int32, n int) {
   246  	total := 10 * int32(n)
   247  	for _, prio := range prios {
   248  		sum := int32(0)
   249  		for _, p := range prio {
   250  			sum += p
   251  		}
   252  		if sum == 0 {
   253  			continue
   254  		}
   255  		for i, p := range prio {
   256  			prio[i] = p * total / sum
   257  		}
   258  	}
   259  }
   260  
   261  // ChooseTable allows to do a weighted choice of a syscall for a given syscall
   262  // based on call-to-call priorities and a set of enabled and generatable syscalls.
   263  type ChoiceTable struct {
   264  	target *Target
   265  	runs   [][]int32
   266  	calls  []*Syscall
   267  }
   268  
   269  func (target *Target) BuildChoiceTable(corpus []*Prog, enabled map[*Syscall]bool) *ChoiceTable {
   270  	prios, enabledCalls := target.CalculatePriorities(corpus, enabled)
   271  	var generatableCalls []*Syscall
   272  	for c := range enabledCalls {
   273  		generatableCalls = append(generatableCalls, c)
   274  	}
   275  	sort.Slice(generatableCalls, func(i, j int) bool {
   276  		return generatableCalls[i].ID < generatableCalls[j].ID
   277  	})
   278  
   279  	run := make([][]int32, len(target.Syscalls))
   280  	// ChoiceTable.runs[][] contains cumulated sum of weighted priority numbers.
   281  	// This helps in quick binary search with biases when generating programs.
   282  	// This only applies for system calls that are enabled for the target.
   283  	for i := range run {
   284  		if !enabledCalls[target.Syscalls[i]] {
   285  			continue
   286  		}
   287  		run[i] = make([]int32, len(target.Syscalls))
   288  		var sum int32
   289  		for j := range run[i] {
   290  			if enabledCalls[target.Syscalls[j]] {
   291  				sum += prios[i][j]
   292  			}
   293  			run[i][j] = sum
   294  		}
   295  	}
   296  	return &ChoiceTable{target, run, generatableCalls}
   297  }
   298  
   299  func (ct *ChoiceTable) Generatable(call int) bool {
   300  	return ct.runs[call] != nil
   301  }
   302  
   303  func (ct *ChoiceTable) choose(r *rand.Rand, bias int) int {
   304  	if r.Intn(100) < 5 {
   305  		// Let's make 5% decisions totally at random.
   306  		return ct.calls[r.Intn(len(ct.calls))].ID
   307  	}
   308  	if bias < 0 {
   309  		bias = ct.calls[r.Intn(len(ct.calls))].ID
   310  	}
   311  	if !ct.Generatable(bias) {
   312  		fmt.Printf("bias to disabled or non-generatable syscall %v\n", ct.target.Syscalls[bias].Name)
   313  		panic("disabled or non-generatable syscall")
   314  	}
   315  	run := ct.runs[bias]
   316  	runSum := int(run[len(run)-1])
   317  	x := int32(r.Intn(runSum) + 1)
   318  	res := sort.Search(len(run), func(i int) bool {
   319  		return run[i] >= x
   320  	})
   321  	if !ct.Generatable(res) {
   322  		panic("selected disabled or non-generatable syscall")
   323  	}
   324  	return res
   325  }