github.com/cilium/statedb@v0.3.2/derive_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package statedb
     5  
     6  import (
     7  	"context"
     8  	"log/slog"
     9  	"slices"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/cilium/hive"
    16  	"github.com/cilium/hive/cell"
    17  	"github.com/cilium/hive/hivetest"
    18  	"github.com/cilium/hive/job"
    19  	"github.com/cilium/statedb/index"
    20  	"github.com/cilium/statedb/part"
    21  )
    22  
    23  type derived struct {
    24  	ID      uint64
    25  	Deleted bool
    26  }
    27  
    28  var derivedIdIndex = Index[derived, uint64]{
    29  	Name: "id",
    30  	FromObject: func(t derived) index.KeySet {
    31  		return index.NewKeySet(index.Uint64(t.ID))
    32  	},
    33  	FromKey: index.Uint64,
    34  	Unique:  true,
    35  }
    36  
    37  type nopHealth struct {
    38  }
    39  
    40  // Degraded implements cell.Health.
    41  func (*nopHealth) Degraded(reason string, err error) {
    42  }
    43  
    44  // NewScope implements cell.Health.
    45  func (h *nopHealth) NewScope(name string) cell.Health {
    46  	return h
    47  }
    48  
    49  // OK implements cell.Health.
    50  func (*nopHealth) OK(status string) {
    51  }
    52  
    53  // Stopped implements cell.Health.
    54  func (*nopHealth) Stopped(reason string) {
    55  }
    56  
    57  func (*nopHealth) Close() {}
    58  
    59  func newNopHealth() (cell.Health, *nopHealth) {
    60  	h := &nopHealth{}
    61  	return h, h
    62  }
    63  
    64  var _ cell.Health = &nopHealth{}
    65  
    66  func TestDerive(t *testing.T) {
    67  	var db *DB
    68  	inTable, err := NewTable("test", idIndex)
    69  	require.NoError(t, err)
    70  	outTable, err := NewTable("derived", derivedIdIndex)
    71  	require.NoError(t, err)
    72  
    73  	transform := func(obj testObject, deleted bool) (derived, DeriveResult) {
    74  		t.Logf("transform(%v, %v)", obj, deleted)
    75  
    76  		tags := slices.Collect(obj.Tags.All())
    77  		if obj.Tags.Len() > 0 && tags[0] == "skip" {
    78  			return derived{}, DeriveSkip
    79  		}
    80  		if deleted {
    81  			if obj.Tags.Len() > 0 && tags[0] == "delete" {
    82  				return derived{ID: obj.ID}, DeriveDelete
    83  			}
    84  			return derived{ID: obj.ID, Deleted: true}, DeriveUpdate
    85  		}
    86  		return derived{ID: obj.ID, Deleted: false}, DeriveInsert
    87  	}
    88  
    89  	h := hive.New(
    90  		Cell, // DB
    91  		job.Cell,
    92  		cell.Provide(newNopHealth),
    93  		cell.Module(
    94  			"test", "Test",
    95  
    96  			cell.Provide(func(db_ *DB) (Table[testObject], RWTable[derived], error) {
    97  				db = db_
    98  				if err := db.RegisterTable(inTable); err != nil {
    99  					return nil, nil, err
   100  				}
   101  				if err := db.RegisterTable(outTable); err != nil {
   102  					return nil, nil, err
   103  				}
   104  				return inTable, outTable, nil
   105  			}),
   106  
   107  			cell.Invoke(Derive("testObject-to-derived", transform)),
   108  		),
   109  	)
   110  	log := hivetest.Logger(t, hivetest.LogLevel(slog.LevelError))
   111  	require.NoError(t, h.Start(log, context.TODO()), "Start")
   112  
   113  	getDerived := func() []derived {
   114  		txn := db.ReadTxn()
   115  		objs := Collect(outTable.All(txn))
   116  		// Log so we can trace the failed eventually calls
   117  		t.Logf("derived: %+v", objs)
   118  		return objs
   119  	}
   120  
   121  	// Insert 1, 2 and 3 (skipped) and validate.
   122  	wtxn := db.WriteTxn(inTable)
   123  	_, _, err = inTable.Insert(wtxn, testObject{ID: 1})
   124  	require.NoError(t, err, "Insert failed")
   125  	_, _, err = inTable.Insert(wtxn, testObject{ID: 2})
   126  	require.NoError(t, err, "Insert failed")
   127  	_, _, err = inTable.Insert(wtxn, testObject{ID: 3, Tags: part.NewSet("skip")})
   128  	require.NoError(t, err, "Insert failed")
   129  	wtxn.Commit()
   130  
   131  	require.Eventually(t,
   132  		func() bool {
   133  			objs := getDerived()
   134  			return len(objs) == 2 && // 3 is skipped
   135  				objs[0].ID == 1 && objs[1].ID == 2
   136  		},
   137  		time.Second,
   138  		10*time.Millisecond,
   139  		"expected 1 & 2 to be derived",
   140  	)
   141  
   142  	// Delete 2 (testing DeriveUpdate)
   143  	wtxn = db.WriteTxn(inTable)
   144  	_, hadOld, err := inTable.Delete(wtxn, testObject{ID: 2})
   145  	require.NoError(t, err, "Delete failed")
   146  	require.True(t, hadOld, "Expected object to be deleted")
   147  	wtxn.Commit()
   148  
   149  	require.Eventually(t,
   150  		func() bool {
   151  			objs := getDerived()
   152  			return len(objs) == 2 && // 3 is skipped
   153  				objs[0].ID == 1 && !objs[0].Deleted &&
   154  				objs[1].ID == 2 && objs[1].Deleted
   155  		},
   156  		time.Second,
   157  		10*time.Millisecond,
   158  		"expected 1 & 2, with 2 marked deleted",
   159  	)
   160  
   161  	// Delete 1 (testing DeriveDelete)
   162  	wtxn = db.WriteTxn(inTable)
   163  	_, _, err = inTable.Insert(wtxn, testObject{ID: 1, Tags: part.NewSet("delete")})
   164  	require.NoError(t, err, "Insert failed")
   165  	wtxn.Commit()
   166  	wtxn = db.WriteTxn(inTable)
   167  	_, _, err = inTable.Delete(wtxn, testObject{ID: 1})
   168  	require.NoError(t, err, "Delete failed")
   169  	wtxn.Commit()
   170  
   171  	require.Eventually(t,
   172  		func() bool {
   173  			objs := getDerived()
   174  			return len(objs) == 1 &&
   175  				objs[0].ID == 2 && objs[0].Deleted
   176  		},
   177  		time.Second,
   178  		10*time.Millisecond,
   179  		"expected 1 to be gone, and 2 mark deleted",
   180  	)
   181  
   182  	require.NoError(t, h.Stop(log, context.TODO()), "Stop")
   183  }