github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/tools/syz-db/syz-db.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package main
     5  
     6  import (
     7  	"flag"
     8  	"fmt"
     9  	"os"
    10  	"path/filepath"
    11  	"runtime"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/google/syzkaller/pkg/db"
    18  	"github.com/google/syzkaller/pkg/hash"
    19  	"github.com/google/syzkaller/pkg/osutil"
    20  	"github.com/google/syzkaller/pkg/tool"
    21  	"github.com/google/syzkaller/prog"
    22  	_ "github.com/google/syzkaller/sys"
    23  	"golang.org/x/exp/maps"
    24  )
    25  
    26  func main() {
    27  	var (
    28  		flagVersion = flag.Uint64("version", 0, "database version")
    29  		flagOS      = flag.String("os", runtime.GOOS, "target OS")
    30  		flagArch    = flag.String("arch", runtime.GOARCH, "target arch")
    31  	)
    32  	flag.Parse()
    33  	args := flag.Args()
    34  	if len(args) == 0 {
    35  		usage()
    36  	}
    37  	if args[0] == "bench" {
    38  		if len(args) != 2 {
    39  			usage()
    40  		}
    41  		target, err := prog.GetTarget(*flagOS, *flagArch)
    42  		if err != nil {
    43  			tool.Failf("failed to find target: %v", err)
    44  		}
    45  		bench(target, args[1])
    46  		return
    47  	}
    48  	var target *prog.Target
    49  	if *flagOS != "" || *flagArch != "" {
    50  		var err error
    51  		target, err = prog.GetTarget(*flagOS, *flagArch)
    52  		if err != nil {
    53  			tool.Failf("failed to find target: %v", err)
    54  		}
    55  	}
    56  	switch args[0] {
    57  	case "pack":
    58  		if len(args) != 3 {
    59  			usage()
    60  		}
    61  		pack(args[1], args[2], target, *flagVersion)
    62  	case "unpack":
    63  		if len(args) != 3 {
    64  			usage()
    65  		}
    66  		unpack(args[1], args[2])
    67  	case "merge":
    68  		if len(args) < 3 {
    69  			usage()
    70  		}
    71  		merge(args[1], args[2:], target)
    72  	case "print":
    73  		if len(args) != 2 {
    74  			usage()
    75  		}
    76  		print(args[1])
    77  	case "rm":
    78  		if len(args) != 3 {
    79  			usage()
    80  		}
    81  		rm(args[1], args[2], target)
    82  	default:
    83  		usage()
    84  	}
    85  }
    86  
    87  func usage() {
    88  	fmt.Fprintf(os.Stderr, `usage: syz-db can be used to manipulate corpus
    89  databases that are used by syz-managers. The following generic arguments are
    90  offered:
    91    -arch string
    92    -os string
    93    -version uint
    94    -vv int
    95  
    96    they can be used for:
    97    packing a database:
    98      syz-db pack dir corpus.db
    99    unpacking a database. A file containing performed syscalls will be returned:
   100      syz-db unpack corpus.db dir
   101    merging databases. No additional file will be created: The first file will be replaced by the merged result:
   102      syz-db merge dst-corpus.db add-corpus.db* add-prog*
   103    running a deserialization benchmark and printing corpus stats:
   104      syz-db bench corpus.db
   105    print corpus db:
   106      syz-db print corpus.db
   107    remove a syscall from db
   108      syz-db rm corpus.db syscall_name
   109  `)
   110  	os.Exit(1)
   111  }
   112  
   113  func pack(dir, file string, target *prog.Target, version uint64) {
   114  	files, err := os.ReadDir(dir)
   115  	if err != nil {
   116  		tool.Failf("failed to read dir: %v", err)
   117  	}
   118  	var records []db.Record
   119  	for _, file := range files {
   120  		data, err := os.ReadFile(filepath.Join(dir, file.Name()))
   121  		if err != nil {
   122  			tool.Failf("failed to read file %v: %v", file.Name(), err)
   123  		}
   124  		var seq uint64
   125  		key := file.Name()
   126  		if parts := strings.Split(file.Name(), "-"); len(parts) == 2 {
   127  			var err error
   128  			if seq, err = strconv.ParseUint(parts[1], 10, 64); err == nil {
   129  				key = parts[0]
   130  			}
   131  		}
   132  		if sig := hash.String(data); key != sig {
   133  			if target != nil {
   134  				p, err := target.Deserialize(data, prog.NonStrict)
   135  				if err != nil {
   136  					tool.Failf("failed to deserialize %v: %v", file.Name(), err)
   137  				}
   138  				data = p.Serialize()
   139  				sig = hash.String(data)
   140  			}
   141  			fmt.Fprintf(os.Stderr, "fixing hash %v -> %v\n", key, sig)
   142  			key = sig
   143  		}
   144  		records = append(records, db.Record{
   145  			Val: data,
   146  			Seq: seq,
   147  		})
   148  	}
   149  	if err := db.Create(file, version, records); err != nil {
   150  		tool.Fail(err)
   151  	}
   152  }
   153  
   154  func unpack(file, dir string) {
   155  	db, err := db.Open(file, false)
   156  	if err != nil {
   157  		tool.Failf("failed to open database: %v", err)
   158  	}
   159  	osutil.MkdirAll(dir)
   160  	for key, rec := range db.Records {
   161  		fname := filepath.Join(dir, key)
   162  		if rec.Seq != 0 {
   163  			fname += fmt.Sprintf("-%v", rec.Seq)
   164  		}
   165  		if err := osutil.WriteFile(fname, rec.Val); err != nil {
   166  			tool.Failf("failed to output file: %v", err)
   167  		}
   168  	}
   169  }
   170  
   171  func merge(file string, adds []string, target *prog.Target) {
   172  	failures, err := db.Merge(file, adds, target)
   173  	if err != nil {
   174  		tool.Failf("%s", err)
   175  	}
   176  	if len(failures) > 0 {
   177  		for _, fail := range failures {
   178  			fmt.Printf("failed to deserialize a record from %s: %s\n", fail.File, fail.Err)
   179  		}
   180  		tool.Failf("there have been deserialization errors")
   181  	}
   182  }
   183  
   184  func bench(target *prog.Target, file string) {
   185  	start := time.Now()
   186  	db, err := db.Open(file, false)
   187  	if err != nil {
   188  		tool.Failf("failed to open database: %v", err)
   189  	}
   190  	var corpus []*prog.Prog
   191  	for _, rec := range db.Records {
   192  		p, err := target.Deserialize(rec.Val, prog.NonStrict)
   193  		if err != nil {
   194  			tool.Failf("failed to deserialize: %v\n%s", err, rec.Val)
   195  		}
   196  		corpus = append(corpus, p)
   197  	}
   198  	runtime.GC()
   199  	var stats runtime.MemStats
   200  	runtime.ReadMemStats(&stats)
   201  	fmt.Printf("allocs %v MB (%v M), next GC %v MB, sys heap %v MB, live allocs %v MB (%v M), time %v\n",
   202  		stats.TotalAlloc>>20,
   203  		stats.Mallocs>>20,
   204  		stats.NextGC>>20,
   205  		stats.HeapSys>>20,
   206  		stats.Alloc>>20,
   207  		(stats.Mallocs-stats.Frees)>>20,
   208  		time.Since(start))
   209  	n := len(corpus)
   210  	fmt.Printf("corpus size: %v\n", n)
   211  	if n == 0 {
   212  		return
   213  	}
   214  	sum := 0
   215  	lens := make([]int, n)
   216  	for i, p := range corpus {
   217  		sum += len(p.Calls)
   218  		lens[i] = len(p.Calls)
   219  	}
   220  	sort.Ints(lens)
   221  	fmt.Printf("program size: min=%v avg=%v max=%v 10%%=%v 50%%=%v 90%%=%v\n",
   222  		lens[0], sum/n, lens[n-1], lens[n/10], lens[n/2], lens[n*9/10])
   223  }
   224  
   225  func print(file string) {
   226  	db, err := db.Open(file, false)
   227  	if err != nil {
   228  		tool.Failf("failed to open database: %v", err)
   229  	}
   230  	keys := maps.Keys(db.Records)
   231  	sort.Strings(keys)
   232  	for _, key := range keys {
   233  		rec := db.Records[key]
   234  		fmt.Printf("%v\n%v\n", key, string(rec.Val))
   235  	}
   236  }
   237  
   238  func rm(file, syscall string, target *prog.Target) {
   239  	db, err := db.Open(file, false)
   240  	if err != nil {
   241  		tool.Failf("failed to open database: %w", err)
   242  	}
   243  	for key, rec := range db.Records {
   244  		p, err := target.Deserialize(rec.Val, prog.NonStrict)
   245  		if err != nil {
   246  			tool.Failf("failed to deserialize: %w\n%s", err, rec.Val)
   247  		}
   248  		for i := len(p.Calls) - 1; i >= 0; i-- {
   249  			if strings.Contains(p.Calls[i].Meta.Name, syscall) {
   250  				p.RemoveCall(i)
   251  			}
   252  		}
   253  		data := p.Serialize()
   254  		if len(data) > 0 {
   255  			db.Save(key, data, rec.Seq)
   256  		} else {
   257  			delete(db.Records, key)
   258  		}
   259  	}
   260  	if err := db.Flush(); err != nil {
   261  		tool.Fail(err)
   262  	}
   263  }