github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/rpcserver/rpcserver.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package rpcserver 5 6 import ( 7 "bytes" 8 "context" 9 "errors" 10 "fmt" 11 "math/rand" 12 "net/url" 13 "slices" 14 "sort" 15 "strings" 16 "sync" 17 "sync/atomic" 18 19 "github.com/google/syzkaller/pkg/cover" 20 "github.com/google/syzkaller/pkg/cover/backend" 21 "github.com/google/syzkaller/pkg/flatrpc" 22 "github.com/google/syzkaller/pkg/fuzzer/queue" 23 "github.com/google/syzkaller/pkg/log" 24 "github.com/google/syzkaller/pkg/mgrconfig" 25 "github.com/google/syzkaller/pkg/report" 26 "github.com/google/syzkaller/pkg/signal" 27 "github.com/google/syzkaller/pkg/stat" 28 "github.com/google/syzkaller/pkg/vminfo" 29 "github.com/google/syzkaller/prog" 30 "github.com/google/syzkaller/sys/targets" 31 "github.com/google/syzkaller/vm/dispatcher" 32 "golang.org/x/sync/errgroup" 33 ) 34 35 type Config struct { 36 vminfo.Config 37 Stats 38 39 VMArch string 40 VMType string 41 RPC string 42 VMLess bool 43 // Hash adjacent PCs to form fuzzing feedback signal (otherwise just use coverage PCs as signal). 44 UseCoverEdges bool 45 // Filter signal/comparisons against target kernel text/data ranges. 46 // Disabled for gVisor/Starnix which are not Linux. 47 FilterSignal bool 48 PrintMachineCheck bool 49 // Abort early on syz-executor not replying to requests and print extra debugging information. 50 DebugTimeouts bool 51 Procs int 52 Slowdown int 53 pcBase uint64 54 localModules []*vminfo.KernelModule 55 56 // RPCServer closes the channel once the machine check has begun. Used for fault injection during testing. 57 machineCheckStarted chan struct{} 58 } 59 60 type RemoteConfig struct { 61 *mgrconfig.Config 62 Manager Manager 63 Stats Stats 64 Debug bool 65 } 66 67 type Manager interface { 68 MaxSignal() signal.Signal 69 BugFrames() (leaks []string, races []string) 70 MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) 71 CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) 72 } 73 74 type Server interface { 75 Listen() error 76 Close() error 77 Port() int 78 TriagedCorpus() 79 Serve(context.Context) error 80 CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error 81 ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte) 82 StopFuzzing(id int) 83 DistributeSignalDelta(plus signal.Signal) 84 } 85 86 type server struct { 87 cfg *Config 88 mgr Manager 89 serv *flatrpc.Serv 90 target *prog.Target 91 sysTarget *targets.Target 92 timeouts targets.Timeouts 93 checker *vminfo.Checker 94 95 infoOnce sync.Once 96 checkDone atomic.Bool 97 checkFailures int 98 onHandshake chan *handshakeResult 99 baseSource *queue.DynamicSourceCtl 100 setupFeatures flatrpc.Feature 101 canonicalModules *cover.Canonicalizer 102 coverFilter []uint64 103 104 mu sync.Mutex 105 runners map[int]*Runner 106 execSource *queue.Distributor 107 triagedCorpus atomic.Bool 108 109 Stats 110 *runnerStats 111 } 112 113 type Stats struct { 114 StatExecs *stat.Val 115 StatNumFuzzing *stat.Val 116 StatVMRestarts *stat.Val 117 StatModules *stat.Val 118 StatExecutorRestarts *stat.Val 119 } 120 121 func NewStats() Stats { 122 return NewNamedStats("") 123 } 124 125 func NewNamedStats(name string) Stats { 126 suffix, linkSuffix := "", "" 127 if name != "" { 128 suffix = " [" + name + "]" 129 linkSuffix = "?pool=" + url.QueryEscape(name) 130 } 131 return Stats{ 132 StatExecs: stat.New("exec total"+suffix, "Total test program executions", 133 stat.Console, stat.Rate{}, stat.Prometheus("syz_exec_total"+name), 134 ), 135 StatNumFuzzing: stat.New("fuzzing VMs"+suffix, 136 "Number of VMs that are currently fuzzing", stat.Graph("fuzzing VMs"), 137 stat.Link("/vms"+linkSuffix), 138 ), 139 StatVMRestarts: stat.New("vm restarts"+suffix, "Total number of VM starts", 140 stat.Rate{}, stat.NoGraph), 141 StatModules: stat.New("modules"+suffix, "Number of loaded kernel modules", 142 stat.NoGraph, stat.Link("/modules"+linkSuffix)), 143 StatExecutorRestarts: stat.New("executor restarts"+suffix, 144 "Number of times executor process was restarted", stat.Rate{}, stat.Graph("executor")), 145 } 146 } 147 148 func New(cfg *RemoteConfig) (Server, error) { 149 var pcBase uint64 150 if cfg.KernelObj != "" { 151 var err error 152 pcBase, err = cover.GetPCBase(cfg.Config) 153 if err != nil { 154 return nil, err 155 } 156 } 157 sandbox, err := flatrpc.SandboxToFlags(cfg.Sandbox) 158 if err != nil { 159 return nil, err 160 } 161 features := flatrpc.AllFeatures 162 if !cfg.Experimental.RemoteCover { 163 features &= ^flatrpc.FeatureExtraCoverage 164 } 165 return newImpl(&Config{ 166 Config: vminfo.Config{ 167 Target: cfg.Target, 168 VMType: cfg.Type, 169 Features: features, 170 Syscalls: cfg.Syscalls, 171 Debug: cfg.Debug, 172 Cover: cfg.Cover, 173 Sandbox: sandbox, 174 SandboxArg: cfg.SandboxArg, 175 }, 176 Stats: cfg.Stats, 177 VMArch: cfg.TargetVMArch, 178 RPC: cfg.RPC, 179 VMLess: cfg.VMLess, 180 // gVisor coverage is not a trace, so producing edges won't work. 181 UseCoverEdges: cfg.Experimental.CoverEdges && cfg.Type != targets.GVisor, 182 // gVisor/Starnix are not Linux, so filtering against Linux ranges won't work. 183 FilterSignal: cfg.Type != targets.GVisor && cfg.Type != targets.Starnix, 184 PrintMachineCheck: true, 185 Procs: cfg.Procs, 186 Slowdown: cfg.Timeouts.Slowdown, 187 pcBase: pcBase, 188 localModules: cfg.LocalModules, 189 }, cfg.Manager), nil 190 } 191 192 func newImpl(cfg *Config, mgr Manager) *server { 193 // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness. 194 sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch) 195 cfg.Procs = min(cfg.Procs, prog.MaxPids) 196 checker := vminfo.New(&cfg.Config) 197 baseSource := queue.DynamicSource(checker) 198 return &server{ 199 cfg: cfg, 200 mgr: mgr, 201 target: cfg.Target, 202 sysTarget: sysTarget, 203 timeouts: sysTarget.Timeouts(cfg.Slowdown), 204 runners: make(map[int]*Runner), 205 checker: checker, 206 baseSource: baseSource, 207 execSource: queue.Distribute(queue.Retry(baseSource)), 208 onHandshake: make(chan *handshakeResult, 1), 209 210 Stats: cfg.Stats, 211 runnerStats: &runnerStats{ 212 statExecRetries: stat.New("exec retries", 213 "Number of times a test program was restarted because the first run failed", 214 stat.Rate{}, stat.Graph("executor")), 215 statExecutorRestarts: cfg.Stats.StatExecutorRestarts, 216 statExecBufferTooSmall: queue.StatExecBufferTooSmall, 217 statExecs: cfg.Stats.StatExecs, 218 statNoExecRequests: queue.StatNoExecRequests, 219 statNoExecDuration: queue.StatNoExecDuration, 220 }, 221 } 222 } 223 224 func (serv *server) Close() error { 225 return serv.serv.Close() 226 } 227 228 func (serv *server) Listen() error { 229 s, err := flatrpc.Listen(serv.cfg.RPC) 230 if err != nil { 231 return err 232 } 233 serv.serv = s 234 return nil 235 } 236 237 // Used for errors incompatible with further RPCServer operation. 238 var errFatal = errors.New("aborting RPC server") 239 240 func (serv *server) Serve(ctx context.Context) error { 241 g, ctx := errgroup.WithContext(ctx) 242 g.Go(func() error { 243 return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { 244 err := serv.handleConn(ctx, conn) 245 if err != nil && !errors.Is(err, errFatal) { 246 log.Logf(2, "%v", err) 247 return nil 248 } 249 return err 250 }) 251 }) 252 g.Go(func() error { 253 var info *handshakeResult 254 select { 255 case <-ctx.Done(): 256 return nil 257 case info = <-serv.onHandshake: 258 } 259 // We run the machine check specifically from the top level context, 260 // not from the per-connection one. 261 return serv.runCheck(ctx, info) 262 }) 263 return g.Wait() 264 } 265 266 func (serv *server) Port() int { 267 return serv.serv.Addr.Port 268 } 269 270 // Must be simple enough to not require adding dependencies to the executor. 271 func authHash(value uint64) uint64 { 272 prime1 := uint64(73856093) 273 prime2 := uint64(83492791) 274 hashValue := (value * prime1) ^ prime2 275 276 return hashValue 277 } 278 279 func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error { 280 // Use a random cookie, because we do not want the fuzzer to accidentally guess it and DDoS multiple managers. 281 helloCookie := rand.Uint64() 282 expectCookie := authHash(helloCookie) 283 connectHello := &flatrpc.ConnectHello{ 284 Cookie: helloCookie, 285 } 286 287 if err := flatrpc.Send(conn, connectHello); err != nil { 288 // The other side is not an executor. 289 return fmt.Errorf("failed to establish connection with a remote runner") 290 } 291 292 connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) 293 if err != nil { 294 return err 295 } 296 id := int(connectReq.Id) 297 298 if connectReq.Cookie != expectCookie { 299 return fmt.Errorf("client failed to respond with a valid cookie: %v (expected %v)", connectReq.Cookie, expectCookie) 300 } 301 302 // From now on, assume that the client is well-behaving. 303 log.Logf(1, "runner %v connected", id) 304 305 if serv.cfg.VMLess { 306 // There is no VM loop, so mimic what it would do. 307 serv.CreateInstance(id, nil, nil) 308 defer func() { 309 serv.StopFuzzing(id) 310 serv.ShutdownInstance(id, true) 311 }() 312 } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { 313 return err 314 } 315 serv.StatVMRestarts.Add(1) 316 317 serv.mu.Lock() 318 runner := serv.runners[id] 319 serv.mu.Unlock() 320 if runner == nil { 321 return fmt.Errorf("unknown VM %v tries to connect", id) 322 } 323 324 err = serv.handleRunnerConn(ctx, runner, conn) 325 log.Logf(2, "runner %v: %v", id, err) 326 327 runner.resultCh <- err 328 return nil 329 } 330 331 func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn *flatrpc.Conn) error { 332 opts := &handshakeConfig{ 333 VMLess: serv.cfg.VMLess, 334 Files: serv.checker.RequiredFiles(), 335 Timeouts: serv.timeouts, 336 Callback: serv.handleMachineInfo, 337 } 338 opts.LeakFrames, opts.RaceFrames = serv.mgr.BugFrames() 339 if serv.checkDone.Load() { 340 opts.Features = serv.setupFeatures 341 } else { 342 opts.Files = append(opts.Files, serv.checker.CheckFiles()...) 343 opts.Features = serv.cfg.Features 344 } 345 346 info, err := runner.Handshake(conn, opts) 347 if err != nil { 348 log.Logf(1, "%v", err) 349 return err 350 } 351 352 select { 353 case serv.onHandshake <- &info: 354 default: 355 } 356 357 if serv.triagedCorpus.Load() { 358 if err := runner.SendCorpusTriaged(); err != nil { 359 return err 360 } 361 } 362 return serv.connectionLoop(ctx, runner) 363 } 364 365 func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { 366 modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) 367 if err != nil { 368 log.Logf(0, "parsing of machine info failed: %v", err) 369 if infoReq.Error == "" { 370 infoReq.Error = err.Error() 371 } 372 } 373 modules = backend.FixModules(serv.cfg.localModules, modules, serv.cfg.pcBase) 374 if infoReq.Error != "" { 375 log.Logf(0, "machine check failed: %v", infoReq.Error) 376 serv.checkFailures++ 377 if serv.checkFailures == 10 { 378 return handshakeResult{}, fmt.Errorf("%w: machine check failed too many times", errFatal) 379 } 380 return handshakeResult{}, errors.New("machine check failed") 381 } 382 var retErr error 383 serv.infoOnce.Do(func() { 384 serv.StatModules.Add(len(modules)) 385 serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover) 386 var err error 387 serv.coverFilter, err = serv.mgr.CoverageFilter(modules) 388 if err != nil { 389 retErr = fmt.Errorf("%w: %w", errFatal, err) 390 return 391 } 392 }) 393 if retErr != nil { 394 return handshakeResult{}, retErr 395 } 396 // Flatbuffers don't do deep copy of byte slices, 397 // so clone manually since we may later pass it a goroutine. 398 for _, file := range infoReq.Files { 399 file.Data = slices.Clone(file.Data) 400 } 401 canonicalizer := serv.canonicalModules.NewInstance(modules) 402 return handshakeResult{ 403 CovFilter: canonicalizer.Decanonicalize(serv.coverFilter), 404 MachineInfo: machineInfo, 405 Canonicalizer: canonicalizer, 406 Files: infoReq.Files, 407 Features: infoReq.Features, 408 }, nil 409 } 410 411 func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) error { 412 // To "cancel" the runner's loop we need to call runner.Stop(). 413 // At the same time, we don't want to leak the goroutine that monitors it, 414 // so we derive a new context and cancel it on function exit. 415 ctx, cancel := context.WithCancel(baseCtx) 416 defer cancel() 417 go func() { 418 <-ctx.Done() 419 runner.Stop() 420 }() 421 422 if serv.cfg.Cover { 423 maxSignal := serv.mgr.MaxSignal().ToRaw() 424 for len(maxSignal) != 0 { 425 // Split coverage into batches to not grow the connection serialization 426 // buffer too much (we don't want to grow it larger than what will be needed 427 // to send programs). 428 n := min(len(maxSignal), 50000) 429 if err := runner.SendSignalUpdate(maxSignal[:n]); err != nil { 430 return err 431 } 432 maxSignal = maxSignal[n:] 433 } 434 } 435 436 serv.StatNumFuzzing.Add(1) 437 defer serv.StatNumFuzzing.Add(-1) 438 439 return runner.ConnectionLoop() 440 } 441 442 func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { 443 if target.Arch != a.Arch { 444 return fmt.Errorf("%w: mismatching manager/executor arches: %v vs %v (full request: `%#v`)", 445 errFatal, target.Arch, a.Arch, a) 446 } 447 if prog.GitRevision != a.GitRevision { 448 return fmt.Errorf("%w: mismatching manager/executor git revisions: %v vs %v", 449 errFatal, prog.GitRevision, a.GitRevision) 450 } 451 if target.Revision != a.SyzRevision { 452 return fmt.Errorf("%w: mismatching manager/executor system call descriptions: %v vs %v", 453 errFatal, target.Revision, a.SyzRevision) 454 } 455 return nil 456 } 457 458 func (serv *server) runCheck(ctx context.Context, info *handshakeResult) error { 459 if serv.cfg.machineCheckStarted != nil { 460 close(serv.cfg.machineCheckStarted) 461 } 462 enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(ctx, info.Files, info.Features) 463 if checkErr == vminfo.ErrAborted { 464 return nil 465 } 466 467 enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) 468 // Note: need to print disbled syscalls before failing due to an error. 469 // This helps to debug "all system calls are disabled". 470 if serv.cfg.PrintMachineCheck { 471 serv.printMachineCheck(info.Files, enabledCalls, disabledCalls, transitivelyDisabled, features) 472 } 473 if checkErr != nil { 474 return checkErr 475 } 476 enabledFeatures := features.Enabled() 477 serv.setupFeatures = features.NeedSetup() 478 newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) 479 if err != nil { 480 return err 481 } 482 serv.baseSource.Store(newSource) 483 serv.checkDone.Store(true) 484 return nil 485 } 486 487 func (serv *server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool, 488 disabledCalls, transitivelyDisabled map[*prog.Syscall]string, features vminfo.Features) { 489 buf := new(bytes.Buffer) 490 if len(serv.cfg.Syscalls) != 0 || log.V(1) { 491 if len(disabledCalls) != 0 { 492 var lines []string 493 for call, reason := range disabledCalls { 494 lines = append(lines, fmt.Sprintf("%-44v: %v\n", call.Name, reason)) 495 } 496 sort.Strings(lines) 497 fmt.Fprintf(buf, "disabled the following syscalls:\n%s\n", strings.Join(lines, "")) 498 } 499 if len(transitivelyDisabled) != 0 { 500 var lines []string 501 for call, reason := range transitivelyDisabled { 502 lines = append(lines, fmt.Sprintf("%-44v: %v\n", call.Name, reason)) 503 } 504 sort.Strings(lines) 505 fmt.Fprintf(buf, "transitively disabled the following syscalls"+ 506 " (missing resource [creating syscalls]):\n%s\n", 507 strings.Join(lines, "")) 508 } 509 } 510 hasFileErrors := false 511 for _, file := range checkFilesInfo { 512 if file.Error == "" { 513 continue 514 } 515 if !hasFileErrors { 516 fmt.Fprintf(buf, "failed to read the following files in the VM:\n") 517 } 518 fmt.Fprintf(buf, "%-44v: %v\n", file.Name, file.Error) 519 hasFileErrors = true 520 } 521 if hasFileErrors { 522 fmt.Fprintf(buf, "\n") 523 } 524 var lines []string 525 lines = append(lines, fmt.Sprintf("%-24v: %v/%v\n", "syscalls", 526 len(enabledCalls), len(serv.cfg.Target.Syscalls))) 527 for feat, info := range features { 528 lines = append(lines, fmt.Sprintf("%-24v: %v\n", 529 flatrpc.EnumNamesFeature[feat], info.Reason)) 530 } 531 sort.Strings(lines) 532 buf.WriteString(strings.Join(lines, "")) 533 fmt.Fprintf(buf, "\n") 534 log.Logf(0, "machine check:\n%s", buf.Bytes()) 535 } 536 537 func (serv *server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error { 538 runner := &Runner{ 539 id: id, 540 source: serv.execSource, 541 cover: serv.cfg.Cover, 542 coverEdges: serv.cfg.UseCoverEdges, 543 filterSignal: serv.cfg.FilterSignal, 544 debug: serv.cfg.Debug, 545 debugTimeouts: serv.cfg.DebugTimeouts, 546 sysTarget: serv.sysTarget, 547 injectExec: injectExec, 548 infoc: make(chan chan []byte), 549 requests: make(map[int64]*queue.Request), 550 executing: make(map[int64]bool), 551 hanged: make(map[int64]bool), 552 // Executor may report proc IDs that are larger than serv.cfg.Procs. 553 lastExec: MakeLastExecuting(prog.MaxPids, 6), 554 stats: serv.runnerStats, 555 procs: serv.cfg.Procs, 556 updInfo: updInfo, 557 resultCh: make(chan error, 1), 558 } 559 serv.mu.Lock() 560 defer serv.mu.Unlock() 561 if serv.runners[id] != nil { 562 panic(fmt.Sprintf("duplicate instance %v", id)) 563 } 564 serv.runners[id] = runner 565 return runner.resultCh 566 } 567 568 // stopInstance prevents further request exchange requests. 569 // To make RPCServer fully forget an instance, shutdownInstance() must be called. 570 func (serv *server) StopFuzzing(id int) { 571 serv.mu.Lock() 572 runner := serv.runners[id] 573 serv.mu.Unlock() 574 if runner.updInfo != nil { 575 runner.updInfo(func(info *dispatcher.Info) { 576 info.Status = "fuzzing is stopped" 577 }) 578 } 579 runner.Stop() 580 } 581 582 func (serv *server) ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte) { 583 serv.mu.Lock() 584 runner := serv.runners[id] 585 delete(serv.runners, id) 586 serv.mu.Unlock() 587 return runner.Shutdown(crashed, extraExecs...), runner.MachineInfo() 588 } 589 590 func (serv *server) DistributeSignalDelta(plus signal.Signal) { 591 plusRaw := plus.ToRaw() 592 serv.foreachRunnerAsync(func(runner *Runner) { 593 runner.SendSignalUpdate(plusRaw) 594 }) 595 } 596 597 func (serv *server) TriagedCorpus() { 598 serv.triagedCorpus.Store(true) 599 serv.foreachRunnerAsync(func(runner *Runner) { 600 runner.SendCorpusTriaged() 601 }) 602 } 603 604 // foreachRunnerAsync runs callback fn for each connected runner asynchronously. 605 // If a VM has hanged w/o reading out the socket, we want to avoid blocking 606 // important goroutines on the send operations. 607 func (serv *server) foreachRunnerAsync(fn func(runner *Runner)) { 608 serv.mu.Lock() 609 defer serv.mu.Unlock() 610 for _, runner := range serv.runners { 611 if runner.Alive() { 612 go fn(runner) 613 } 614 } 615 }