github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/db/db.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // Package db implements a simple key-value database.
     5  // The database is cached in memory and mirrored on disk.
     6  // It is used to store corpus in syz-manager and syz-hub.
     7  // The database strives to minimize number of disk accesses
     8  // as they can be slow in virtualized environments (GCE).
     9  package db
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"compress/flate"
    15  	"encoding/binary"
    16  	"fmt"
    17  	"io"
    18  	"os"
    19  	"sort"
    20  
    21  	"github.com/google/syzkaller/pkg/hash"
    22  	"github.com/google/syzkaller/pkg/osutil"
    23  	"github.com/google/syzkaller/prog"
    24  )
    25  
    26  type DB struct {
    27  	Version uint64            // arbitrary user version (0 for new database)
    28  	Records map[string]Record // in-memory cache, must not be modified directly
    29  
    30  	filename    string
    31  	uncompacted int           // number of records in the file
    32  	pending     *bytes.Buffer // pending writes to the file
    33  }
    34  
    35  type Record struct {
    36  	Val []byte
    37  	Seq uint64
    38  }
    39  
    40  // Open opens the specified database file.
    41  // If the database is corrupted and reading failed, then it returns an non-nil db
    42  // with whatever records were recovered and a non-nil error at the same time.
    43  func Open(filename string, repair bool) (*DB, error) {
    44  	db := &DB{
    45  		filename: filename,
    46  	}
    47  	f, err := os.OpenFile(db.filename, os.O_RDONLY|os.O_CREATE, osutil.DefaultFilePerm)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	defer f.Close()
    52  	var deserializeErr error
    53  	db.Version, db.Records, db.uncompacted, deserializeErr = deserializeDB(bufio.NewReader(f))
    54  	// Deserialization error is considered a "soft" error if repair == true,
    55  	// but compact below ensures that the file is at least writable.
    56  	if deserializeErr != nil && !repair {
    57  		return nil, deserializeErr
    58  	}
    59  	f.Close() // compact will rewrite the file, so close our descriptor
    60  	if err := db.compact(); err != nil {
    61  		return nil, err
    62  	}
    63  	return db, deserializeErr
    64  }
    65  
    66  func (db *DB) Save(key string, val []byte, seq uint64) {
    67  	if seq == seqDeleted {
    68  		panic("reserved seq")
    69  	}
    70  	if rec, ok := db.Records[key]; ok && seq == rec.Seq && bytes.Equal(val, rec.Val) {
    71  		return
    72  	}
    73  	db.Records[key] = Record{val, seq}
    74  	db.serialize(key, val, seq)
    75  	db.uncompacted++
    76  }
    77  
    78  func (db *DB) Delete(key string) {
    79  	if _, ok := db.Records[key]; !ok {
    80  		return
    81  	}
    82  	delete(db.Records, key)
    83  	db.serialize(key, nil, seqDeleted)
    84  	db.uncompacted++
    85  }
    86  
    87  func (db *DB) Flush() error {
    88  	if db.uncompacted/10*9 > len(db.Records) {
    89  		return db.compact()
    90  	}
    91  	if db.pending == nil {
    92  		return nil
    93  	}
    94  	f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, osutil.DefaultFilePerm)
    95  	if err != nil {
    96  		return err
    97  	}
    98  	defer f.Close()
    99  	if _, err := f.Write(db.pending.Bytes()); err != nil {
   100  		return err
   101  	}
   102  	db.pending = nil
   103  	return nil
   104  }
   105  
   106  func (db *DB) BumpVersion(version uint64) error {
   107  	if db.Version == version {
   108  		return db.Flush()
   109  	}
   110  	db.Version = version
   111  	return db.compact()
   112  }
   113  
   114  func (db *DB) compact() error {
   115  	buf := new(bytes.Buffer)
   116  	serializeHeader(buf, db.Version)
   117  	for key, rec := range db.Records {
   118  		serializeRecord(buf, key, rec.Val, rec.Seq)
   119  	}
   120  	f, err := os.Create(db.filename + ".tmp")
   121  	if err != nil {
   122  		return err
   123  	}
   124  	defer f.Close()
   125  	if _, err := f.Write(buf.Bytes()); err != nil {
   126  		return err
   127  	}
   128  	f.Close()
   129  	if err := osutil.Rename(f.Name(), db.filename); err != nil {
   130  		return err
   131  	}
   132  	db.uncompacted = len(db.Records)
   133  	db.pending = nil
   134  	return nil
   135  }
   136  
   137  func (db *DB) serialize(key string, val []byte, seq uint64) {
   138  	if db.pending == nil {
   139  		db.pending = new(bytes.Buffer)
   140  	}
   141  	serializeRecord(db.pending, key, val, seq)
   142  }
   143  
   144  const (
   145  	dbMagic    = uint32(0xbaddb)
   146  	recMagic   = uint32(0xfee1bad)
   147  	curVersion = uint32(2)
   148  	seqDeleted = ^uint64(0)
   149  )
   150  
   151  func serializeHeader(w *bytes.Buffer, version uint64) {
   152  	binary.Write(w, binary.LittleEndian, dbMagic)
   153  	binary.Write(w, binary.LittleEndian, curVersion)
   154  	binary.Write(w, binary.LittleEndian, version)
   155  }
   156  
   157  func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) {
   158  	binary.Write(w, binary.LittleEndian, recMagic)
   159  	binary.Write(w, binary.LittleEndian, uint32(len(key)))
   160  	w.WriteString(key)
   161  	binary.Write(w, binary.LittleEndian, seq)
   162  	if seq == seqDeleted {
   163  		if len(val) != 0 {
   164  			panic("deleting record with value")
   165  		}
   166  		return
   167  	}
   168  	if len(val) == 0 {
   169  		binary.Write(w, binary.LittleEndian, uint32(len(val)))
   170  	} else {
   171  		lenPos := len(w.Bytes())
   172  		binary.Write(w, binary.LittleEndian, uint32(0))
   173  		startPos := len(w.Bytes())
   174  		fw, err := flate.NewWriter(w, flate.BestCompression)
   175  		if err != nil {
   176  			panic(err)
   177  		}
   178  		if _, err := fw.Write(val); err != nil {
   179  			panic(err)
   180  		}
   181  		fw.Close()
   182  		binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos))
   183  	}
   184  }
   185  
   186  func deserializeDB(r *bufio.Reader) (version uint64, records map[string]Record, uncompacted int, err0 error) {
   187  	records = make(map[string]Record)
   188  	ver, err := deserializeHeader(r)
   189  	if err != nil {
   190  		err0 = fmt.Errorf("failed to deserialize database header: %w", err)
   191  		return
   192  	}
   193  	version = ver
   194  	for {
   195  		key, val, seq, err := deserializeRecord(r)
   196  		if err == io.EOF {
   197  			return
   198  		}
   199  		if err != nil {
   200  			err0 = fmt.Errorf("failed to deserialize database record: %w", err)
   201  			return
   202  		}
   203  		uncompacted++
   204  		if seq == seqDeleted {
   205  			delete(records, key)
   206  		} else {
   207  			records[key] = Record{val, seq}
   208  		}
   209  	}
   210  }
   211  
   212  func deserializeHeader(r *bufio.Reader) (uint64, error) {
   213  	var magic, ver uint32
   214  	if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
   215  		if err == io.EOF {
   216  			return 0, nil
   217  		}
   218  		return 0, err
   219  	}
   220  	if magic != dbMagic {
   221  		return 0, fmt.Errorf("bad db header: 0x%x", magic)
   222  	}
   223  	if err := binary.Read(r, binary.LittleEndian, &ver); err != nil {
   224  		return 0, err
   225  	}
   226  	if ver == 0 || ver > curVersion {
   227  		return 0, fmt.Errorf("bad db version: %v", ver)
   228  	}
   229  	var userVer uint64
   230  	if ver >= 2 {
   231  		if err := binary.Read(r, binary.LittleEndian, &userVer); err != nil {
   232  			return 0, err
   233  		}
   234  	}
   235  	return userVer, nil
   236  }
   237  
   238  func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) {
   239  	var magic uint32
   240  	if err = binary.Read(r, binary.LittleEndian, &magic); err != nil {
   241  		return
   242  	}
   243  	if magic != recMagic {
   244  		err = fmt.Errorf("bad record header: 0x%x", magic)
   245  		return
   246  	}
   247  	var keyLen uint32
   248  	if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil {
   249  		return
   250  	}
   251  	keyBuf := make([]byte, keyLen)
   252  	if _, err = io.ReadFull(r, keyBuf); err != nil {
   253  		return
   254  	}
   255  	key = string(keyBuf)
   256  	if err = binary.Read(r, binary.LittleEndian, &seq); err != nil {
   257  		return
   258  	}
   259  	if seq == seqDeleted {
   260  		return
   261  	}
   262  	var valLen uint32
   263  	if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil {
   264  		return
   265  	}
   266  	if valLen != 0 {
   267  		fr := flate.NewReader(&io.LimitedReader{R: r, N: int64(valLen)})
   268  		if val, err = io.ReadAll(fr); err != nil {
   269  			return
   270  		}
   271  		fr.Close()
   272  	}
   273  	return
   274  }
   275  
   276  // Create creates a new database in the specified file with the specified records.
   277  func Create(filename string, version uint64, records []Record) error {
   278  	os.Remove(filename)
   279  	db, err := Open(filename, true)
   280  	if err != nil {
   281  		return fmt.Errorf("failed to open database file: %w", err)
   282  	}
   283  	if err := db.BumpVersion(version); err != nil {
   284  		return fmt.Errorf("failed to bump database version: %w", err)
   285  	}
   286  	for _, rec := range records {
   287  		db.Save(hash.String(rec.Val), rec.Val, rec.Seq)
   288  	}
   289  	if err := db.Flush(); err != nil {
   290  		return fmt.Errorf("failed to save database file: %w", err)
   291  	}
   292  	return nil
   293  }
   294  
   295  func ReadCorpus(filename string, target *prog.Target) (progs []*prog.Prog, err error) {
   296  	if filename == "" {
   297  		return
   298  	}
   299  	db, err := Open(filename, false)
   300  	if err != nil {
   301  		return nil, fmt.Errorf("failed to open database file: %w", err)
   302  	}
   303  	recordKeys := make([]string, 0, len(db.Records))
   304  	for key := range db.Records {
   305  		recordKeys = append(recordKeys, key)
   306  	}
   307  	sort.Strings(recordKeys)
   308  	for _, key := range recordKeys {
   309  		p, err := target.Deserialize(db.Records[key].Val, prog.NonStrict)
   310  		if err != nil {
   311  			return nil, fmt.Errorf("failed to deserialize corpus program: %w", err)
   312  		}
   313  		progs = append(progs, p)
   314  	}
   315  	return progs, nil
   316  }