github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/memdb/memdb.go (about)

     1  package memdb
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"sort"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/authzed/spicedb/internal/datastore/common"
    13  	"github.com/authzed/spicedb/pkg/spiceerrors"
    14  
    15  	"github.com/google/uuid"
    16  	"github.com/hashicorp/go-memdb"
    17  
    18  	"github.com/authzed/spicedb/internal/datastore/revisions"
    19  	"github.com/authzed/spicedb/pkg/datastore"
    20  	"github.com/authzed/spicedb/pkg/datastore/options"
    21  	corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
    22  )
    23  
    24  const (
    25  	Engine                   = "memory"
    26  	defaultWatchBufferLength = 128
    27  	numAttempts              = 10
    28  )
    29  
    30  var ErrSerialization = errors.New("serialization error")
    31  
    32  // DisableGC is a convenient constant for setting the garbage collection
    33  // interval high enough that it will never run.
    34  const DisableGC = time.Duration(math.MaxInt64)
    35  
    36  // NewMemdbDatastore creates a new Datastore compliant datastore backed by memdb.
    37  //
    38  // If the watchBufferLength value of 0 is set then a default value of 128 will be used.
    39  func NewMemdbDatastore(
    40  	watchBufferLength uint16,
    41  	revisionQuantization,
    42  	gcWindow time.Duration,
    43  ) (datastore.Datastore, error) {
    44  	if revisionQuantization > gcWindow {
    45  		return nil, errors.New("gc window must be larger than quantization interval")
    46  	}
    47  
    48  	if revisionQuantization <= 1 {
    49  		revisionQuantization = 1
    50  	}
    51  
    52  	db, err := memdb.NewMemDB(schema)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	if watchBufferLength == 0 {
    58  		watchBufferLength = defaultWatchBufferLength
    59  	}
    60  
    61  	uniqueID := uuid.NewString()
    62  	return &memdbDatastore{
    63  		CommonDecoder: revisions.CommonDecoder{
    64  			Kind: revisions.Timestamp,
    65  		},
    66  		db: db,
    67  		revisions: []snapshot{
    68  			{
    69  				revision: nowRevision(),
    70  				db:       db,
    71  			},
    72  		},
    73  
    74  		negativeGCWindow:        gcWindow.Nanoseconds() * -1,
    75  		quantizationPeriod:      revisionQuantization.Nanoseconds(),
    76  		watchBufferLength:       watchBufferLength,
    77  		watchBufferWriteTimeout: 100 * time.Millisecond,
    78  		uniqueID:                uniqueID,
    79  	}, nil
    80  }
    81  
    82  type memdbDatastore struct {
    83  	sync.RWMutex
    84  	revisions.CommonDecoder
    85  
    86  	db             *memdb.MemDB
    87  	revisions      []snapshot
    88  	activeWriteTxn *memdb.Txn
    89  
    90  	negativeGCWindow        int64
    91  	quantizationPeriod      int64
    92  	watchBufferLength       uint16
    93  	watchBufferWriteTimeout time.Duration
    94  	uniqueID                string
    95  }
    96  
    97  type snapshot struct {
    98  	revision revisions.TimestampRevision
    99  	db       *memdb.MemDB
   100  }
   101  
   102  func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reader {
   103  	mdb.RLock()
   104  	defer mdb.RUnlock()
   105  
   106  	if len(mdb.revisions) == 0 {
   107  		return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is not ready")}
   108  	}
   109  
   110  	if err := mdb.checkRevisionLocalCallerMustLock(dr); err != nil {
   111  		return &memdbReader{nil, nil, err}
   112  	}
   113  
   114  	revIndex := sort.Search(len(mdb.revisions), func(i int) bool {
   115  		return mdb.revisions[i].revision.GreaterThan(dr) || mdb.revisions[i].revision.Equal(dr)
   116  	})
   117  
   118  	// handle the case when there is no revision snapshot newer than the requested revision
   119  	if revIndex == len(mdb.revisions) {
   120  		revIndex = len(mdb.revisions) - 1
   121  	}
   122  
   123  	rev := mdb.revisions[revIndex]
   124  	if rev.db == nil {
   125  		return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is already closed")}
   126  	}
   127  
   128  	roTxn := rev.db.Txn(false)
   129  	txSrc := func() (*memdb.Txn, error) {
   130  		return roTxn, nil
   131  	}
   132  
   133  	return &memdbReader{noopTryLocker{}, txSrc, nil}
   134  }
   135  
   136  func (mdb *memdbDatastore) ReadWriteTx(
   137  	ctx context.Context,
   138  	f datastore.TxUserFunc,
   139  	opts ...options.RWTOptionsOption,
   140  ) (datastore.Revision, error) {
   141  	config := options.NewRWTOptionsWithOptions(opts...)
   142  	txNumAttempts := numAttempts
   143  	if config.DisableRetries {
   144  		txNumAttempts = 1
   145  	}
   146  
   147  	for i := 0; i < txNumAttempts; i++ {
   148  		var tx *memdb.Txn
   149  		createTxOnce := sync.Once{}
   150  		txSrc := func() (*memdb.Txn, error) {
   151  			var err error
   152  			createTxOnce.Do(func() {
   153  				mdb.Lock()
   154  				defer mdb.Unlock()
   155  
   156  				if mdb.activeWriteTxn != nil {
   157  					err = ErrSerialization
   158  					return
   159  				}
   160  
   161  				if mdb.db == nil {
   162  					err = fmt.Errorf("datastore is closed")
   163  					return
   164  				}
   165  
   166  				tx = mdb.db.Txn(true)
   167  				tx.TrackChanges()
   168  				mdb.activeWriteTxn = tx
   169  			})
   170  
   171  			return tx, err
   172  		}
   173  
   174  		newRevision := mdb.newRevisionID()
   175  		rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil}, newRevision}
   176  		if err := f(ctx, rwt); err != nil {
   177  			mdb.Lock()
   178  			if tx != nil {
   179  				tx.Abort()
   180  				mdb.activeWriteTxn = nil
   181  			}
   182  
   183  			// If the error was a serialization error, retry the transaction
   184  			if errors.Is(err, ErrSerialization) {
   185  				mdb.Unlock()
   186  
   187  				// If we don't sleep here, we run out of retries instantaneously
   188  				time.Sleep(1 * time.Millisecond)
   189  				continue
   190  			}
   191  			defer mdb.Unlock()
   192  
   193  			// We *must* return the inner error unmodified in case it's not an error type
   194  			// that supports unwrapping (e.g. gRPC errors)
   195  			return datastore.NoRevision, err
   196  		}
   197  
   198  		mdb.Lock()
   199  		defer mdb.Unlock()
   200  
   201  		tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema)
   202  		if tx != nil {
   203  			for _, change := range tx.Changes() {
   204  				switch change.Table {
   205  				case tableRelationship:
   206  					if change.After != nil {
   207  						rt, err := change.After.(*relationship).RelationTuple()
   208  						if err != nil {
   209  							return datastore.NoRevision, err
   210  						}
   211  
   212  						if err := tracked.AddRelationshipChange(ctx, newRevision, rt, corev1.RelationTupleUpdate_TOUCH); err != nil {
   213  							return datastore.NoRevision, err
   214  						}
   215  					} else if change.After == nil && change.Before != nil {
   216  						rt, err := change.Before.(*relationship).RelationTuple()
   217  						if err != nil {
   218  							return datastore.NoRevision, err
   219  						}
   220  
   221  						if err := tracked.AddRelationshipChange(ctx, newRevision, rt, corev1.RelationTupleUpdate_DELETE); err != nil {
   222  							return datastore.NoRevision, err
   223  						}
   224  					} else {
   225  						return datastore.NoRevision, spiceerrors.MustBugf("unexpected relationship change")
   226  					}
   227  				case tableNamespace:
   228  					if change.After != nil {
   229  						loaded := &corev1.NamespaceDefinition{}
   230  						if err := loaded.UnmarshalVT(change.After.(*namespace).configBytes); err != nil {
   231  							return datastore.NoRevision, err
   232  						}
   233  
   234  						tracked.AddChangedDefinition(ctx, newRevision, loaded)
   235  					} else if change.After == nil && change.Before != nil {
   236  						tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name)
   237  					} else {
   238  						return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change")
   239  					}
   240  				case tableCaveats:
   241  					if change.After != nil {
   242  						loaded := &corev1.CaveatDefinition{}
   243  						if err := loaded.UnmarshalVT(change.After.(*caveat).definition); err != nil {
   244  							return datastore.NoRevision, err
   245  						}
   246  
   247  						tracked.AddChangedDefinition(ctx, newRevision, loaded)
   248  					} else if change.After == nil && change.Before != nil {
   249  						tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name)
   250  					} else {
   251  						return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change")
   252  					}
   253  				}
   254  			}
   255  
   256  			var rc datastore.RevisionChanges
   257  			changes := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc)
   258  			if len(changes) > 1 {
   259  				return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes")
   260  			} else if len(changes) == 1 {
   261  				rc = changes[0]
   262  			}
   263  
   264  			change := &changelog{
   265  				revisionNanos: newRevision.TimestampNanoSec(),
   266  				changes:       rc,
   267  			}
   268  			if err := tx.Insert(tableChangelog, change); err != nil {
   269  				return datastore.NoRevision, fmt.Errorf("error writing changelog: %w", err)
   270  			}
   271  
   272  			tx.Commit()
   273  		}
   274  		mdb.activeWriteTxn = nil
   275  
   276  		// Create a snapshot and add it to the revisions slice
   277  		if mdb.db == nil {
   278  			return datastore.NoRevision, fmt.Errorf("datastore has been closed")
   279  		}
   280  
   281  		snap := mdb.db.Snapshot()
   282  		mdb.revisions = append(mdb.revisions, snapshot{newRevision, snap})
   283  		return newRevision, nil
   284  	}
   285  
   286  	return datastore.NoRevision, NewSerializationMaxRetriesReachedErr(errors.New("serialization max retries exceeded; please reduce your parallel writes"))
   287  }
   288  
   289  func (mdb *memdbDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) {
   290  	mdb.RLock()
   291  	defer mdb.RUnlock()
   292  
   293  	return datastore.ReadyState{
   294  		Message: "missing expected initial revision",
   295  		IsReady: len(mdb.revisions) > 0,
   296  	}, nil
   297  }
   298  
   299  func (mdb *memdbDatastore) Features(_ context.Context) (*datastore.Features, error) {
   300  	return &datastore.Features{Watch: datastore.Feature{Enabled: true}}, nil
   301  }
   302  
   303  func (mdb *memdbDatastore) Close() error {
   304  	mdb.Lock()
   305  	defer mdb.Unlock()
   306  
   307  	if db := mdb.db; db != nil {
   308  		mdb.revisions = []snapshot{
   309  			{
   310  				revision: nowRevision(),
   311  				db:       db,
   312  			},
   313  		}
   314  	} else {
   315  		mdb.revisions = []snapshot{}
   316  	}
   317  
   318  	mdb.db = nil
   319  
   320  	return nil
   321  }
   322  
   323  var _ datastore.Datastore = &memdbDatastore{}