github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/fuzzer/queue/queue.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package queue
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"encoding/gob"
    10  	"fmt"
    11  	"sync"
    12  	"sync/atomic"
    13  
    14  	"github.com/google/syzkaller/pkg/flatrpc"
    15  	"github.com/google/syzkaller/pkg/hash"
    16  	"github.com/google/syzkaller/pkg/ipc"
    17  	"github.com/google/syzkaller/pkg/signal"
    18  	"github.com/google/syzkaller/pkg/stats"
    19  	"github.com/google/syzkaller/prog"
    20  )
    21  
    22  type Request struct {
    23  	Prog     *prog.Prog
    24  	ExecOpts ipc.ExecOpts
    25  
    26  	// If specified, the resulting signal for call SignalFilterCall
    27  	// will include subset of it even if it's not new.
    28  	SignalFilter     signal.Signal
    29  	SignalFilterCall int
    30  
    31  	// By default, only the newly seen signal is returned.
    32  	// ReturnAllSignal tells the executor to return everything.
    33  	ReturnAllSignal bool
    34  	ReturnError     bool
    35  	ReturnOutput    bool
    36  
    37  	// This stat will be incremented on request completion.
    38  	Stat *stats.Val
    39  
    40  	// Options needed by runtest.
    41  	BinaryFile string // If set, it's executed instead of Prog.
    42  	Repeat     int    // Repeats in addition to the first run.
    43  
    44  	// Important requests will be retried even from crashed VMs.
    45  	Important bool
    46  
    47  	// The callback will be called on request completion in the LIFO order.
    48  	// If it returns false, all further processing will be stopped.
    49  	// It allows wrappers to intercept Done() requests.
    50  	callback DoneCallback
    51  
    52  	onceCrashed bool
    53  
    54  	mu     sync.Mutex
    55  	result *Result
    56  	done   chan struct{}
    57  }
    58  
    59  type DoneCallback func(*Request, *Result) bool
    60  
    61  func (r *Request) OnDone(cb DoneCallback) {
    62  	oldCallback := r.callback
    63  	r.callback = func(req *Request, res *Result) bool {
    64  		r.callback = oldCallback
    65  		if !cb(req, res) {
    66  			return false
    67  		}
    68  		if oldCallback == nil {
    69  			return true
    70  		}
    71  		return oldCallback(req, res)
    72  	}
    73  }
    74  
    75  func (r *Request) Done(res *Result) {
    76  	if r.callback != nil {
    77  		if !r.callback(r, res) {
    78  			return
    79  		}
    80  	}
    81  	if r.Stat != nil {
    82  		r.Stat.Add(1)
    83  	}
    84  	r.initChannel()
    85  	r.result = res
    86  	close(r.done)
    87  }
    88  
    89  // Wait() blocks until we have the result.
    90  func (r *Request) Wait(ctx context.Context) *Result {
    91  	r.initChannel()
    92  	select {
    93  	case <-ctx.Done():
    94  		return &Result{Status: ExecFailure}
    95  	case <-r.done:
    96  		return r.result
    97  	}
    98  }
    99  
   100  // Risky() returns true if there's a substantial risk of the input crashing the VM.
   101  func (r *Request) Risky() bool {
   102  	return r.onceCrashed
   103  }
   104  
   105  func (r *Request) Validate() error {
   106  	collectSignal := r.ExecOpts.ExecFlags&flatrpc.ExecFlagCollectSignal > 0
   107  	if r.ReturnAllSignal && !collectSignal {
   108  		return fmt.Errorf("ReturnAllSignal is set, but FlagCollectSignal is not")
   109  	}
   110  	if r.SignalFilter != nil && !collectSignal {
   111  		return fmt.Errorf("SignalFilter must be used with FlagCollectSignal")
   112  	}
   113  	collectComps := r.ExecOpts.ExecFlags&flatrpc.ExecFlagCollectComps > 0
   114  	collectCover := r.ExecOpts.ExecFlags&flatrpc.ExecFlagCollectCover > 0
   115  	if (collectComps) && (collectSignal || collectCover) {
   116  		return fmt.Errorf("hint collection is mutually exclusive with signal/coverage")
   117  	}
   118  	return nil
   119  }
   120  
   121  func (r *Request) hash() hash.Sig {
   122  	buf := new(bytes.Buffer)
   123  	if err := gob.NewEncoder(buf).Encode(r.ExecOpts); err != nil {
   124  		panic(err)
   125  	}
   126  	return hash.Hash(r.Prog.Serialize(), buf.Bytes())
   127  }
   128  
   129  func (r *Request) initChannel() {
   130  	r.mu.Lock()
   131  	if r.done == nil {
   132  		r.done = make(chan struct{})
   133  	}
   134  	r.mu.Unlock()
   135  }
   136  
   137  type Result struct {
   138  	Info   *ipc.ProgInfo
   139  	Output []byte
   140  	Status Status
   141  	Err    error // More details in case of ExecFailure.
   142  }
   143  
   144  func (r *Result) clone() *Result {
   145  	ret := *r
   146  	if ret.Info != nil {
   147  		ret.Info = ret.Info.Clone()
   148  	}
   149  	return &ret
   150  }
   151  
   152  func (r *Result) Stop() bool {
   153  	return r.Status == ExecFailure || r.Status == Crashed
   154  }
   155  
   156  type Status int
   157  
   158  const (
   159  	Success     Status = iota
   160  	ExecFailure        // For e.g. serialization errors.
   161  	Crashed            // The VM crashed holding the request.
   162  	Restarted          // The VM was restarted holding the request.
   163  )
   164  
   165  // Executor describes the interface wanted by the producers of requests.
   166  // After a Request is submitted, it's expected that the consumer will eventually
   167  // take it and report the execution result via Done().
   168  type Executor interface {
   169  	Submit(req *Request)
   170  }
   171  
   172  // Source describes the interface wanted by the consumers of requests.
   173  type Source interface {
   174  	Next() *Request
   175  }
   176  
   177  // PlainQueue is a straighforward thread-safe Request queue implementation.
   178  type PlainQueue struct {
   179  	stat  *stats.Val
   180  	mu    sync.Mutex
   181  	queue []*Request
   182  	pos   int
   183  }
   184  
   185  func Plain() *PlainQueue {
   186  	return &PlainQueue{}
   187  }
   188  
   189  func PlainWithStat(val *stats.Val) *PlainQueue {
   190  	return &PlainQueue{stat: val}
   191  }
   192  
   193  func (pq *PlainQueue) Len() int {
   194  	pq.mu.Lock()
   195  	defer pq.mu.Unlock()
   196  	return len(pq.queue) - pq.pos
   197  }
   198  
   199  func (pq *PlainQueue) Submit(req *Request) {
   200  	if pq.stat != nil {
   201  		pq.stat.Add(1)
   202  	}
   203  	pq.mu.Lock()
   204  	defer pq.mu.Unlock()
   205  
   206  	// It doesn't make sense to compact the queue too often.
   207  	const minSizeToCompact = 128
   208  	if pq.pos > len(pq.queue)/2 && len(pq.queue) >= minSizeToCompact {
   209  		copy(pq.queue, pq.queue[pq.pos:])
   210  		for pq.pos > 0 {
   211  			newLen := len(pq.queue) - 1
   212  			pq.queue[newLen] = nil
   213  			pq.queue = pq.queue[:newLen]
   214  			pq.pos--
   215  		}
   216  	}
   217  	pq.queue = append(pq.queue, req)
   218  }
   219  
   220  func (pq *PlainQueue) Next() *Request {
   221  	pq.mu.Lock()
   222  	defer pq.mu.Unlock()
   223  	return pq.nextLocked()
   224  }
   225  
   226  func (pq *PlainQueue) tryNext() *Request {
   227  	if !pq.mu.TryLock() {
   228  		return nil
   229  	}
   230  	defer pq.mu.Unlock()
   231  	return pq.nextLocked()
   232  }
   233  
   234  func (pq *PlainQueue) nextLocked() *Request {
   235  	if pq.pos == len(pq.queue) {
   236  		return nil
   237  	}
   238  	ret := pq.queue[pq.pos]
   239  	pq.queue[pq.pos] = nil
   240  	pq.pos++
   241  	if pq.stat != nil {
   242  		pq.stat.Add(-1)
   243  	}
   244  	return ret
   245  }
   246  
   247  // Order combines several different sources in a particular order.
   248  type orderImpl struct {
   249  	sources []Source
   250  }
   251  
   252  func Order(sources ...Source) Source {
   253  	return &orderImpl{sources: sources}
   254  }
   255  
   256  func (o *orderImpl) Next() *Request {
   257  	for _, s := range o.sources {
   258  		req := s.Next()
   259  		if req != nil {
   260  			return req
   261  		}
   262  	}
   263  	return nil
   264  }
   265  
   266  type callback struct {
   267  	cb func() *Request
   268  }
   269  
   270  // Callback produces a source that calls the callback to serve every Next() request.
   271  func Callback(cb func() *Request) Source {
   272  	return &callback{cb}
   273  }
   274  
   275  func (cb *callback) Next() *Request {
   276  	return cb.cb()
   277  }
   278  
   279  type alternate struct {
   280  	base Source
   281  	nth  int
   282  	seq  atomic.Int64
   283  }
   284  
   285  // Alternate proxies base, but returns nil every nth Next() call.
   286  func Alternate(base Source, nth int) Source {
   287  	return &alternate{
   288  		base: base,
   289  		nth:  nth,
   290  	}
   291  }
   292  
   293  func (a *alternate) Next() *Request {
   294  	if a.seq.Add(1)%int64(a.nth) == 0 {
   295  		return nil
   296  	}
   297  	return a.base.Next()
   298  }
   299  
   300  type DynamicOrderer struct {
   301  	mu       sync.Mutex
   302  	currPrio int
   303  	ops      *priorityQueueOps[*Request]
   304  }
   305  
   306  // DynamicOrder() can be used to form nested queues dynamically.
   307  // That is, if
   308  // q1 := pq.Append()
   309  // q2 := pq.Append()
   310  // All elements added via q2.Submit() will always have a *lower* priority
   311  // than all elements added via q1.Submit().
   312  func DynamicOrder() *DynamicOrderer {
   313  	return &DynamicOrderer{
   314  		ops: &priorityQueueOps[*Request]{},
   315  	}
   316  }
   317  
   318  func (do *DynamicOrderer) Append() Executor {
   319  	do.mu.Lock()
   320  	defer do.mu.Unlock()
   321  	do.currPrio++
   322  	return &dynamicOrdererItem{
   323  		parent: do,
   324  		prio:   do.currPrio,
   325  	}
   326  }
   327  
   328  func (do *DynamicOrderer) submit(req *Request, prio int) {
   329  	do.mu.Lock()
   330  	defer do.mu.Unlock()
   331  	do.ops.Push(req, prio)
   332  }
   333  
   334  func (do *DynamicOrderer) Next() *Request {
   335  	do.mu.Lock()
   336  	defer do.mu.Unlock()
   337  	return do.ops.Pop()
   338  }
   339  
   340  type dynamicOrdererItem struct {
   341  	parent *DynamicOrderer
   342  	prio   int
   343  }
   344  
   345  func (doi *dynamicOrdererItem) Submit(req *Request) {
   346  	doi.parent.submit(req, doi.prio)
   347  }
   348  
   349  type DynamicSourceCtl struct {
   350  	value atomic.Pointer[Source]
   351  }
   352  
   353  // DynamicSource is assumed never to point to nil.
   354  func DynamicSource(source Source) *DynamicSourceCtl {
   355  	var ret DynamicSourceCtl
   356  	ret.Store(source)
   357  	return &ret
   358  }
   359  
   360  func (ds *DynamicSourceCtl) Store(source Source) {
   361  	ds.value.Store(&source)
   362  }
   363  
   364  func (ds *DynamicSourceCtl) Next() *Request {
   365  	return (*ds.value.Load()).Next()
   366  }
   367  
   368  // Deduplicator() keeps track of the previously run requests to avoid re-running them.
   369  type Deduplicator struct {
   370  	mu     sync.Mutex
   371  	ctx    context.Context
   372  	source Source
   373  	mm     map[hash.Sig]*duplicateState
   374  }
   375  
   376  type duplicateState struct {
   377  	res    *Result
   378  	queued []*Request // duplicate requests waiting for the result.
   379  }
   380  
   381  func Deduplicate(ctx context.Context, source Source) Source {
   382  	return &Deduplicator{
   383  		ctx:    ctx,
   384  		source: source,
   385  		mm:     map[hash.Sig]*duplicateState{},
   386  	}
   387  }
   388  
   389  func (d *Deduplicator) Next() *Request {
   390  	for {
   391  		req := d.source.Next()
   392  		if req == nil {
   393  			return nil
   394  		}
   395  		hash := req.hash()
   396  		d.mu.Lock()
   397  		entry, ok := d.mm[hash]
   398  		if !ok {
   399  			d.mm[hash] = &duplicateState{}
   400  		} else if entry.res == nil {
   401  			// There's no result yet, put the request to the queue.
   402  			entry.queued = append(entry.queued, req)
   403  		} else {
   404  			// We already know the result.
   405  			req.Done(entry.res.clone())
   406  		}
   407  		d.mu.Unlock()
   408  		if !ok {
   409  			// This is the first time we see such a request.
   410  			req.OnDone(d.onDone)
   411  			return req
   412  		}
   413  	}
   414  }
   415  
   416  func (d *Deduplicator) onDone(req *Request, res *Result) bool {
   417  	hash := req.hash()
   418  	clonedRes := res.clone()
   419  
   420  	d.mu.Lock()
   421  	entry := d.mm[hash]
   422  	queued := entry.queued
   423  	entry.queued = nil
   424  	entry.res = clonedRes
   425  	d.mu.Unlock()
   426  
   427  	// Broadcast the result.
   428  	for _, waitingReq := range queued {
   429  		waitingReq.Done(res.clone())
   430  	}
   431  	return true
   432  }