github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/vm/vm.go (about)

     1  // Copyright 2015 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // Package vm provides an abstract test machine (VM, physical machine, etc)
     5  // interface for the rest of the system.
     6  // For convenience test machines are subsequently collectively called VMs.
     7  // Package wraps vmimpl package interface with some common functionality
     8  // and higher-level interface.
     9  package vm
    10  
    11  import (
    12  	"bytes"
    13  	"fmt"
    14  	"io"
    15  	"os"
    16  	"path/filepath"
    17  	"strings"
    18  	"sync/atomic"
    19  	"time"
    20  
    21  	"github.com/google/syzkaller/pkg/mgrconfig"
    22  	"github.com/google/syzkaller/pkg/osutil"
    23  	"github.com/google/syzkaller/pkg/report"
    24  	"github.com/google/syzkaller/pkg/stats"
    25  	"github.com/google/syzkaller/sys/targets"
    26  	"github.com/google/syzkaller/vm/vmimpl"
    27  
    28  	// Import all VM implementations, so that users only need to import vm.
    29  	_ "github.com/google/syzkaller/vm/adb"
    30  	_ "github.com/google/syzkaller/vm/bhyve"
    31  	_ "github.com/google/syzkaller/vm/cuttlefish"
    32  	_ "github.com/google/syzkaller/vm/gce"
    33  	_ "github.com/google/syzkaller/vm/gvisor"
    34  	_ "github.com/google/syzkaller/vm/isolated"
    35  	_ "github.com/google/syzkaller/vm/proxyapp"
    36  	_ "github.com/google/syzkaller/vm/qemu"
    37  	_ "github.com/google/syzkaller/vm/starnix"
    38  	_ "github.com/google/syzkaller/vm/vmm"
    39  	_ "github.com/google/syzkaller/vm/vmware"
    40  )
    41  
    42  type Pool struct {
    43  	impl               vmimpl.Pool
    44  	workdir            string
    45  	template           string
    46  	timeouts           targets.Timeouts
    47  	activeCount        int32
    48  	hostFuzzer         bool
    49  	statOutputReceived *stats.Val
    50  }
    51  
    52  type Instance struct {
    53  	pool    *Pool
    54  	impl    vmimpl.Instance
    55  	workdir string
    56  	index   int
    57  	onClose func()
    58  }
    59  
    60  var (
    61  	Shutdown                = vmimpl.Shutdown
    62  	ErrTimeout              = vmimpl.ErrTimeout
    63  	_          BootErrorer  = vmimpl.BootError{}
    64  	_          InfraErrorer = vmimpl.InfraError{}
    65  )
    66  
    67  type BootErrorer interface {
    68  	BootError() (string, []byte)
    69  }
    70  
    71  type InfraErrorer interface {
    72  	InfraError() (string, []byte)
    73  }
    74  
    75  // vmType splits the VM type from any suffix (separated by ":"). This is mostly
    76  // useful for the "proxyapp" type, where pkg/build needs to specify/handle
    77  // sub-types.
    78  func vmType(fullName string) string {
    79  	name, _, _ := strings.Cut(fullName, ":")
    80  	return name
    81  }
    82  
    83  // AllowsOvercommit returns if the instance type allows overcommit of instances
    84  // (i.e. creation of instances out-of-thin-air). Overcommit is used during image
    85  // and patch testing in syz-ci when it just asks for more than specified in config
    86  // instances. Generally virtual machines (qemu, gce) support overcommit,
    87  // while physical machines (adb, isolated) do not. Strictly speaking, we should
    88  // never use overcommit and use only what's specified in config, because we
    89  // override resource limits specified in config (e.g. can OOM). But it works and
    90  // makes lots of things much simpler.
    91  func AllowsOvercommit(typ string) bool {
    92  	return vmimpl.Types[vmType(typ)].Overcommit
    93  }
    94  
    95  // Create creates a VM pool that can be used to create individual VMs.
    96  func Create(cfg *mgrconfig.Config, debug bool) (*Pool, error) {
    97  	typ, ok := vmimpl.Types[vmType(cfg.Type)]
    98  	if !ok {
    99  		return nil, fmt.Errorf("unknown instance type '%v'", cfg.Type)
   100  	}
   101  	env := &vmimpl.Env{
   102  		Name:      cfg.Name,
   103  		OS:        cfg.TargetOS,
   104  		Arch:      cfg.TargetVMArch,
   105  		Workdir:   cfg.Workdir,
   106  		Image:     cfg.Image,
   107  		SSHKey:    cfg.SSHKey,
   108  		SSHUser:   cfg.SSHUser,
   109  		Timeouts:  cfg.Timeouts,
   110  		Debug:     debug,
   111  		Config:    cfg.VM,
   112  		KernelSrc: cfg.KernelSrc,
   113  	}
   114  	impl, err := typ.Ctor(env)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return &Pool{
   119  		impl:       impl,
   120  		workdir:    env.Workdir,
   121  		template:   cfg.WorkdirTemplate,
   122  		timeouts:   cfg.Timeouts,
   123  		hostFuzzer: cfg.SysTarget.HostFuzzer,
   124  		statOutputReceived: stats.Create("vm output", "Bytes of VM console output received",
   125  			stats.Graph("traffic"), stats.Rate{}, stats.FormatMB),
   126  	}, nil
   127  }
   128  
   129  func (pool *Pool) Count() int {
   130  	return pool.impl.Count()
   131  }
   132  
   133  func (pool *Pool) Create(index int) (*Instance, error) {
   134  	if index < 0 || index >= pool.Count() {
   135  		return nil, fmt.Errorf("invalid VM index %v (count %v)", index, pool.Count())
   136  	}
   137  	workdir, err := osutil.ProcessTempDir(pool.workdir)
   138  	if err != nil {
   139  		return nil, fmt.Errorf("failed to create instance temp dir: %w", err)
   140  	}
   141  	if pool.template != "" {
   142  		if err := osutil.CopyDirRecursively(pool.template, filepath.Join(workdir, "template")); err != nil {
   143  			return nil, err
   144  		}
   145  	}
   146  	impl, err := pool.impl.Create(workdir, index)
   147  	if err != nil {
   148  		os.RemoveAll(workdir)
   149  		return nil, err
   150  	}
   151  	atomic.AddInt32(&pool.activeCount, 1)
   152  	return &Instance{
   153  		pool:    pool,
   154  		impl:    impl,
   155  		workdir: workdir,
   156  		index:   index,
   157  		onClose: func() { atomic.AddInt32(&pool.activeCount, -1) },
   158  	}, nil
   159  }
   160  
   161  // TODO: Integration or end-to-end testing is needed.
   162  //
   163  //	https://github.com/google/syzkaller/pull/3269#discussion_r967650801
   164  func (pool *Pool) Close() error {
   165  	if pool.activeCount != 0 {
   166  		panic("all the instances should be closed before pool.Close()")
   167  	}
   168  	if closer, ok := pool.impl.(io.Closer); ok {
   169  		return closer.Close()
   170  	}
   171  	return nil
   172  }
   173  
   174  func (inst *Instance) Copy(hostSrc string) (string, error) {
   175  	return inst.impl.Copy(hostSrc)
   176  }
   177  
   178  func (inst *Instance) Forward(port int) (string, error) {
   179  	return inst.impl.Forward(port)
   180  }
   181  
   182  type ExitCondition int
   183  
   184  const (
   185  	// The program is allowed to exit after timeout.
   186  	ExitTimeout = ExitCondition(1 << iota)
   187  	// The program is allowed to exit with no errors.
   188  	ExitNormal
   189  	// The program is allowed to exit with errors.
   190  	ExitError
   191  )
   192  
   193  type StopChan <-chan bool
   194  type InjectOutput <-chan []byte
   195  type OutputSize int
   196  
   197  // An early notification that the command has finished / VM crashed.
   198  type EarlyFinishCb func()
   199  
   200  // Run runs cmd inside of the VM (think of ssh cmd) and monitors command execution
   201  // and the kernel console output. It detects kernel oopses in output, lost connections, hangs, etc.
   202  // Returns command+kernel output and a non-symbolized crash report (nil if no error happens).
   203  // Accepted options:
   204  //   - StopChan: stop channel can be used to prematurely stop the command
   205  //   - ExitCondition: says which exit modes should be considered as errors/OK
   206  //   - OutputSize: how much output to keep/return
   207  func (inst *Instance) Run(timeout time.Duration, reporter *report.Reporter, command string, opts ...any) (
   208  	[]byte, *report.Report, error) {
   209  	exit := ExitNormal
   210  	var stop <-chan bool
   211  	var injected <-chan []byte
   212  	var finished func()
   213  	outputSize := beforeContextDefault
   214  	for _, o := range opts {
   215  		switch opt := o.(type) {
   216  		case ExitCondition:
   217  			exit = opt
   218  		case StopChan:
   219  			stop = opt
   220  		case OutputSize:
   221  			outputSize = int(opt)
   222  		case InjectOutput:
   223  			injected = (<-chan []byte)(opt)
   224  		case EarlyFinishCb:
   225  			finished = opt
   226  		default:
   227  			panic(fmt.Sprintf("unknown option %#v", opt))
   228  		}
   229  	}
   230  	outc, errc, err := inst.impl.Run(timeout, stop, command)
   231  	if err != nil {
   232  		return nil, nil, err
   233  	}
   234  	mon := &monitor{
   235  		inst:            inst,
   236  		outc:            outc,
   237  		injected:        injected,
   238  		errc:            errc,
   239  		finished:        finished,
   240  		reporter:        reporter,
   241  		beforeContext:   outputSize,
   242  		exit:            exit,
   243  		lastExecuteTime: time.Now(),
   244  	}
   245  	rep := mon.monitorExecution()
   246  	return mon.output, rep, nil
   247  }
   248  
   249  func (inst *Instance) Info() ([]byte, error) {
   250  	if ii, ok := inst.impl.(vmimpl.Infoer); ok {
   251  		return ii.Info()
   252  	}
   253  	return nil, nil
   254  }
   255  
   256  func (inst *Instance) PprofPort() int {
   257  	if inst.pool.hostFuzzer {
   258  		// In the fuzzing on host mode, fuzzers are always on the same network.
   259  		// Don't set up pprof endpoints in this case.
   260  		return 0
   261  	}
   262  	if ii, ok := inst.impl.(vmimpl.PprofPortProvider); ok {
   263  		return ii.PprofPort()
   264  	}
   265  	return vmimpl.PprofPort
   266  }
   267  
   268  func (inst *Instance) diagnose(rep *report.Report) ([]byte, bool) {
   269  	if rep == nil {
   270  		panic("rep is nil")
   271  	}
   272  	return inst.impl.Diagnose(rep)
   273  }
   274  
   275  func (inst *Instance) Close() {
   276  	inst.impl.Close()
   277  	os.RemoveAll(inst.workdir)
   278  	inst.onClose()
   279  }
   280  
   281  type monitor struct {
   282  	inst            *Instance
   283  	outc            <-chan []byte
   284  	injected        <-chan []byte
   285  	finished        func()
   286  	errc            <-chan error
   287  	reporter        *report.Reporter
   288  	exit            ExitCondition
   289  	output          []byte
   290  	beforeContext   int
   291  	matchPos        int
   292  	lastExecuteTime time.Time
   293  	extractCalled   bool
   294  }
   295  
   296  func (mon *monitor) monitorExecution() *report.Report {
   297  	ticker := time.NewTicker(tickerPeriod * mon.inst.pool.timeouts.Scale)
   298  	defer ticker.Stop()
   299  	defer func() {
   300  		if mon.finished != nil {
   301  			mon.finished()
   302  		}
   303  	}()
   304  	for {
   305  		select {
   306  		case err := <-mon.errc:
   307  			switch err {
   308  			case nil:
   309  				// The program has exited without errors,
   310  				// but wait for kernel output in case there is some delayed oops.
   311  				crash := ""
   312  				if mon.exit&ExitNormal == 0 {
   313  					crash = lostConnectionCrash
   314  				}
   315  				return mon.extractError(crash)
   316  			case ErrTimeout:
   317  				if mon.exit&ExitTimeout == 0 {
   318  					return mon.extractError(timeoutCrash)
   319  				}
   320  				return nil
   321  			default:
   322  				// Note: connection lost can race with a kernel oops message.
   323  				// In such case we want to return the kernel oops.
   324  				crash := ""
   325  				if mon.exit&ExitError == 0 {
   326  					crash = lostConnectionCrash
   327  				}
   328  				return mon.extractError(crash)
   329  			}
   330  		case out, ok := <-mon.outc:
   331  			if !ok {
   332  				mon.outc = nil
   333  				continue
   334  			}
   335  			mon.inst.pool.statOutputReceived.Add(len(out))
   336  			if rep, done := mon.appendOutput(out); done {
   337  				return rep
   338  			}
   339  		case out := <-mon.injected:
   340  			if rep, done := mon.appendOutput(out); done {
   341  				return rep
   342  			}
   343  		case <-ticker.C:
   344  			// Detect both "no output whatsoever" and "kernel episodically prints
   345  			// something to console, but fuzzer is not actually executing programs".
   346  			if time.Since(mon.lastExecuteTime) > mon.inst.pool.timeouts.NoOutput {
   347  				return mon.extractError(noOutputCrash)
   348  			}
   349  		case <-Shutdown:
   350  			return nil
   351  		}
   352  	}
   353  }
   354  
   355  func (mon *monitor) appendOutput(out []byte) (*report.Report, bool) {
   356  	lastPos := len(mon.output)
   357  	mon.output = append(mon.output, out...)
   358  	if bytes.Contains(mon.output[lastPos:], executingProgram1) ||
   359  		bytes.Contains(mon.output[lastPos:], executingProgram2) {
   360  		mon.lastExecuteTime = time.Now()
   361  	}
   362  	if mon.reporter.ContainsCrash(mon.output[mon.matchPos:]) {
   363  		return mon.extractError("unknown error"), true
   364  	}
   365  	if len(mon.output) > 2*mon.beforeContext {
   366  		copy(mon.output, mon.output[len(mon.output)-mon.beforeContext:])
   367  		mon.output = mon.output[:mon.beforeContext]
   368  	}
   369  	// Find the starting position for crash matching on the next iteration.
   370  	// We step back from the end of output by maxErrorLength to handle the case
   371  	// when a crash line is currently split/incomplete. And then we try to find
   372  	// the preceding '\n' to have a full line. This is required to handle
   373  	// the case when a particular pattern is ignored as crash, but a suffix
   374  	// of the pattern is detected as crash (e.g. "ODEBUG:" is trimmed to "BUG:").
   375  	mon.matchPos = len(mon.output) - maxErrorLength
   376  	for i := 0; i < maxErrorLength; i++ {
   377  		if mon.matchPos <= 0 || mon.output[mon.matchPos-1] == '\n' {
   378  			break
   379  		}
   380  		mon.matchPos--
   381  	}
   382  	if mon.matchPos < 0 {
   383  		mon.matchPos = 0
   384  	}
   385  	return nil, false
   386  }
   387  
   388  func (mon *monitor) extractError(defaultError string) *report.Report {
   389  	if mon.extractCalled {
   390  		panic("extractError called twice")
   391  	}
   392  	mon.extractCalled = true
   393  	if mon.finished != nil {
   394  		// If the caller wanted an early notification, provide it.
   395  		mon.finished()
   396  		mon.finished = nil
   397  	}
   398  	diagOutput, diagWait := []byte{}, false
   399  	if defaultError != "" {
   400  		diagOutput, diagWait = mon.inst.diagnose(mon.createReport(defaultError))
   401  	}
   402  	// Give it some time to finish writing the error message.
   403  	// But don't wait for "no output", we already waited enough.
   404  	if defaultError != noOutputCrash || diagWait {
   405  		mon.waitForOutput()
   406  	}
   407  	if bytes.Contains(mon.output, []byte(fuzzerPreemptedStr)) {
   408  		return nil
   409  	}
   410  	if defaultError == "" && mon.reporter.ContainsCrash(mon.output[mon.matchPos:]) {
   411  		// We did not call Diagnose above because we thought there is no error, so call it now.
   412  		diagOutput, diagWait = mon.inst.diagnose(mon.createReport(defaultError))
   413  		if diagWait {
   414  			mon.waitForOutput()
   415  		}
   416  	}
   417  	rep := mon.createReport(defaultError)
   418  	if rep == nil {
   419  		return nil
   420  	}
   421  	if len(diagOutput) > 0 {
   422  		rep.Output = append(rep.Output, vmDiagnosisStart...)
   423  		rep.Output = append(rep.Output, diagOutput...)
   424  	}
   425  	return rep
   426  }
   427  
   428  func (mon *monitor) createReport(defaultError string) *report.Report {
   429  	rep := mon.reporter.ParseFrom(mon.output, mon.matchPos)
   430  	if rep == nil {
   431  		if defaultError == "" {
   432  			return nil
   433  		}
   434  		return &report.Report{
   435  			Title:      defaultError,
   436  			Output:     mon.output,
   437  			Suppressed: report.IsSuppressed(mon.reporter, mon.output),
   438  		}
   439  	}
   440  	start := rep.StartPos - mon.beforeContext
   441  	if start < 0 {
   442  		start = 0
   443  	}
   444  	end := rep.EndPos + afterContext
   445  	if end > len(rep.Output) {
   446  		end = len(rep.Output)
   447  	}
   448  	rep.Output = rep.Output[start:end]
   449  	rep.StartPos -= start
   450  	rep.EndPos -= start
   451  	return rep
   452  }
   453  
   454  func (mon *monitor) waitForOutput() {
   455  	timer := time.NewTimer(waitForOutputTimeout * mon.inst.pool.timeouts.Scale)
   456  	defer timer.Stop()
   457  	for {
   458  		select {
   459  		case out, ok := <-mon.outc:
   460  			if !ok {
   461  				return
   462  			}
   463  			mon.output = append(mon.output, out...)
   464  		case <-timer.C:
   465  			return
   466  		case <-Shutdown:
   467  			return
   468  		}
   469  	}
   470  }
   471  
   472  const (
   473  	maxErrorLength = 256
   474  
   475  	lostConnectionCrash = "lost connection to test machine"
   476  	noOutputCrash       = "no output from test machine"
   477  	timeoutCrash        = "timed out"
   478  	fuzzerPreemptedStr  = "SYZ-FUZZER: PREEMPTED"
   479  	vmDiagnosisStart    = "\nVM DIAGNOSIS:\n"
   480  )
   481  
   482  var (
   483  	executingProgram1 = []byte("executing program")  // syz-fuzzer, syz-runner output
   484  	executingProgram2 = []byte("executed programs:") // syz-execprog output
   485  
   486  	beforeContextDefault = 1024 << 10
   487  	afterContext         = 128 << 10
   488  
   489  	tickerPeriod         = 10 * time.Second
   490  	waitForOutputTimeout = 10 * time.Second
   491  )