
     1  package crdb
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"regexp"
     8  	"strconv"
     9  	"time"
    11  	""
    12  	sq ""
    13  	""
    14  	""
    15  	""
    16  	""
    17  	""
    18  	""
    19  	""
    20  	""
    22  	datastoreinternal ""
    23  	""
    24  	""
    25  	""
    26  	pgxcommon ""
    27  	""
    28  	log ""
    29  	""
    30  	""
    31  )
    33  func init() {
    34  	datastore.Engines = append(datastore.Engines, Engine)
    35  }
    37  var ParseRevisionString = revisions.RevisionParser(revisions.HybridLogicalClock)
    39  var (
    40  	psql = sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
    42  	gcTTLRegex = regexp.MustCompile(`gc\.ttlseconds\s*=\s*([1-9][0-9]+)`)
    44  	tracer = otel.Tracer("spicedb/internal/datastore/crdb")
    45  )
    47  const (
    48  	Engine            = "cockroachdb"
    49  	tableNamespace    = "namespace_config"
    50  	tableTuple        = "relation_tuple"
    51  	tableTransactions = "transactions"
    52  	tableCaveat       = "caveat"
    54  	colNamespace         = "namespace"
    55  	colConfig            = "serialized_config"
    56  	colTimestamp         = "timestamp"
    57  	colTransactionKey    = "key"
    58  	colObjectID          = "object_id"
    59  	colRelation          = "relation"
    60  	colUsersetNamespace  = "userset_namespace"
    61  	colUsersetObjectID   = "userset_object_id"
    62  	colUsersetRelation   = "userset_relation"
    63  	colCaveatName        = "name"
    64  	colCaveatDefinition  = "definition"
    65  	colCaveatContextName = "caveat_name"
    66  	colCaveatContext     = "caveat_context"
    68  	errUnableToInstantiate = "unable to instantiate datastore"
    69  	errRevision            = "unable to find revision: %w"
    71  	querySelectNow            = "SELECT cluster_logical_timestamp()"
    72  	queryTransactionNowPreV23 = querySelectNow
    73  	queryTransactionNow       = "SHOW COMMIT TIMESTAMP"
    74  	queryShowZoneConfig       = "SHOW ZONE CONFIGURATION FOR RANGE default;"
    75  )
    77  var livingTupleConstraints = []string{"pk_relation_tuple"}
    79  func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datastore.Datastore, error) {
    80  	config, err := generateConfig(options)
    81  	if err != nil {
    82  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
    83  	}
    85  	readPoolConfig, err := pgxpool.ParseConfig(url)
    86  	if err != nil {
    87  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
    88  	}
    89  	config.readPoolOpts.ConfigurePgx(readPoolConfig)
    91  	writePoolConfig, err := pgxpool.ParseConfig(url)
    92  	if err != nil {
    93  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
    94  	}
    95  	config.writePoolOpts.ConfigurePgx(writePoolConfig)
    97  	initCtx, initCancel := context.WithTimeout(context.Background(), 5*time.Minute)
    98  	defer initCancel()
   100  	healthChecker, err := pool.NewNodeHealthChecker(url)
   101  	if err != nil {
   102  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
   103  	}
   105  	// The initPool is a 1-connection pool that is only used for setup tasks.
   106  	// The actual pools are not given the initCtx, since cancellation can
   107  	// interfere with pool setup.
   108  	initPoolConfig := readPoolConfig.Copy()
   109  	initPoolConfig.MinConns = 1
   110  	initPool, err := pool.NewRetryPool(initCtx, "init", initPoolConfig, healthChecker, config.maxRetries, config.connectRate)
   111  	if err != nil {
   112  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
   113  	}
   114  	defer initPool.Close()
   116  	var version crdbVersion
   117  	if err := queryServerVersion(initCtx, initPool, &version); err != nil {
   118  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
   119  	}
   121  	changefeedQuery := queryChangefeed
   122  	if version.Major < 22 {
   123  		log.Info().Object("version", version).Msg("using changefeed query for CRDB version < 22")
   124  		changefeedQuery = queryChangefeedPreV22
   125  	}
   127  	transactionNowQuery := queryTransactionNow
   128  	if version.Major < 23 {
   129  		log.Info().Object("version", version).Msg("using transaction now query for CRDB version < 23")
   130  		transactionNowQuery = queryTransactionNowPreV23
   131  	}
   133  	clusterTTLNanos, err := readClusterTTLNanos(initCtx, initPool)
   134  	if err != nil {
   135  		return nil, fmt.Errorf("unable to read cluster gc window: %w", err)
   136  	}
   138  	gcWindowNanos := config.gcWindow.Nanoseconds()
   139  	if clusterTTLNanos < gcWindowNanos {
   140  		log.Warn().
   141  			Int64("cockroach_cluster_gc_window_nanos", clusterTTLNanos).
   142  			Int64("spicedb_gc_window_nanos", gcWindowNanos).
   143  			Msg("configured CockroachDB cluster gc window is less than configured SpiceDB gc window, falling back to CRDB value - see")
   144  		config.gcWindow = time.Duration(clusterTTLNanos) * time.Nanosecond
   145  	}
   147  	keySetInit := newKeySet
   148  	var keyer overlapKeyer
   149  	switch config.overlapStrategy {
   150  	case overlapStrategyStatic:
   151  		if len(config.overlapKey) == 0 {
   152  			return nil, fmt.Errorf("static tx overlap strategy specified without an overlap key")
   153  		}
   154  		keyer = appendStaticKey(config.overlapKey)
   155  	case overlapStrategyPrefix:
   156  		keyer = prefixKeyer
   157  	case overlapStrategyRequest:
   158  		// overlap keys are computed over requests and not data
   159  		keyer = noOverlapKeyer
   160  		keySetInit = overlapKeysFromContext
   161  	case overlapStrategyInsecure:
   162  		log.Warn().Str("strategy", overlapStrategyInsecure).
   163  			Msg("running in this mode is only safe when replicas == nodes")
   164  		keyer = noOverlapKeyer
   165  	}
   167  	maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())*
   168  		config.maxRevisionStalenessPercent) * time.Nanosecond
   170  	ds := &crdbDatastore{
   171  		RemoteClockRevisions: revisions.NewRemoteClockRevisions(
   172  			config.gcWindow,
   173  			maxRevisionStaleness,
   174  			config.followerReadDelay,
   175  			config.revisionQuantization,
   176  		),
   177  		CommonDecoder:           revisions.CommonDecoder{Kind: revisions.HybridLogicalClock},
   178  		dburl:                   url,
   179  		watchBufferLength:       config.watchBufferLength,
   180  		watchBufferWriteTimeout: config.watchBufferWriteTimeout,
   181  		writeOverlapKeyer:       keyer,
   182  		overlapKeyInit:          keySetInit,
   183  		beginChangefeedQuery:    changefeedQuery,
   184  		transactionNowQuery:     transactionNowQuery,
   185  		analyzeBeforeStatistics: config.analyzeBeforeStatistics,
   186  	}
   187  	ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal)
   189  	// this ctx and cancel is tied to the lifetime of the datastore
   190  	ds.ctx, ds.cancel = context.WithCancel(context.Background())
   191  	ds.writePool, err = pool.NewRetryPool(ds.ctx, "write", writePoolConfig, healthChecker, config.maxRetries, config.connectRate)
   192  	if err != nil {
   193  		ds.cancel()
   194  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
   195  	}
   196  	ds.readPool, err = pool.NewRetryPool(ds.ctx, "read", readPoolConfig, healthChecker, config.maxRetries, config.connectRate)
   197  	if err != nil {
   198  		ds.cancel()
   199  		return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url)
   200  	}
   202  	if config.enablePrometheusStats {
   203  		if err := prometheus.Register(pgxpoolprometheus.NewCollector(ds.writePool, map[string]string{
   204  			"db_name":    "spicedb",
   205  			"pool_usage": "write",
   206  		})); err != nil {
   207  			ds.cancel()
   208  			return nil, err
   209  		}
   211  		if err := prometheus.Register(pgxpoolprometheus.NewCollector(ds.readPool, map[string]string{
   212  			"db_name":    "spicedb",
   213  			"pool_usage": "read",
   214  		})); err != nil {
   215  			ds.cancel()
   216  			return nil, err
   217  		}
   218  	}
   220  	// TODO: this (and the GC startup that it's based on for mysql/pg) should
   221  	// be removed and have the lifetimes tied to server start/stop.
   223  	// Start goroutines for pruning
   224  	if config.enableConnectionBalancing {
   225  		log.Ctx(initCtx).Info().Msg("starting cockroach connection balancer")
   226  		ds.pruneGroup, ds.ctx = errgroup.WithContext(ds.ctx)
   227  		writePoolBalancer := pool.NewNodeConnectionBalancer(ds.writePool, healthChecker, 5*time.Second)
   228  		readPoolBalancer := pool.NewNodeConnectionBalancer(ds.readPool, healthChecker, 5*time.Second)
   229  		ds.pruneGroup.Go(func() error {
   230  			writePoolBalancer.Prune(ds.ctx)
   231  			return nil
   232  		})
   233  		ds.pruneGroup.Go(func() error {
   234  			readPoolBalancer.Prune(ds.ctx)
   235  			return nil
   236  		})
   237  		ds.pruneGroup.Go(func() error {
   238  			healthChecker.Poll(ds.ctx, 5*time.Second)
   239  			return nil
   240  		})
   241  	}
   243  	return ds, nil
   244  }
   246  // NewCRDBDatastore initializes a SpiceDB datastore that uses a CockroachDB
   247  // database while leveraging its AOST functionality.
   248  func NewCRDBDatastore(ctx context.Context, url string, options ...Option) (datastore.Datastore, error) {
   249  	ds, err := newCRDBDatastore(ctx, url, options...)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	return datastoreinternal.NewSeparatingContextDatastoreProxy(ds), nil
   254  }
   256  type crdbDatastore struct {
   257  	*revisions.RemoteClockRevisions
   258  	revisions.CommonDecoder
   260  	dburl                   string
   261  	readPool, writePool     *pool.RetryPool
   262  	watchBufferLength       uint16
   263  	watchBufferWriteTimeout time.Duration
   264  	writeOverlapKeyer       overlapKeyer
   265  	overlapKeyInit          func(ctx context.Context) keySet
   266  	analyzeBeforeStatistics bool
   268  	beginChangefeedQuery string
   269  	transactionNowQuery  string
   271  	featureGroup singleflight.Group[string, *datastore.Features]
   273  	pruneGroup *errgroup.Group
   274  	ctx        context.Context
   275  	cancel     context.CancelFunc
   276  }
   278  func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader {
   279  	executor := common.QueryExecutor{
   280  		Executor: pgxcommon.NewPGXExecutor(cds.readPool),
   281  	}
   283  	fromBuilder := func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder {
   284  		return query.From(fromStr + " AS OF SYSTEM TIME " + rev.String())
   285  	}
   287  	return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, fromBuilder}
   288  }
   290  func (cds *crdbDatastore) ReadWriteTx(
   291  	ctx context.Context,
   292  	f datastore.TxUserFunc,
   293  	opts ...options.RWTOptionsOption,
   294  ) (datastore.Revision, error) {
   295  	var commitTimestamp datastore.Revision
   297  	config := options.NewRWTOptionsWithOptions(opts...)
   298  	if config.DisableRetries {
   299  		ctx = context.WithValue(ctx, pool.CtxDisableRetries, true)
   300  	}
   302  	err := cds.writePool.BeginFunc(ctx, func(tx pgx.Tx) error {
   303  		querier := pgxcommon.QuerierFuncsFor(tx)
   304  		executor := common.QueryExecutor{
   305  			Executor: pgxcommon.NewPGXExecutor(querier),
   306  		}
   308  		rwt := &crdbReadWriteTXN{
   309  			&crdbReader{
   310  				querier,
   311  				executor,
   312  				cds.writeOverlapKeyer,
   313  				cds.overlapKeyInit(ctx),
   314  				func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder {
   315  					return query.From(fromStr)
   316  				},
   317  			},
   318  			tx,
   319  			0,
   320  		}
   322  		if err := f(ctx, rwt); err != nil {
   323  			return err
   324  		}
   326  		// Touching the transaction key happens last so that the "write intent" for
   327  		// the transaction as a whole lands in a range for the affected tuples.
   328  		for k := range rwt.overlapKeySet {
   329  			if _, err := tx.Exec(ctx, queryTouchTransaction, k); err != nil {
   330  				return fmt.Errorf("error writing overlapping keys: %w", err)
   331  			}
   332  		}
   334  		var err error
   335  		commitTimestamp, err = cds.readTransactionCommitRev(ctx, querier)
   336  		if err != nil {
   337  			return fmt.Errorf("error getting commit timestamp: %w", err)
   338  		}
   339  		return nil
   340  	})
   341  	if err != nil {
   342  		return datastore.NoRevision, wrapError(err)
   343  	}
   345  	return commitTimestamp, nil
   346  }
   348  func wrapError(err error) error {
   349  	// If a unique constraint violation is returned, then its likely that the cause
   350  	// was an existing relationship.
   351  	if cerr := pgxcommon.ConvertToWriteConstraintError(livingTupleConstraints, err); cerr != nil {
   352  		return cerr
   353  	}
   354  	return err
   355  }
   357  func (cds *crdbDatastore) ReadyState(ctx context.Context) (datastore.ReadyState, error) {
   358  	headMigration, err := migrations.CRDBMigrations.HeadRevision()
   359  	if err != nil {
   360  		return datastore.ReadyState{}, fmt.Errorf("invalid head migration found for cockroach: %w", err)
   361  	}
   363  	currentRevision, err := migrations.NewCRDBDriver(cds.dburl)
   364  	if err != nil {
   365  		return datastore.ReadyState{}, err
   366  	}
   367  	defer currentRevision.Close(ctx)
   369  	version, err := currentRevision.Version(ctx)
   370  	if err != nil {
   371  		return datastore.ReadyState{}, err
   372  	}
   374  	// TODO(jschorr): Remove the check for the older migration once we are confident
   375  	// that all users have migrated past it.
   376  	if version != headMigration && version != "add-caveats" {
   377  		return datastore.ReadyState{
   378  			Message: fmt.Sprintf(
   379  				"datastore is not migrated: currently at revision `%s`, but requires `%s`. Please run `spicedb migrate`.",
   380  				version,
   381  				headMigration,
   382  			),
   383  			IsReady: false,
   384  		}, nil
   385  	}
   387  	readMin := cds.readPool.MinConns()
   388  	if readMin > 0 {
   389  		readMin--
   390  	}
   391  	writeMin := cds.writePool.MinConns()
   392  	if writeMin > 0 {
   393  		writeMin--
   394  	}
   395  	writeTotal := uint32(cds.writePool.Stat().TotalConns())
   396  	readTotal := uint32(cds.readPool.Stat().TotalConns())
   397  	if writeTotal < writeMin || readTotal < readMin {
   398  		return datastore.ReadyState{
   399  			Message: fmt.Sprintf(
   400  				"spicedb does not have the required minimum connection count to the datastore. Read: %d/%d, Write: %d/%d",
   401  				readTotal,
   402  				readMin,
   403  				writeTotal,
   404  				writeMin,
   405  			),
   406  			IsReady: false,
   407  		}, nil
   408  	}
   409  	return datastore.ReadyState{IsReady: true}, nil
   410  }
   412  func (cds *crdbDatastore) Close() error {
   413  	cds.cancel()
   414  	cds.readPool.Close()
   415  	cds.writePool.Close()
   416  	return nil
   417  }
   419  func (cds *crdbDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) {
   420  	return cds.headRevisionInternal(ctx)
   421  }
   423  func (cds *crdbDatastore) headRevisionInternal(ctx context.Context) (datastore.Revision, error) {
   424  	var hlcNow datastore.Revision
   426  	var fnErr error
   427  	hlcNow, fnErr = readCRDBNow(ctx, cds.readPool)
   428  	if fnErr != nil {
   429  		return datastore.NoRevision, fmt.Errorf(errRevision, fnErr)
   430  	}
   432  	return hlcNow, fnErr
   433  }
   435  func (cds *crdbDatastore) Features(ctx context.Context) (*datastore.Features, error) {
   436  	features, _, err := cds.featureGroup.Do(ctx, "", func(ictx context.Context) (*datastore.Features, error) {
   437  		return cds.features(ictx)
   438  	})
   439  	return features, err
   440  }
   442  func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, error) {
   443  	var features datastore.Features
   445  	head, err := cds.HeadRevision(ctx)
   446  	if err != nil {
   447  		return nil, err
   448  	}
   450  	// streams don't return at all if they succeed, so the only way to know
   451  	// it was created successfully is to wait a bit and then cancel
   452  	streamCtx, cancel := context.WithCancel(ctx)
   453  	defer cancel()
   454  	time.AfterFunc(1*time.Second, cancel)
   456  	_ = cds.writePool.ExecFunc(streamCtx, func(ctx context.Context, tag pgconn.CommandTag, err error) error {
   457  		if err != nil && errors.Is(err, context.Canceled) {
   458  			features.Watch.Enabled = true
   459  			features.Watch.Reason = ""
   460  		} else if err != nil {
   461  			features.Watch.Enabled = false
   462  			features.Watch.Reason = fmt.Sprintf("Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: %s", err.Error())
   463  		}
   464  		return nil
   465  	}, fmt.Sprintf(cds.beginChangefeedQuery, tableTuple, head, "1s"))
   467  	<-streamCtx.Done()
   469  	return &features, nil
   470  }
   472  func (cds *crdbDatastore) readTransactionCommitRev(ctx context.Context, reader pgxcommon.DBFuncQuerier) (datastore.Revision, error) {
   473  	ctx, span := tracer.Start(ctx, "readTransactionCommitRev")
   474  	defer span.End()
   476  	var hlcNow decimal.Decimal
   477  	if err := reader.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
   478  		return row.Scan(&hlcNow)
   479  	}, cds.transactionNowQuery); err != nil {
   480  		return datastore.NoRevision, fmt.Errorf("unable to read timestamp: %w", err)
   481  	}
   483  	return revisions.NewForHLC(hlcNow)
   484  }
   486  func readCRDBNow(ctx context.Context, reader pgxcommon.DBFuncQuerier) (datastore.Revision, error) {
   487  	ctx, span := tracer.Start(ctx, "readCRDBNow")
   488  	defer span.End()
   490  	var hlcNow decimal.Decimal
   491  	if err := reader.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
   492  		return row.Scan(&hlcNow)
   493  	}, querySelectNow); err != nil {
   494  		return datastore.NoRevision, fmt.Errorf("unable to read timestamp: %w", err)
   495  	}
   497  	return revisions.NewForHLC(hlcNow)
   498  }
   500  func readClusterTTLNanos(ctx context.Context, conn pgxcommon.DBFuncQuerier) (int64, error) {
   501  	var target, configSQL string
   503  	if err := conn.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
   504  		return row.Scan(&target, &configSQL)
   505  	}, queryShowZoneConfig); err != nil {
   506  		return 0, err
   507  	}
   509  	groups := gcTTLRegex.FindStringSubmatch(configSQL)
   510  	if groups == nil || len(groups) != 2 {
   511  		return 0, fmt.Errorf("CRDB zone config unexpected format")
   512  	}
   514  	gcSeconds, err := strconv.ParseInt(groups[1], 10, 64)
   515  	if err != nil {
   516  		return 0, err
   517  	}
   519  	return gcSeconds * 1_000_000_000, nil
   520  }