github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/prio.go (about)

     1  // Copyright 2015/2016 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"math/rand"
    10  	"sort"
    11  )
    12  
    13  // Calulation of call-to-call priorities.
    14  // For a given pair of calls X and Y, the priority is our guess as to whether
    15  // additional of call Y into a program containing call X is likely to give
    16  // new coverage or not.
    17  // The current algorithm has two components: static and dynamic.
    18  // The static component is based on analysis of argument types. For example,
    19  // if call X and call Y both accept fd[sock], then they are more likely to give
    20  // new coverage together.
    21  // The dynamic component is based on frequency of occurrence of a particular
    22  // pair of syscalls in a single program in corpus. For example, if socket and
    23  // connect frequently occur in programs together, we give higher priority to
    24  // this pair of syscalls.
    25  // Note: the current implementation is very basic, there is no theory behind any
    26  // constants.
    27  
    28  func (target *Target) CalculatePriorities(corpus []*Prog) [][]int32 {
    29  	static := target.calcStaticPriorities()
    30  	if len(corpus) != 0 {
    31  		// Let's just sum the static and dynamic distributions.
    32  		dynamic := target.calcDynamicPrio(corpus)
    33  		for i, prios := range dynamic {
    34  			dst := static[i]
    35  			for j, p := range prios {
    36  				dst[j] += p
    37  			}
    38  		}
    39  	}
    40  	return static
    41  }
    42  
    43  func (target *Target) calcStaticPriorities() [][]int32 {
    44  	uses := target.calcResourceUsage()
    45  	prios := make([][]int32, len(target.Syscalls))
    46  	for i := range prios {
    47  		prios[i] = make([]int32, len(target.Syscalls))
    48  	}
    49  	for _, weights := range uses {
    50  		for _, w0 := range weights {
    51  			for _, w1 := range weights {
    52  				if w0.call == w1.call {
    53  					// Self-priority is assigned below.
    54  					continue
    55  				}
    56  				// The static priority is assigned based on the direction of arguments. A higher priority will be
    57  				// assigned when c0 is a call that produces a resource and c1 a call that uses that resource.
    58  				prios[w0.call][w1.call] += w0.inout*w1.in*3/2 + w0.inout*w1.inout
    59  			}
    60  		}
    61  	}
    62  	// The value assigned for self-priority (call wrt itself) have to be high, but not too high.
    63  	for c0, pp := range prios {
    64  		var max int32
    65  		for _, p := range pp {
    66  			if p > max {
    67  				max = p
    68  			}
    69  		}
    70  		if max == 0 {
    71  			pp[c0] = 1
    72  		} else {
    73  			pp[c0] = max * 3 / 4
    74  		}
    75  	}
    76  	normalizePrios(prios)
    77  	return prios
    78  }
    79  
    80  func (target *Target) calcResourceUsage() map[string]map[int]weights {
    81  	uses := make(map[string]map[int]weights)
    82  	ForeachType(target.Syscalls, func(t Type, ctx *TypeCtx) {
    83  		c := ctx.Meta
    84  		switch a := t.(type) {
    85  		case *ResourceType:
    86  			if target.AuxResources[a.Desc.Name] {
    87  				noteUsage(uses, c, 1, ctx.Dir, "res%v", a.Desc.Name)
    88  			} else {
    89  				str := "res"
    90  				for i, k := range a.Desc.Kind {
    91  					str += "-" + k
    92  					w := int32(10)
    93  					if i < len(a.Desc.Kind)-1 {
    94  						w = 2
    95  					}
    96  					noteUsage(uses, c, w, ctx.Dir, str)
    97  				}
    98  			}
    99  		case *PtrType:
   100  			if _, ok := a.Elem.(*StructType); ok {
   101  				noteUsage(uses, c, 10, ctx.Dir, "ptrto-%v", a.Elem.Name())
   102  			}
   103  			if _, ok := a.Elem.(*UnionType); ok {
   104  				noteUsage(uses, c, 10, ctx.Dir, "ptrto-%v", a.Elem.Name())
   105  			}
   106  			if arr, ok := a.Elem.(*ArrayType); ok {
   107  				noteUsage(uses, c, 10, ctx.Dir, "ptrto-%v", arr.Elem.Name())
   108  			}
   109  		case *BufferType:
   110  			switch a.Kind {
   111  			case BufferBlobRand, BufferBlobRange, BufferText, BufferCompressed:
   112  			case BufferString, BufferGlob:
   113  				if a.SubKind != "" {
   114  					noteUsage(uses, c, 2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind))
   115  				}
   116  			case BufferFilename:
   117  				noteUsage(uses, c, 10, DirIn, "filename")
   118  			default:
   119  				panic("unknown buffer kind")
   120  			}
   121  		case *VmaType:
   122  			noteUsage(uses, c, 5, ctx.Dir, "vma")
   123  		case *IntType:
   124  			switch a.Kind {
   125  			case IntPlain, IntRange:
   126  			default:
   127  				panic("unknown int kind")
   128  			}
   129  		}
   130  	})
   131  	return uses
   132  }
   133  
   134  type weights struct {
   135  	call  int
   136  	in    int32
   137  	inout int32
   138  }
   139  
   140  func noteUsage(uses map[string]map[int]weights, c *Syscall, weight int32, dir Dir, str string, args ...interface{}) {
   141  	id := fmt.Sprintf(str, args...)
   142  	if uses[id] == nil {
   143  		uses[id] = make(map[int]weights)
   144  	}
   145  	callWeight := uses[id][c.ID]
   146  	callWeight.call = c.ID
   147  	if dir != DirOut {
   148  		if weight > uses[id][c.ID].in {
   149  			callWeight.in = weight
   150  		}
   151  	}
   152  	if weight > uses[id][c.ID].inout {
   153  		callWeight.inout = weight
   154  	}
   155  	uses[id][c.ID] = callWeight
   156  }
   157  
   158  func (target *Target) calcDynamicPrio(corpus []*Prog) [][]int32 {
   159  	prios := make([][]int32, len(target.Syscalls))
   160  	for i := range prios {
   161  		prios[i] = make([]int32, len(target.Syscalls))
   162  	}
   163  	for _, p := range corpus {
   164  		for idx0, c0 := range p.Calls {
   165  			for _, c1 := range p.Calls[idx0+1:] {
   166  				prios[c0.Meta.ID][c1.Meta.ID]++
   167  			}
   168  		}
   169  	}
   170  	for i := range prios {
   171  		for j, val := range prios[i] {
   172  			// It's more important that some calls do coexist than whether
   173  			// it happened 50 or 100 times.
   174  			// Let's use sqrt() to lessen the effect of large counts.
   175  			prios[i][j] = int32(2.0 * math.Sqrt(float64(val)))
   176  		}
   177  	}
   178  	normalizePrios(prios)
   179  	return prios
   180  }
   181  
   182  // normalizePrio distributes |N| * 10 points proportional to the values in the matrix.
   183  func normalizePrios(prios [][]int32) {
   184  	total := 10 * int32(len(prios))
   185  	for _, prio := range prios {
   186  		sum := int32(0)
   187  		for _, p := range prio {
   188  			sum += p
   189  		}
   190  		if sum == 0 {
   191  			continue
   192  		}
   193  		for i, p := range prio {
   194  			prio[i] = p * total / sum
   195  		}
   196  	}
   197  }
   198  
   199  // ChooseTable allows to do a weighted choice of a syscall for a given syscall
   200  // based on call-to-call priorities and a set of enabled and generatable syscalls.
   201  type ChoiceTable struct {
   202  	target *Target
   203  	runs   [][]int32
   204  	calls  []*Syscall
   205  }
   206  
   207  func (target *Target) BuildChoiceTable(corpus []*Prog, enabled map[*Syscall]bool) *ChoiceTable {
   208  	if enabled == nil {
   209  		enabled = make(map[*Syscall]bool)
   210  		for _, c := range target.Syscalls {
   211  			enabled[c] = true
   212  		}
   213  	}
   214  	noGenerateCalls := make(map[int]bool)
   215  	enabledCalls := make(map[*Syscall]bool)
   216  	for call := range enabled {
   217  		if call.Attrs.NoGenerate {
   218  			noGenerateCalls[call.ID] = true
   219  		} else if !call.Attrs.Disabled {
   220  			enabledCalls[call] = true
   221  		}
   222  	}
   223  	var generatableCalls []*Syscall
   224  	for c := range enabledCalls {
   225  		generatableCalls = append(generatableCalls, c)
   226  	}
   227  	if len(generatableCalls) == 0 {
   228  		panic("no syscalls enabled and generatable")
   229  	}
   230  	sort.Slice(generatableCalls, func(i, j int) bool {
   231  		return generatableCalls[i].ID < generatableCalls[j].ID
   232  	})
   233  	for _, p := range corpus {
   234  		for _, call := range p.Calls {
   235  			if !enabledCalls[call.Meta] && !noGenerateCalls[call.Meta.ID] {
   236  				fmt.Printf("corpus contains disabled syscall %v\n", call.Meta.Name)
   237  				for call := range enabled {
   238  					fmt.Printf("%s: enabled\n", call.Name)
   239  				}
   240  				panic("disabled syscall")
   241  			}
   242  		}
   243  	}
   244  	prios := target.CalculatePriorities(corpus)
   245  	run := make([][]int32, len(target.Syscalls))
   246  	// ChoiceTable.runs[][] contains cumulated sum of weighted priority numbers.
   247  	// This helps in quick binary search with biases when generating programs.
   248  	// This only applies for system calls that are enabled for the target.
   249  	for i := range run {
   250  		if !enabledCalls[target.Syscalls[i]] {
   251  			continue
   252  		}
   253  		run[i] = make([]int32, len(target.Syscalls))
   254  		var sum int32
   255  		for j := range run[i] {
   256  			if enabledCalls[target.Syscalls[j]] {
   257  				sum += prios[i][j]
   258  			}
   259  			run[i][j] = sum
   260  		}
   261  	}
   262  	return &ChoiceTable{target, run, generatableCalls}
   263  }
   264  
   265  func (ct *ChoiceTable) Generatable(call int) bool {
   266  	return ct.runs[call] != nil
   267  }
   268  
   269  func (ct *ChoiceTable) choose(r *rand.Rand, bias int) int {
   270  	if r.Intn(100) < 5 {
   271  		// Let's make 5% decisions totally at random.
   272  		return ct.calls[r.Intn(len(ct.calls))].ID
   273  	}
   274  	if bias < 0 {
   275  		bias = ct.calls[r.Intn(len(ct.calls))].ID
   276  	}
   277  	if !ct.Generatable(bias) {
   278  		fmt.Printf("bias to disabled or non-generatable syscall %v\n", ct.target.Syscalls[bias].Name)
   279  		panic("disabled or non-generatable syscall")
   280  	}
   281  	run := ct.runs[bias]
   282  	runSum := int(run[len(run)-1])
   283  	x := int32(r.Intn(runSum) + 1)
   284  	res := sort.Search(len(run), func(i int) bool {
   285  		return run[i] >= x
   286  	})
   287  	if !ct.Generatable(res) {
   288  		panic("selected disabled or non-generatable syscall")
   289  	}
   290  	return res
   291  }