github.com/cilium/statedb@v0.3.2/fuzz_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package statedb_test
     5  
     6  import (
     7  	"flag"
     8  	"fmt"
     9  	"log"
    10  	"maps"
    11  	"math/rand"
    12  	"os"
    13  	"runtime"
    14  	"slices"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/cilium/statedb"
    20  	"github.com/cilium/statedb/index"
    21  	"github.com/stretchr/testify/require"
    22  )
    23  
    24  // Run test with "--debug" for log output.
    25  var debug = flag.Bool("debug", false, "Enable debug logging")
    26  
    27  type debugLogger struct {
    28  	l *log.Logger
    29  }
    30  
    31  func (l *debugLogger) log(fmt string, args ...any) {
    32  	if l == nil {
    33  		return
    34  	}
    35  	l.l.Printf(fmt, args...)
    36  }
    37  
    38  func newDebugLogger(worker int) *debugLogger {
    39  	if !*debug {
    40  		return nil
    41  	}
    42  	logger := log.New(os.Stdout, fmt.Sprintf("worker[%03d] | ", worker), 0)
    43  	return &debugLogger{logger}
    44  }
    45  
    46  const (
    47  	numUniqueIDs    = 3000
    48  	numUniqueValues = 2000
    49  	numWorkers      = 20
    50  	numTrackers     = 5
    51  	numIterations   = 1000
    52  )
    53  
    54  type fuzzObj struct {
    55  	id    string
    56  	value uint64
    57  }
    58  
    59  func mkID() string {
    60  	// We use a string hex presentation instead of the raw uint64 so we get
    61  	// a wide range of different length keys and different prefixes.
    62  	return fmt.Sprintf("%x", 1+uint64(rand.Int63n(numUniqueIDs)))
    63  }
    64  
    65  func mkValue() uint64 {
    66  	return 1 + uint64(rand.Int63n(numUniqueValues))
    67  }
    68  
    69  var idIndex = statedb.Index[fuzzObj, string]{
    70  	Name: "id",
    71  	FromObject: func(obj fuzzObj) index.KeySet {
    72  		return index.NewKeySet(index.String(obj.id))
    73  	},
    74  	FromKey: index.String,
    75  	Unique:  true,
    76  }
    77  
    78  var valueIndex = statedb.Index[fuzzObj, uint64]{
    79  	Name: "value",
    80  	FromObject: func(obj fuzzObj) index.KeySet {
    81  		return index.NewKeySet(index.Uint64(obj.value))
    82  	},
    83  	FromKey: index.Uint64,
    84  	Unique:  false,
    85  }
    86  
    87  var (
    88  	tableFuzz1  = statedb.MustNewTable("fuzz1", idIndex, valueIndex)
    89  	tableFuzz2  = statedb.MustNewTable("fuzz2", idIndex, valueIndex)
    90  	tableFuzz3  = statedb.MustNewTable("fuzz3", idIndex, valueIndex)
    91  	tableFuzz4  = statedb.MustNewTable("fuzz4", idIndex, valueIndex)
    92  	fuzzTables  = []statedb.TableMeta{tableFuzz1, tableFuzz2, tableFuzz3, tableFuzz4}
    93  	fuzzMetrics = statedb.NewExpVarMetrics(false)
    94  	fuzzDB      *statedb.DB
    95  )
    96  
    97  func randomSubset[T any](xs []T) []T {
    98  	xs = slices.Clone(xs)
    99  	rand.Shuffle(len(xs), func(i, j int) {
   100  		xs[i], xs[j] = xs[j], xs[i]
   101  	})
   102  	// Pick random subset
   103  	n := 1 + rand.Intn(len(xs))
   104  	return xs[:n]
   105  }
   106  
   107  type actionLog interface {
   108  	append(actionLogEntry)
   109  	validateTable(txn statedb.ReadTxn, table statedb.Table[fuzzObj]) error
   110  }
   111  
   112  type realActionLog struct {
   113  	sync.Mutex
   114  	log map[string][]actionLogEntry
   115  }
   116  
   117  func (a *realActionLog) append(e actionLogEntry) {
   118  	a.Lock()
   119  	a.log[e.table.Name()] = append(a.log[e.table.Name()], e)
   120  	a.Unlock()
   121  }
   122  
   123  func (a *realActionLog) validateTable(txn statedb.ReadTxn, table statedb.Table[fuzzObj]) error {
   124  	a.Lock()
   125  	defer a.Unlock()
   126  
   127  	// Collapse the log down to objects that are alive at the end.
   128  	alive := map[string]struct{}{}
   129  	for _, e := range a.log[table.Name()] {
   130  		switch e.act {
   131  		case actInsert:
   132  			alive[e.id] = struct{}{}
   133  		case actDelete:
   134  			delete(alive, e.id)
   135  		case actDeleteAll:
   136  			clear(alive)
   137  		}
   138  	}
   139  
   140  	// Since everything was deleted we can clear the log entries for this table now
   141  	a.log[table.Name()] = nil
   142  
   143  	actual := map[string]struct{}{}
   144  	for obj := range table.All(txn) {
   145  		actual[obj.id] = struct{}{}
   146  	}
   147  	diff := setSymmetricDifference(actual, alive)
   148  	if len(diff) != 0 {
   149  		return fmt.Errorf("validate failed, mismatching ids: %v", maps.Keys(diff))
   150  	}
   151  	return nil
   152  }
   153  
   154  func setSymmetricDifference[T comparable, M map[T]struct{}](s1, s2 M) M {
   155  	counts := make(map[T]int, len(s1)+len(s2))
   156  	for k1 := range s1 {
   157  		counts[k1] = 1
   158  	}
   159  	for k2 := range s2 {
   160  		counts[k2]++
   161  	}
   162  	result := M{}
   163  	for k, count := range counts {
   164  		if count == 1 {
   165  			result[k] = struct{}{}
   166  		}
   167  	}
   168  	return result
   169  }
   170  
   171  type nopActionLog struct {
   172  }
   173  
   174  func (nopActionLog) append(e actionLogEntry) {}
   175  
   176  func (nopActionLog) validateTable(txn statedb.ReadTxn, table statedb.Table[fuzzObj]) error {
   177  	return nil
   178  }
   179  
   180  const (
   181  	actInsert = iota
   182  	actDelete
   183  	actDeleteAll
   184  )
   185  
   186  type actionLogEntry struct {
   187  	table statedb.Table[fuzzObj]
   188  	act   int
   189  	id    string
   190  	value uint64
   191  }
   192  
   193  type tableAndID struct {
   194  	table string
   195  	id    string
   196  }
   197  
   198  type txnActionLog struct {
   199  	latest map[tableAndID]actionLogEntry
   200  }
   201  
   202  type actionContext struct {
   203  	t      *testing.T
   204  	log    *debugLogger
   205  	actLog actionLog
   206  	txnLog *txnActionLog
   207  	txn    statedb.WriteTxn
   208  	table  statedb.RWTable[fuzzObj]
   209  }
   210  
   211  type action func(ctx actionContext)
   212  
   213  func insertAction(ctx actionContext) {
   214  	id := mkID()
   215  	value := mkValue()
   216  	ctx.log.log("%s: Insert %s", ctx.table.Name(), id)
   217  	ctx.table.Insert(ctx.txn, fuzzObj{id, value})
   218  	e := actionLogEntry{ctx.table, actInsert, id, value}
   219  	ctx.actLog.append(e)
   220  	ctx.txnLog.latest[tableAndID{ctx.table.Name(), id}] = e
   221  }
   222  
   223  func deleteAction(ctx actionContext) {
   224  	id := mkID()
   225  	ctx.log.log("%s: Delete %s", ctx.table.Name(), id)
   226  	ctx.table.Delete(ctx.txn, fuzzObj{id, 0})
   227  	e := actionLogEntry{ctx.table, actDelete, id, 0}
   228  	ctx.actLog.append(e)
   229  	ctx.txnLog.latest[tableAndID{ctx.table.Name(), id}] = e
   230  }
   231  
   232  func deleteAllAction(ctx actionContext) {
   233  	ctx.log.log("%s: DeleteAll", ctx.table.Name())
   234  
   235  	// Validate the log before objects are wiped.
   236  	if err := ctx.actLog.validateTable(ctx.txn, ctx.table); err != nil {
   237  		panic(err)
   238  	}
   239  	ctx.table.DeleteAll(ctx.txn)
   240  	ctx.actLog.append(actionLogEntry{ctx.table, actDeleteAll, "", 0})
   241  	clear(ctx.txnLog.latest)
   242  }
   243  
   244  func deleteManyAction(ctx actionContext) {
   245  	// Delete third of the objects using iteration to test that
   246  	// nothing bad happens when the iterator is used while deleting.
   247  	toDelete := ctx.table.NumObjects(ctx.txn) / 3
   248  
   249  	n := 0
   250  	for obj := range ctx.table.All(ctx.txn) {
   251  		ctx.log.log("%s: DeleteMany %s (%d/%d)", ctx.table.Name(), obj.id, n+1, toDelete)
   252  		_, hadOld, _ := ctx.table.Delete(ctx.txn, obj)
   253  		if !hadOld {
   254  			panic("expected Delete of a known object to return the old object")
   255  		}
   256  		e := actionLogEntry{ctx.table, actDelete, obj.id, 0}
   257  		ctx.actLog.append(e)
   258  		ctx.txnLog.latest[tableAndID{ctx.table.Name(), obj.id}] = e
   259  
   260  		n++
   261  		if n >= toDelete {
   262  			break
   263  		}
   264  	}
   265  }
   266  
   267  func allAction(ctx actionContext) {
   268  	iter := ctx.table.All(ctx.txn)
   269  	ctx.log.log("%s: All => %d found", ctx.table.Name(), len(statedb.Collect(iter)))
   270  }
   271  
   272  func listAction(ctx actionContext) {
   273  	value := mkValue()
   274  	values := ctx.table.List(ctx.txn, valueIndex.Query(value))
   275  	ctx.log.log("%s: List(%d)", ctx.table.Name(), value)
   276  	for obj := range values {
   277  		if e, ok2 := ctx.txnLog.latest[tableAndID{ctx.table.Name(), obj.id}]; ok2 {
   278  			if e.act == actInsert {
   279  				if e.value != obj.value {
   280  					panic("List() did not return the last write")
   281  				}
   282  				if obj.value != value {
   283  					panic(fmt.Sprintf("Get() returned object with wrong value, expected %d, got %d", value, obj.value))
   284  				}
   285  			} else if e.act == actDelete {
   286  				panic("List() returned value even though it was deleted")
   287  			}
   288  		}
   289  	}
   290  }
   291  
   292  func getAction(ctx actionContext) {
   293  	id := mkID()
   294  	obj, rev, ok := ctx.table.Get(ctx.txn, idIndex.Query(id))
   295  
   296  	if e, ok2 := ctx.txnLog.latest[tableAndID{ctx.table.Name(), id}]; ok2 {
   297  		if e.act == actInsert {
   298  			if !ok {
   299  				panic("Get() returned not found, expected last inserted value")
   300  			}
   301  			if e.value != obj.value {
   302  				panic("Get() did not return the last write")
   303  			}
   304  		} else if e.act == actDelete {
   305  			if ok {
   306  				panic("Get() returned value even though it was deleted")
   307  			}
   308  		}
   309  	}
   310  	ctx.log.log("%s: Get(%s) => rev=%d, ok=%v", ctx.table.Name(), id, rev, ok)
   311  }
   312  
   313  func lowerboundAction(ctx actionContext) {
   314  	id := mkID()
   315  	iter, _ := ctx.table.LowerBoundWatch(ctx.txn, idIndex.Query(id))
   316  	ctx.log.log("%s: LowerBound(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
   317  }
   318  
   319  func prefixAction(ctx actionContext) {
   320  	id := mkID()
   321  	iter := ctx.table.Prefix(ctx.txn, idIndex.Query(id))
   322  	ctx.log.log("%s: Prefix(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
   323  }
   324  
   325  var actions = []action{
   326  	// Make inserts much more likely than deletions to build up larger tables.
   327  	insertAction, insertAction, insertAction, insertAction, insertAction,
   328  	insertAction, insertAction, insertAction, insertAction, insertAction,
   329  	insertAction, insertAction, insertAction, insertAction, insertAction,
   330  	insertAction, insertAction, insertAction, insertAction, insertAction,
   331  	insertAction, insertAction, insertAction, insertAction, insertAction,
   332  	insertAction, insertAction, insertAction, insertAction, insertAction,
   333  	insertAction, insertAction, insertAction, insertAction, insertAction,
   334  	insertAction, insertAction, insertAction, insertAction, insertAction,
   335  	insertAction, insertAction, insertAction, insertAction, insertAction,
   336  	insertAction, insertAction, insertAction, insertAction, insertAction,
   337  	insertAction, insertAction, insertAction, insertAction, insertAction,
   338  	insertAction, insertAction, insertAction, insertAction, insertAction,
   339  	insertAction, insertAction, insertAction, insertAction, insertAction,
   340  	insertAction, insertAction, insertAction, insertAction, insertAction,
   341  	insertAction, insertAction, insertAction, insertAction, insertAction,
   342  	insertAction, insertAction, insertAction, insertAction, insertAction,
   343  	insertAction, insertAction, insertAction, insertAction, insertAction,
   344  	insertAction, insertAction, insertAction, insertAction, insertAction,
   345  	insertAction, insertAction, insertAction, insertAction, insertAction,
   346  	insertAction, insertAction, insertAction, insertAction, insertAction,
   347  	insertAction, insertAction, insertAction, insertAction, insertAction,
   348  	insertAction, insertAction, insertAction, insertAction, insertAction,
   349  	insertAction, insertAction, insertAction, insertAction, insertAction,
   350  
   351  	deleteAction, deleteAction, deleteAction,
   352  	deleteManyAction, deleteAllAction,
   353  
   354  	getAction, getAction, getAction, getAction, getAction,
   355  	getAction, getAction, getAction, getAction, getAction,
   356  	getAction, getAction, getAction, getAction, getAction,
   357  	listAction, listAction, listAction, listAction, listAction,
   358  	allAction, allAction,
   359  	lowerboundAction, lowerboundAction, lowerboundAction,
   360  	prefixAction, prefixAction, prefixAction,
   361  }
   362  
   363  func randomAction() action {
   364  	return actions[rand.Intn(len(actions))]
   365  }
   366  
   367  func trackerWorker(i int, stop <-chan struct{}) {
   368  	log := newDebugLogger(900 + i)
   369  	wtxn := fuzzDB.WriteTxn(tableFuzz1)
   370  	iter, err := tableFuzz1.Changes(wtxn)
   371  	wtxn.Commit()
   372  	if err != nil {
   373  		panic(err)
   374  	}
   375  
   376  	// Keep track of what state the changes lead us to in order to validate it.
   377  	state := map[string]*statedb.Change[fuzzObj]{}
   378  
   379  	var txn statedb.ReadTxn
   380  	var prevRev statedb.Revision
   381  	for {
   382  		newChanges := false
   383  		txn = fuzzDB.ReadTxn()
   384  		changes, watch := iter.Next(txn)
   385  		for change, rev := range changes {
   386  			newChanges = true
   387  			log.log("%d: %v", rev, change)
   388  
   389  			if rev != change.Revision {
   390  				panic(fmt.Sprintf("trackerWorker: event.Revision mismatch with actual revision: %d vs %d", change.Revision, rev))
   391  			}
   392  
   393  			if rev <= prevRev {
   394  				panic(fmt.Sprintf("trackerWorker: revisions went backwards %d <= %d: %v", rev, prevRev, change))
   395  			}
   396  			prevRev = rev
   397  
   398  			if change.Object.id == "" || change.Object.value == 0 {
   399  				panic("trackerWorker: object with zero id/value")
   400  			}
   401  
   402  			if change.Deleted {
   403  				delete(state, change.Object.id)
   404  			} else {
   405  				change := change
   406  				state[change.Object.id] = &change
   407  			}
   408  		}
   409  
   410  		if txn != nil && newChanges {
   411  			// Validate that the observed changes match with the database state at this
   412  			// snapshot.
   413  			state2 := maps.Clone(state)
   414  			allObjects := tableFuzz1.LowerBound(txn, statedb.ByRevision[fuzzObj](0))
   415  			for obj, rev := range allObjects {
   416  				change, found := state[obj.id]
   417  				if !found {
   418  					panic(fmt.Sprintf("trackerWorker: object %s not found from state", obj.id))
   419  				}
   420  
   421  				if change.Revision != rev {
   422  					panic(fmt.Sprintf("trackerWorker: last observed revision %d does not match real revision %d", change.Revision, rev))
   423  				}
   424  
   425  				if change.Object.value != obj.value {
   426  					panic(fmt.Sprintf("trackerWorker: observed value %d does not match real value %d", change.Object.value, obj.value))
   427  				}
   428  				delete(state2, obj.id)
   429  			}
   430  
   431  			if len(state2) > 0 {
   432  				for id := range state2 {
   433  					log.log("%s should not exist\n", id)
   434  				}
   435  				panic(fmt.Sprintf("trackerWorker: %d orphan object(s)", len(state2)))
   436  			}
   437  		}
   438  
   439  		select {
   440  		case <-watch:
   441  		case <-stop:
   442  			log.log("final object count %d", len(state))
   443  			return
   444  		}
   445  	}
   446  }
   447  
   448  func fuzzWorker(realActionLog *realActionLog, worker int, iterations int) {
   449  	log := newDebugLogger(worker)
   450  	for iterations > 0 {
   451  		targets := randomSubset(fuzzTables)
   452  		txn := fuzzDB.WriteTxn(targets[0], targets[1:]...)
   453  		txnActionLog := &txnActionLog{
   454  			latest: map[tableAndID]actionLogEntry{},
   455  		}
   456  
   457  		// Try to run other goroutines with write lock held.
   458  		runtime.Gosched()
   459  
   460  		var actLog actionLog = realActionLog
   461  		abort := false
   462  		if rand.Intn(10) == 0 {
   463  			abort = true
   464  			actLog = nopActionLog{}
   465  		}
   466  
   467  		for _, target := range targets {
   468  			ctx := actionContext{
   469  				log:    log,
   470  				actLog: actLog,
   471  				txnLog: txnActionLog,
   472  				txn:    txn,
   473  				table:  target.(statedb.RWTable[fuzzObj]),
   474  			}
   475  			numActs := rand.Intn(20)
   476  			for i := 0; i < numActs; i++ {
   477  				randomAction()(ctx)
   478  				runtime.Gosched()
   479  			}
   480  		}
   481  		runtime.Gosched()
   482  
   483  		if abort {
   484  			log.log("Abort")
   485  			txn.Abort()
   486  		} else {
   487  			log.log("Commit")
   488  			txn.Commit()
   489  		}
   490  		iterations--
   491  	}
   492  }
   493  
   494  func TestDB_Fuzz(t *testing.T) {
   495  	t.Parallel()
   496  
   497  	fuzzDB = statedb.New(statedb.WithMetrics(fuzzMetrics))
   498  	for _, tbl := range fuzzTables {
   499  		require.NoError(t, fuzzDB.RegisterTable(tbl))
   500  	}
   501  
   502  	fuzzDB.Start()
   503  	defer fuzzDB.Stop()
   504  
   505  	actionLog := &realActionLog{
   506  		log: map[string][]actionLogEntry{},
   507  	}
   508  
   509  	// Start workers to mutate the tables.
   510  	var wg sync.WaitGroup
   511  	wg.Add(numWorkers)
   512  	for i := 0; i < numWorkers; i++ {
   513  		i := i
   514  		go func() {
   515  			fuzzWorker(actionLog, i, numIterations)
   516  			wg.Done()
   517  		}()
   518  	}
   519  
   520  	// Start change trackers to observe changes.
   521  	stop := make(chan struct{})
   522  	var wg2 sync.WaitGroup
   523  	wg2.Add(numTrackers)
   524  	for i := 0; i < numTrackers; i++ {
   525  		i := i
   526  		go func() {
   527  			trackerWorker(i, stop)
   528  			wg2.Done()
   529  		}()
   530  		// Delay a bit to start the trackers at different points in time
   531  		// so they will observe a different starting state.
   532  		time.Sleep(500 * time.Millisecond)
   533  	}
   534  
   535  	// Wait until the mutation workers stop and then stop
   536  	// the change observers.
   537  	wg.Wait()
   538  	close(stop)
   539  	wg2.Wait()
   540  
   541  	for _, table := range []statedb.Table[fuzzObj]{tableFuzz1, tableFuzz2, tableFuzz3, tableFuzz4} {
   542  		if err := actionLog.validateTable(fuzzDB.ReadTxn(), table); err != nil {
   543  			t.Fatal(err)
   544  		}
   545  	}
   546  
   547  	t.Logf("metrics:\n%s\n", fuzzMetrics.String())
   548  }