github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/db/db.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // Package db implements a simple key-value database.
     5  // The database is cached in memory and mirrored on disk.
     6  // It is used to store corpus in syz-manager and syz-hub.
     7  // The database strives to minimize number of disk accesses
     8  // as they can be slow in virtualized environments (GCE).
     9  package db
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"compress/flate"
    15  	"encoding/binary"
    16  	"fmt"
    17  	"io"
    18  	"os"
    19  	"sort"
    20  
    21  	"github.com/google/syzkaller/pkg/hash"
    22  	"github.com/google/syzkaller/pkg/osutil"
    23  	"github.com/google/syzkaller/prog"
    24  )
    25  
    26  type DB struct {
    27  	Version uint64            // arbitrary user version (0 for new database)
    28  	Records map[string]Record // in-memory cache, must not be modified directly
    29  
    30  	filename      string
    31  	uncompacted   int           // number of records in the file
    32  	pending       *bytes.Buffer // pending writes to the file
    33  	dataDiscarded bool
    34  }
    35  
    36  type Record struct {
    37  	Val []byte
    38  	Seq uint64
    39  }
    40  
    41  // Open opens the specified database file.
    42  // If the database is corrupted and reading failed, then it returns an non-nil db
    43  // with whatever records were recovered and a non-nil error at the same time.
    44  func Open(filename string, repair bool) (*DB, error) {
    45  	db := &DB{
    46  		filename: filename,
    47  	}
    48  	var deserializeErr error
    49  	db.Version, db.Records, db.uncompacted, deserializeErr = deserializeFile(db.filename)
    50  	// Deserialization error is considered a "soft" error if repair == true,
    51  	// but compact below ensures that the file is at least writable.
    52  	if deserializeErr != nil && !repair {
    53  		return nil, deserializeErr
    54  	}
    55  	if err := db.compact(); err != nil {
    56  		return nil, err
    57  	}
    58  	return db, deserializeErr
    59  }
    60  
    61  func (db *DB) Save(key string, val []byte, seq uint64) {
    62  	if seq == seqDeleted {
    63  		panic("reserved seq")
    64  	}
    65  	// If data is discarded, we assume key identifies data (data hash).
    66  	if rec, ok := db.Records[key]; ok && seq == rec.Seq && (db.dataDiscarded || bytes.Equal(val, rec.Val)) {
    67  		return
    68  	}
    69  	db.serialize(key, val, seq)
    70  	if db.dataDiscarded {
    71  		val = nil
    72  	}
    73  	db.Records[key] = Record{val, seq}
    74  	db.uncompacted++
    75  }
    76  
    77  func (db *DB) Delete(key string) {
    78  	if _, ok := db.Records[key]; !ok {
    79  		return
    80  	}
    81  	delete(db.Records, key)
    82  	db.serialize(key, nil, seqDeleted)
    83  	db.uncompacted++
    84  }
    85  
    86  // DiscardData discards all record's values from memory.
    87  // This allows to save memory if values are not needed anymore,
    88  // but in exchange every compaction will need to re-read all data from disk.
    89  func (db *DB) DiscardData() {
    90  	db.dataDiscarded = true
    91  	for key, rec := range db.Records {
    92  		rec.Val = nil
    93  		db.Records[key] = rec
    94  	}
    95  }
    96  
    97  func (db *DB) Flush() error {
    98  	if db.pending == nil {
    99  		return nil
   100  	}
   101  	f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, osutil.DefaultFilePerm)
   102  	if err != nil {
   103  		return err
   104  	}
   105  	defer f.Close()
   106  	if _, err := f.Write(db.pending.Bytes()); err != nil {
   107  		return err
   108  	}
   109  	db.pending = nil
   110  	if db.uncompacted/10*9 < len(db.Records) {
   111  		return nil
   112  	}
   113  	return db.compact()
   114  }
   115  
   116  func (db *DB) BumpVersion(version uint64) error {
   117  	if err := db.Flush(); err != nil {
   118  		return err
   119  	}
   120  	if db.Version == version {
   121  		return nil
   122  	}
   123  	db.Version = version
   124  	return db.compact()
   125  }
   126  
   127  func (db *DB) compact() error {
   128  	if db.pending != nil {
   129  		panic("compacting with pending records")
   130  	}
   131  	records := db.Records
   132  	if db.dataDiscarded {
   133  		var err error
   134  		_, records, _, err = deserializeFile(db.filename)
   135  		if err != nil {
   136  			return err
   137  		}
   138  	}
   139  	buf := new(bytes.Buffer)
   140  	serializeHeader(buf, db.Version)
   141  	for key, rec := range records {
   142  		serializeRecord(buf, key, rec.Val, rec.Seq)
   143  	}
   144  	f, err := os.Create(db.filename + ".tmp")
   145  	if err != nil {
   146  		return err
   147  	}
   148  	defer f.Close()
   149  	if _, err := f.Write(buf.Bytes()); err != nil {
   150  		return err
   151  	}
   152  	f.Close()
   153  	if err := osutil.Rename(f.Name(), db.filename); err != nil {
   154  		return err
   155  	}
   156  	db.uncompacted = len(records)
   157  	return nil
   158  }
   159  
   160  func (db *DB) serialize(key string, val []byte, seq uint64) {
   161  	if db.pending == nil {
   162  		db.pending = new(bytes.Buffer)
   163  	}
   164  	serializeRecord(db.pending, key, val, seq)
   165  }
   166  
   167  const (
   168  	dbMagic    = uint32(0xbaddb)
   169  	recMagic   = uint32(0xfee1bad)
   170  	curVersion = uint32(2)
   171  	seqDeleted = ^uint64(0)
   172  )
   173  
   174  func serializeHeader(w *bytes.Buffer, version uint64) {
   175  	binary.Write(w, binary.LittleEndian, dbMagic)
   176  	binary.Write(w, binary.LittleEndian, curVersion)
   177  	binary.Write(w, binary.LittleEndian, version)
   178  }
   179  
   180  func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) {
   181  	binary.Write(w, binary.LittleEndian, recMagic)
   182  	binary.Write(w, binary.LittleEndian, uint32(len(key)))
   183  	w.WriteString(key)
   184  	binary.Write(w, binary.LittleEndian, seq)
   185  	if seq == seqDeleted {
   186  		if len(val) != 0 {
   187  			panic("deleting record with value")
   188  		}
   189  		return
   190  	}
   191  	if len(val) == 0 {
   192  		binary.Write(w, binary.LittleEndian, uint32(len(val)))
   193  	} else {
   194  		lenPos := len(w.Bytes())
   195  		binary.Write(w, binary.LittleEndian, uint32(0))
   196  		startPos := len(w.Bytes())
   197  		fw, err := flate.NewWriter(w, flate.BestCompression)
   198  		if err != nil {
   199  			panic(err)
   200  		}
   201  		if _, err := fw.Write(val); err != nil {
   202  			panic(err)
   203  		}
   204  		fw.Close()
   205  		binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos))
   206  	}
   207  }
   208  
   209  func deserializeFile(filename string) (version uint64, records map[string]Record, uncompacted int, err error) {
   210  	f, err := os.OpenFile(filename, os.O_RDONLY|os.O_CREATE, osutil.DefaultFilePerm)
   211  	if err != nil {
   212  		return 0, nil, 0, err
   213  	}
   214  	defer f.Close()
   215  	return deserializeDB(bufio.NewReader(f))
   216  }
   217  
   218  func deserializeDB(r *bufio.Reader) (version uint64, records map[string]Record, uncompacted int, err0 error) {
   219  	records = make(map[string]Record)
   220  	ver, err := deserializeHeader(r)
   221  	if err != nil {
   222  		err0 = fmt.Errorf("failed to deserialize database header: %w", err)
   223  		return
   224  	}
   225  	version = ver
   226  	for {
   227  		key, val, seq, err := deserializeRecord(r)
   228  		if err == io.EOF {
   229  			return
   230  		}
   231  		if err != nil {
   232  			err0 = fmt.Errorf("failed to deserialize database record: %w", err)
   233  			return
   234  		}
   235  		uncompacted++
   236  		if seq == seqDeleted {
   237  			delete(records, key)
   238  		} else {
   239  			records[key] = Record{val, seq}
   240  		}
   241  	}
   242  }
   243  
   244  func deserializeHeader(r *bufio.Reader) (uint64, error) {
   245  	var magic, ver uint32
   246  	if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
   247  		if err == io.EOF {
   248  			return 0, nil
   249  		}
   250  		return 0, err
   251  	}
   252  	if magic != dbMagic {
   253  		return 0, fmt.Errorf("bad db header: 0x%x", magic)
   254  	}
   255  	if err := binary.Read(r, binary.LittleEndian, &ver); err != nil {
   256  		return 0, err
   257  	}
   258  	if ver == 0 || ver > curVersion {
   259  		return 0, fmt.Errorf("bad db version: %v", ver)
   260  	}
   261  	var userVer uint64
   262  	if ver >= 2 {
   263  		if err := binary.Read(r, binary.LittleEndian, &userVer); err != nil {
   264  			return 0, err
   265  		}
   266  	}
   267  	return userVer, nil
   268  }
   269  
   270  func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) {
   271  	var magic uint32
   272  	if err = binary.Read(r, binary.LittleEndian, &magic); err != nil {
   273  		return
   274  	}
   275  	if magic != recMagic {
   276  		err = fmt.Errorf("bad record header: 0x%x", magic)
   277  		return
   278  	}
   279  	var keyLen uint32
   280  	if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil {
   281  		return
   282  	}
   283  	keyBuf := make([]byte, keyLen)
   284  	if _, err = io.ReadFull(r, keyBuf); err != nil {
   285  		return
   286  	}
   287  	key = string(keyBuf)
   288  	if err = binary.Read(r, binary.LittleEndian, &seq); err != nil {
   289  		return
   290  	}
   291  	if seq == seqDeleted {
   292  		return
   293  	}
   294  	var valLen uint32
   295  	if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil {
   296  		return
   297  	}
   298  	if valLen != 0 {
   299  		fr := flate.NewReader(&io.LimitedReader{R: r, N: int64(valLen)})
   300  		if val, err = io.ReadAll(fr); err != nil {
   301  			return
   302  		}
   303  		fr.Close()
   304  	}
   305  	return
   306  }
   307  
   308  // Create creates a new database in the specified file with the specified records.
   309  func Create(filename string, version uint64, records []Record) error {
   310  	os.Remove(filename)
   311  	db, err := Open(filename, true)
   312  	if err != nil {
   313  		return fmt.Errorf("failed to open database file: %w", err)
   314  	}
   315  	if err := db.BumpVersion(version); err != nil {
   316  		return fmt.Errorf("failed to bump database version: %w", err)
   317  	}
   318  	for _, rec := range records {
   319  		db.Save(hash.String(rec.Val), rec.Val, rec.Seq)
   320  	}
   321  	if err := db.Flush(); err != nil {
   322  		return fmt.Errorf("failed to save database file: %w", err)
   323  	}
   324  	return nil
   325  }
   326  
   327  func ReadCorpus(filename string, target *prog.Target) (progs []*prog.Prog, err error) {
   328  	if filename == "" {
   329  		return
   330  	}
   331  	db, err := Open(filename, false)
   332  	if err != nil {
   333  		return nil, fmt.Errorf("failed to open database file: %w", err)
   334  	}
   335  	recordKeys := make([]string, 0, len(db.Records))
   336  	for key := range db.Records {
   337  		recordKeys = append(recordKeys, key)
   338  	}
   339  	sort.Strings(recordKeys)
   340  	for _, key := range recordKeys {
   341  		p, err := target.Deserialize(db.Records[key].Val, prog.NonStrict)
   342  		if err != nil {
   343  			return nil, fmt.Errorf("failed to deserialize corpus program: %w", err)
   344  		}
   345  		progs = append(progs, p)
   346  	}
   347  	return progs, nil
   348  }
   349  
   350  type DeserializeFailure struct {
   351  	File string
   352  	Err  error
   353  }
   354  
   355  func Merge(into string, other []string, target *prog.Target) ([]DeserializeFailure, error) {
   356  	dstDB, err := Open(into, false)
   357  	if err != nil {
   358  		return nil, fmt.Errorf("failed to open database: %w", err)
   359  	}
   360  	var failed []DeserializeFailure
   361  	for _, add := range other {
   362  		addDB, err := Open(add, false)
   363  		if err == nil {
   364  			// It's a DB file.
   365  			for key, rec := range addDB.Records {
   366  				dstDB.Save(key, rec.Val, rec.Seq)
   367  			}
   368  			continue
   369  		}
   370  		if target == nil {
   371  			// We were not given a target, so we cannot parse it as a seed file.
   372  			return nil, fmt.Errorf("failed to open db %v: %w", add, err)
   373  		}
   374  		data, err := os.ReadFile(add)
   375  		if err != nil {
   376  			return nil, err
   377  		}
   378  		if _, err := target.Deserialize(data, prog.NonStrict); err != nil {
   379  			failed = append(failed, DeserializeFailure{add, err})
   380  		}
   381  		dstDB.Save(hash.String(data), data, 0)
   382  	}
   383  	if err := dstDB.Flush(); err != nil {
   384  		return nil, fmt.Errorf("failed to save db: %w", err)
   385  	}
   386  	return failed, nil
   387  }