github.com/cilium/statedb@v0.3.2/table.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package statedb
     5  
     6  import (
     7  	"fmt"
     8  	"iter"
     9  	"regexp"
    10  	"runtime"
    11  	"slices"
    12  	"sort"
    13  	"strings"
    14  	"sync"
    15  	"sync/atomic"
    16  
    17  	"github.com/cilium/statedb/internal"
    18  	"github.com/cilium/statedb/part"
    19  	"gopkg.in/yaml.v3"
    20  
    21  	"github.com/cilium/statedb/index"
    22  )
    23  
    24  // NewTable creates a new table with given name and indexes.
    25  // Can fail if the indexes or the name are malformed.
    26  // The name must match regex "^[a-z][a-z0-9_\\-]{0,30}$".
    27  //
    28  // To provide access to the table via Hive:
    29  //
    30  //	cell.Provide(
    31  //		// Provide statedb.RWTable[*MyObject]. Often only provided to the module with ProvidePrivate.
    32  //		statedb.NewTable[*MyObject]("my-objects", MyObjectIDIndex, MyObjectNameIndex),
    33  //		// Provide the read-only statedb.Table[*MyObject].
    34  //		statedb.RWTable[*MyObject].ToTable,
    35  //	)
    36  func NewTable[Obj any](
    37  	tableName TableName,
    38  	primaryIndexer Indexer[Obj],
    39  	secondaryIndexers ...Indexer[Obj],
    40  ) (RWTable[Obj], error) {
    41  	if err := validateTableName(tableName); err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	toAnyIndexer := func(idx Indexer[Obj]) anyIndexer {
    46  		return anyIndexer{
    47  			name: idx.indexName(),
    48  			fromObject: func(iobj object) index.KeySet {
    49  				return idx.fromObject(iobj.data.(Obj))
    50  			},
    51  			fromString: idx.fromString,
    52  			unique:     idx.isUnique(),
    53  		}
    54  	}
    55  
    56  	table := &genTable[Obj]{
    57  		table:                tableName,
    58  		smu:                  internal.NewSortableMutex(),
    59  		primaryAnyIndexer:    toAnyIndexer(primaryIndexer),
    60  		primaryIndexer:       primaryIndexer,
    61  		secondaryAnyIndexers: make(map[string]anyIndexer, len(secondaryIndexers)),
    62  		indexPositions:       make(map[string]int),
    63  		pos:                  -1,
    64  	}
    65  
    66  	table.indexPositions[primaryIndexer.indexName()] = PrimaryIndexPos
    67  
    68  	// Internal indexes
    69  	table.indexPositions[RevisionIndex] = RevisionIndexPos
    70  	table.indexPositions[GraveyardIndex] = GraveyardIndexPos
    71  	table.indexPositions[GraveyardRevisionIndex] = GraveyardRevisionIndexPos
    72  
    73  	indexPos := SecondaryIndexStartPos
    74  	for _, indexer := range secondaryIndexers {
    75  		name := indexer.indexName()
    76  		anyIndexer := toAnyIndexer(indexer)
    77  		anyIndexer.pos = indexPos
    78  		table.secondaryAnyIndexers[name] = anyIndexer
    79  		table.indexPositions[name] = indexPos
    80  		indexPos++
    81  	}
    82  
    83  	// Primary index must always be unique
    84  	if !primaryIndexer.isUnique() {
    85  		return nil, tableError(tableName, ErrPrimaryIndexNotUnique)
    86  	}
    87  
    88  	// Validate that indexes have unique ids.
    89  	indexNames := map[string]struct{}{}
    90  	indexNames[primaryIndexer.indexName()] = struct{}{}
    91  	for _, indexer := range secondaryIndexers {
    92  		if _, ok := indexNames[indexer.indexName()]; ok {
    93  			return nil, tableError(tableName, fmt.Errorf("index %q: %w", indexer.indexName(), ErrDuplicateIndex))
    94  		}
    95  		indexNames[indexer.indexName()] = struct{}{}
    96  	}
    97  	for name := range indexNames {
    98  		if strings.HasPrefix(name, reservedIndexPrefix) {
    99  			return nil, tableError(tableName, fmt.Errorf("index %q: %w", name, ErrReservedPrefix))
   100  		}
   101  	}
   102  	return table, nil
   103  }
   104  
   105  // MustNewTable creates a new table with given name and indexes.
   106  // Panics if indexes are malformed.
   107  func MustNewTable[Obj any](
   108  	tableName TableName,
   109  	primaryIndexer Indexer[Obj],
   110  	secondaryIndexers ...Indexer[Obj]) RWTable[Obj] {
   111  	t, err := NewTable(tableName, primaryIndexer, secondaryIndexers...)
   112  	if err != nil {
   113  		panic(err)
   114  	}
   115  	return t
   116  }
   117  
   118  var nameRegex = regexp.MustCompile(`^[a-z][a-z0-9_\-]{0,30}$`)
   119  
   120  func validateTableName(name string) error {
   121  	if !nameRegex.MatchString(name) {
   122  		return fmt.Errorf("invalid table name %q, expected to match %q", name, nameRegex)
   123  	}
   124  	return nil
   125  }
   126  
   127  type genTable[Obj any] struct {
   128  	pos                  int
   129  	table                TableName
   130  	smu                  internal.SortableMutex
   131  	primaryIndexer       Indexer[Obj]
   132  	primaryAnyIndexer    anyIndexer
   133  	secondaryAnyIndexers map[string]anyIndexer
   134  	indexPositions       map[string]int
   135  	lastWriteTxn         atomic.Pointer[txn]
   136  }
   137  
   138  func (t *genTable[Obj]) acquired(txn *txn) {
   139  	t.lastWriteTxn.Store(txn)
   140  }
   141  
   142  func (t *genTable[Obj]) getAcquiredInfo() string {
   143  	return t.lastWriteTxn.Load().acquiredInfo()
   144  }
   145  
   146  func (t *genTable[Obj]) tableEntry() tableEntry {
   147  	var entry tableEntry
   148  	entry.meta = t
   149  	entry.deleteTrackers = part.New[anyDeleteTracker]()
   150  	entry.initWatchChan = make(chan struct{})
   151  	entry.indexes = make([]indexEntry, len(t.indexPositions))
   152  	entry.indexes[t.indexPositions[t.primaryIndexer.indexName()]] = indexEntry{part.New[object](), nil, true}
   153  
   154  	for index, indexer := range t.secondaryAnyIndexers {
   155  		entry.indexes[t.indexPositions[index]] = indexEntry{part.New[object](), nil, indexer.unique}
   156  	}
   157  	// For revision indexes we only need to watch the root.
   158  	entry.indexes[t.indexPositions[RevisionIndex]] = indexEntry{part.New[object](part.RootOnlyWatch), nil, true}
   159  	entry.indexes[t.indexPositions[GraveyardRevisionIndex]] = indexEntry{part.New[object](part.RootOnlyWatch), nil, true}
   160  	entry.indexes[t.indexPositions[GraveyardIndex]] = indexEntry{part.New[object](), nil, true}
   161  
   162  	return entry
   163  }
   164  
   165  func (t *genTable[Obj]) setTablePos(pos int) {
   166  	t.pos = pos
   167  }
   168  
   169  func (t *genTable[Obj]) tablePos() int {
   170  	return t.pos
   171  }
   172  
   173  func (t *genTable[Obj]) tableKey() []byte {
   174  	return []byte(t.table)
   175  }
   176  
   177  func (t *genTable[Obj]) indexPos(name string) int {
   178  	if t.primaryAnyIndexer.name == name {
   179  		return PrimaryIndexPos
   180  	}
   181  	return t.indexPositions[name]
   182  }
   183  
   184  func (t *genTable[Obj]) getIndexer(name string) *anyIndexer {
   185  	if name == "" || t.primaryAnyIndexer.name == name {
   186  		return &t.primaryAnyIndexer
   187  	}
   188  	if indexer, ok := t.secondaryAnyIndexers[name]; ok {
   189  		return &indexer
   190  	}
   191  	return nil
   192  }
   193  
   194  func (t *genTable[Obj]) PrimaryIndexer() Indexer[Obj] {
   195  	return t.primaryIndexer
   196  }
   197  
   198  func (t *genTable[Obj]) primary() anyIndexer {
   199  	return t.primaryAnyIndexer
   200  }
   201  
   202  func (t *genTable[Obj]) secondary() map[string]anyIndexer {
   203  	return t.secondaryAnyIndexers
   204  }
   205  
   206  func (t *genTable[Obj]) Name() string {
   207  	return t.table
   208  }
   209  
   210  func (t *genTable[Obj]) Indexes() []string {
   211  	idxs := make([]string, 0, 1+len(t.secondaryAnyIndexers))
   212  	idxs = append(idxs, t.primaryAnyIndexer.name)
   213  	for k := range t.secondaryAnyIndexers {
   214  		idxs = append(idxs, k)
   215  	}
   216  	sort.Strings(idxs)
   217  	return idxs
   218  }
   219  
   220  func (t *genTable[Obj]) ToTable() Table[Obj] {
   221  	return t
   222  }
   223  
   224  func (t *genTable[Obj]) Initialized(txn ReadTxn) (bool, <-chan struct{}) {
   225  	table := txn.getTxn().getTableEntry(t)
   226  	if len(table.pendingInitializers) == 0 {
   227  		return true, closedWatchChannel
   228  	}
   229  	return false, table.initWatchChan
   230  }
   231  
   232  func (t *genTable[Obj]) PendingInitializers(txn ReadTxn) []string {
   233  	return txn.getTxn().getTableEntry(t).pendingInitializers
   234  }
   235  
   236  func (t *genTable[Obj]) RegisterInitializer(txn WriteTxn, name string) func(WriteTxn) {
   237  	table := txn.getTxn().modifiedTables[t.pos]
   238  	if table != nil {
   239  		if slices.Contains(table.pendingInitializers, name) {
   240  			panic(fmt.Sprintf("RegisterInitializer: %q already registered", name))
   241  		}
   242  		table.pendingInitializers =
   243  			append(slices.Clone(table.pendingInitializers), name)
   244  		var once sync.Once
   245  		return func(txn WriteTxn) {
   246  			once.Do(func() {
   247  				if table := txn.getTxn().modifiedTables[t.pos]; table != nil {
   248  					table.pendingInitializers = slices.DeleteFunc(
   249  						slices.Clone(table.pendingInitializers),
   250  						func(n string) bool { return n == name },
   251  					)
   252  				}
   253  			})
   254  		}
   255  	} else {
   256  		panic(fmt.Sprintf("RegisterInitializer: Table %q not locked for writing", t.table))
   257  	}
   258  }
   259  
   260  func (t *genTable[Obj]) Revision(txn ReadTxn) Revision {
   261  	return txn.getTxn().getTableEntry(t).revision
   262  }
   263  
   264  func (t *genTable[Obj]) NumObjects(txn ReadTxn) int {
   265  	table := txn.getTxn().getTableEntry(t)
   266  	return table.numObjects()
   267  }
   268  
   269  func (t *genTable[Obj]) numDeletedObjects(txn ReadTxn) int {
   270  	table := txn.getTxn().getTableEntry(t)
   271  	return table.numDeletedObjects()
   272  }
   273  
   274  func (t *genTable[Obj]) Get(txn ReadTxn, q Query[Obj]) (obj Obj, revision uint64, ok bool) {
   275  	obj, revision, _, ok = t.GetWatch(txn, q)
   276  	return
   277  }
   278  
   279  func (t *genTable[Obj]) GetWatch(txn ReadTxn, q Query[Obj]) (obj Obj, revision uint64, watch <-chan struct{}, ok bool) {
   280  	// Since we're not returning an iterator here we can optimize and not use
   281  	// indexReadTxn which clones if this is a WriteTxn (to avoid invalidating iterators).
   282  	indexPos := t.indexPos(q.index)
   283  	itxn := txn.getTxn()
   284  	var (
   285  		ops    part.Ops[object]
   286  		unique bool
   287  	)
   288  	if itxn.modifiedTables != nil && itxn.modifiedTables[t.tablePos()] != nil {
   289  		var err error
   290  		iwtxn, err := itxn.indexWriteTxn(t, indexPos)
   291  		if err != nil {
   292  			panic(err)
   293  		}
   294  		ops = iwtxn.Txn
   295  		unique = iwtxn.unique
   296  	} else {
   297  		entry := itxn.root[t.tablePos()].indexes[indexPos]
   298  		ops = entry.tree
   299  		unique = entry.unique
   300  	}
   301  
   302  	var iobj object
   303  	if unique {
   304  		// On a unique index we can do a direct get rather than a prefix search.
   305  		iobj, watch, ok = ops.Get(q.key)
   306  		if !ok {
   307  			return
   308  		}
   309  		obj = iobj.data.(Obj)
   310  		revision = iobj.revision
   311  		return
   312  	}
   313  
   314  	// For a non-unique index we need to do a prefix search.
   315  	iter, watch := ops.Prefix(q.key)
   316  	for {
   317  		var key []byte
   318  		key, iobj, ok = iter.Next()
   319  		if !ok {
   320  			break
   321  		}
   322  
   323  		// Check that we have a full match on the key
   324  		secondary, _ := decodeNonUniqueKey(key)
   325  		if len(secondary) == len(q.key) {
   326  			break
   327  		}
   328  	}
   329  
   330  	if ok {
   331  		obj = iobj.data.(Obj)
   332  		revision = iobj.revision
   333  	}
   334  	return
   335  }
   336  
   337  func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) iter.Seq2[Obj, Revision] {
   338  	iter, _ := t.LowerBoundWatch(txn, q)
   339  	return iter
   340  }
   341  
   342  func (t *genTable[Obj]) LowerBoundWatch(txn ReadTxn, q Query[Obj]) (iter.Seq2[Obj, Revision], <-chan struct{}) {
   343  	indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index))
   344  	// Since LowerBound query may be invalidated by changes in another branch
   345  	// of the tree, we cannot just simply watch the node we seeked to. Instead
   346  	// we watch the whole table for changes.
   347  	watch := indexTxn.RootWatch()
   348  	iter := indexTxn.LowerBound(q.key)
   349  	if indexTxn.unique {
   350  		return partSeq[Obj](iter), watch
   351  	}
   352  	return nonUniqueLowerBoundSeq[Obj](iter, q.key), watch
   353  }
   354  
   355  func (t *genTable[Obj]) Prefix(txn ReadTxn, q Query[Obj]) iter.Seq2[Obj, Revision] {
   356  	iter, _ := t.PrefixWatch(txn, q)
   357  	return iter
   358  }
   359  
   360  func (t *genTable[Obj]) PrefixWatch(txn ReadTxn, q Query[Obj]) (iter.Seq2[Obj, Revision], <-chan struct{}) {
   361  	indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index))
   362  	iter, watch := indexTxn.Prefix(q.key)
   363  	if indexTxn.unique {
   364  		return partSeq[Obj](iter), watch
   365  	}
   366  	return nonUniqueSeq[Obj](iter, true, q.key), watch
   367  }
   368  
   369  func (t *genTable[Obj]) All(txn ReadTxn) iter.Seq2[Obj, Revision] {
   370  	iter, _ := t.AllWatch(txn)
   371  	return iter
   372  }
   373  
   374  func (t *genTable[Obj]) AllWatch(txn ReadTxn) (iter.Seq2[Obj, Revision], <-chan struct{}) {
   375  	indexTxn := txn.getTxn().mustIndexReadTxn(t, PrimaryIndexPos)
   376  	return partSeq[Obj](indexTxn.Iterator()), indexTxn.RootWatch()
   377  }
   378  
   379  func (t *genTable[Obj]) List(txn ReadTxn, q Query[Obj]) iter.Seq2[Obj, Revision] {
   380  	iter, _ := t.ListWatch(txn, q)
   381  	return iter
   382  }
   383  
   384  func (t *genTable[Obj]) ListWatch(txn ReadTxn, q Query[Obj]) (iter.Seq2[Obj, Revision], <-chan struct{}) {
   385  	indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index))
   386  	if indexTxn.unique {
   387  		// Unique index means that there can be only a single matching object.
   388  		// Doing a Get() is more efficient than constructing an iterator.
   389  		value, watch, ok := indexTxn.Get(q.key)
   390  		seq := func(yield func(Obj, Revision) bool) {
   391  			if ok {
   392  				yield(value.data.(Obj), value.revision)
   393  			}
   394  		}
   395  		return seq, watch
   396  	}
   397  
   398  	// For a non-unique index we do a prefix search. The keys are of
   399  	// form <secondary key><primary key><secondary key length>, and thus the
   400  	// iteration will continue until key length mismatches, e.g. we hit a
   401  	// longer key sharing the same prefix.
   402  	iter, watch := indexTxn.Prefix(q.key)
   403  	return nonUniqueSeq[Obj](iter, false, q.key), watch
   404  }
   405  
   406  func (t *genTable[Obj]) Insert(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, err error) {
   407  	var old object
   408  	old, hadOld, err = txn.getTxn().insert(t, Revision(0), obj)
   409  	if hadOld {
   410  		oldObj = old.data.(Obj)
   411  	}
   412  	return
   413  }
   414  
   415  func (t *genTable[Obj]) Modify(txn WriteTxn, obj Obj, merge func(old, new Obj) Obj) (oldObj Obj, hadOld bool, err error) {
   416  	var old object
   417  	old, hadOld, err = txn.getTxn().modify(t, Revision(0), obj,
   418  		func(old any) any {
   419  			return merge(old.(Obj), obj)
   420  		})
   421  	if hadOld {
   422  		oldObj = old.data.(Obj)
   423  	}
   424  	return
   425  }
   426  
   427  func (t *genTable[Obj]) CompareAndSwap(txn WriteTxn, rev Revision, obj Obj) (oldObj Obj, hadOld bool, err error) {
   428  	var old object
   429  	old, hadOld, err = txn.getTxn().insert(t, rev, obj)
   430  	if hadOld {
   431  		oldObj = old.data.(Obj)
   432  	}
   433  	return
   434  }
   435  
   436  func (t *genTable[Obj]) Delete(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, err error) {
   437  	var old object
   438  	old, hadOld, err = txn.getTxn().delete(t, Revision(0), obj)
   439  	if hadOld {
   440  		oldObj = old.data.(Obj)
   441  	}
   442  	return
   443  }
   444  
   445  func (t *genTable[Obj]) CompareAndDelete(txn WriteTxn, rev Revision, obj Obj) (oldObj Obj, hadOld bool, err error) {
   446  	var old object
   447  	old, hadOld, err = txn.getTxn().delete(t, rev, obj)
   448  	if hadOld {
   449  		oldObj = old.data.(Obj)
   450  	}
   451  	return
   452  }
   453  
   454  func (t *genTable[Obj]) DeleteAll(txn WriteTxn) error {
   455  	itxn := txn.getTxn()
   456  	for obj := range t.All(txn) {
   457  		_, _, err := itxn.delete(t, Revision(0), obj)
   458  		if err != nil {
   459  			return err
   460  		}
   461  	}
   462  	return nil
   463  }
   464  
   465  func (t *genTable[Obj]) Changes(txn WriteTxn) (ChangeIterator[Obj], error) {
   466  	iter := &changeIterator[Obj]{
   467  		revision: 0,
   468  
   469  		// Don't observe any past deletions.
   470  		deleteRevision: t.Revision(txn),
   471  		table:          t,
   472  		watch:          closedWatchChannel,
   473  	}
   474  	// Set a finalizer to unregister the delete tracker when the iterator
   475  	// is dropped.
   476  	runtime.SetFinalizer(iter, func(iter *changeIterator[Obj]) {
   477  		iter.close()
   478  	})
   479  
   480  	itxn := txn.getTxn()
   481  	name := fmt.Sprintf("changes-%p", iter)
   482  	iter.dt = &deleteTracker[Obj]{
   483  		db:          itxn.db,
   484  		trackerName: name,
   485  		table:       t,
   486  	}
   487  
   488  	iter.dt.setRevision(iter.deleteRevision)
   489  	err := itxn.addDeleteTracker(t, name, iter.dt)
   490  	if err != nil {
   491  		return nil, err
   492  	}
   493  
   494  	// Prime it.
   495  	iter.refresh(txn)
   496  
   497  	return iter, nil
   498  }
   499  
   500  // anyChanges returns the anyChangeIterator. Used for implementing the /changes HTTP
   501  // API where we can't work with concrete object types as they're not known and thus
   502  // uninstantiatable.
   503  func (t *genTable[Obj]) anyChanges(txn WriteTxn) (anyChangeIterator, error) {
   504  	iter, err := t.Changes(txn)
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	return iter.(*changeIterator[Obj]), err
   509  }
   510  
   511  func (t *genTable[Obj]) sortableMutex() internal.SortableMutex {
   512  	return t.smu
   513  }
   514  
   515  func (t *genTable[Obj]) proto() any {
   516  	var zero Obj
   517  	return zero
   518  }
   519  
   520  func (t *genTable[Obj]) unmarshalYAML(data []byte) (any, error) {
   521  	var obj Obj
   522  	if err := yaml.Unmarshal(data, &obj); err != nil {
   523  		return nil, err
   524  	}
   525  	return obj, nil
   526  }
   527  
   528  var _ Table[bool] = &genTable[bool]{}
   529  var _ RWTable[bool] = &genTable[bool]{}