github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/watch.go (about)

     1  package postgres
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	sq "github.com/Masterminds/squirrel"
    10  	"github.com/jackc/pgx/v5"
    11  	"google.golang.org/protobuf/types/known/structpb"
    12  
    13  	"github.com/authzed/spicedb/internal/datastore/common"
    14  	pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
    15  	"github.com/authzed/spicedb/pkg/datastore"
    16  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    17  )
    18  
    19  const (
    20  	minimumWatchSleep = 100 * time.Millisecond
    21  )
    22  
    23  type revisionWithXid struct {
    24  	postgresRevision
    25  	tx xid8
    26  }
    27  
    28  var (
    29  	// This query must cast an xid8 to xid, which is a safe operation as long as the
    30  	// xid8 is one of the last ~2 billion transaction IDs generated. We should be garbage
    31  	// collecting these transactions long before we get to that point.
    32  	newRevisionsQuery = fmt.Sprintf(`
    33  	SELECT %[1]s, %[2]s FROM %[3]s
    34  	WHERE %[1]s >= pg_snapshot_xmax($1) OR (
    35  		%[1]s >= pg_snapshot_xmin($1) AND NOT pg_visible_in_snapshot(%[1]s, $1)
    36  	) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, tableTransaction)
    37  
    38  	queryChangedTuples = psql.Select(
    39  		colNamespace,
    40  		colObjectID,
    41  		colRelation,
    42  		colUsersetNamespace,
    43  		colUsersetObjectID,
    44  		colUsersetRelation,
    45  		colCaveatContextName,
    46  		colCaveatContext,
    47  		colCreatedXid,
    48  		colDeletedXid,
    49  	).From(tableTuple)
    50  
    51  	queryChangedNamespaces = psql.Select(
    52  		colConfig,
    53  		colCreatedXid,
    54  		colDeletedXid,
    55  	).From(tableNamespace)
    56  
    57  	queryChangedCaveats = psql.Select(
    58  		colCaveatName,
    59  		colCaveatDefinition,
    60  		colCreatedXid,
    61  		colDeletedXid,
    62  	).From(tableCaveat)
    63  )
    64  
    65  func (pgd *pgDatastore) Watch(
    66  	ctx context.Context,
    67  	afterRevisionRaw datastore.Revision,
    68  	options datastore.WatchOptions,
    69  ) (<-chan *datastore.RevisionChanges, <-chan error) {
    70  	watchBufferLength := options.WatchBufferLength
    71  	if watchBufferLength <= 0 {
    72  		watchBufferLength = pgd.watchBufferLength
    73  	}
    74  
    75  	updates := make(chan *datastore.RevisionChanges, watchBufferLength)
    76  	errs := make(chan error, 1)
    77  
    78  	if !pgd.watchEnabled {
    79  		errs <- datastore.NewWatchDisabledErr("postgres must be run with track_commit_timestamp=on for watch to be enabled. See https://spicedb.dev/d/enable-watch-api-postgres")
    80  		return updates, errs
    81  	}
    82  
    83  	afterRevision := afterRevisionRaw.(postgresRevision)
    84  	watchSleep := options.CheckpointInterval
    85  	if watchSleep < minimumWatchSleep {
    86  		watchSleep = minimumWatchSleep
    87  	}
    88  
    89  	watchBufferWriteTimeout := options.WatchBufferWriteTimeout
    90  	if watchBufferWriteTimeout <= 0 {
    91  		watchBufferWriteTimeout = pgd.watchBufferWriteTimeout
    92  	}
    93  
    94  	sendChange := func(change *datastore.RevisionChanges) bool {
    95  		select {
    96  		case updates <- change:
    97  			return true
    98  
    99  		default:
   100  			// If we cannot immediately write, setup the timer and try again.
   101  		}
   102  
   103  		timer := time.NewTimer(watchBufferWriteTimeout)
   104  		defer timer.Stop()
   105  
   106  		select {
   107  		case updates <- change:
   108  			return true
   109  
   110  		case <-timer.C:
   111  			errs <- datastore.NewWatchDisconnectedErr()
   112  			return false
   113  		}
   114  	}
   115  
   116  	go func() {
   117  		defer close(updates)
   118  		defer close(errs)
   119  
   120  		currentTxn := afterRevision
   121  
   122  		for {
   123  			newTxns, err := pgd.getNewRevisions(ctx, currentTxn)
   124  			if err != nil {
   125  				if errors.Is(ctx.Err(), context.Canceled) {
   126  					errs <- datastore.NewWatchCanceledErr()
   127  				} else if pgxcommon.IsCancellationError(err) {
   128  					errs <- datastore.NewWatchCanceledErr()
   129  				} else {
   130  					errs <- err
   131  				}
   132  				return
   133  			}
   134  
   135  			if len(newTxns) > 0 {
   136  				changesToWrite, err := pgd.loadChanges(ctx, newTxns, options)
   137  				if err != nil {
   138  					if errors.Is(ctx.Err(), context.Canceled) {
   139  						errs <- datastore.NewWatchCanceledErr()
   140  					} else {
   141  						errs <- err
   142  					}
   143  					return
   144  				}
   145  
   146  				for _, changeToWrite := range changesToWrite {
   147  					changeToWrite := changeToWrite
   148  					if !sendChange(&changeToWrite) {
   149  						return
   150  					}
   151  				}
   152  
   153  				// In order to make progress, we need to ensure that any seen transactions here are
   154  				// marked as done in the revision given back to Postgres on the next iteration. We pick
   155  				// the *last* transaction to start, as it should encompass all completed transactions
   156  				// except those running concurrently, which is handled by calling markComplete on the other
   157  				// transactions.
   158  				currentTxn = newTxns[len(newTxns)-1].postgresRevision
   159  				for _, newTx := range newTxns {
   160  					currentTxn = postgresRevision{currentTxn.snapshot.markComplete(newTx.tx.Uint64)}
   161  				}
   162  
   163  				// If checkpoints were requested, output a checkpoint. While the Postgres datastore does not
   164  				// move revisions forward outside of changes, these could be necessary if the caller is
   165  				// watching only a *subset* of changes.
   166  				if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints {
   167  					if !sendChange(&datastore.RevisionChanges{
   168  						Revision:     currentTxn,
   169  						IsCheckpoint: true,
   170  					}) {
   171  						return
   172  					}
   173  				}
   174  			} else {
   175  				sleep := time.NewTimer(watchSleep)
   176  
   177  				select {
   178  				case <-sleep.C:
   179  					break
   180  				case <-ctx.Done():
   181  					errs <- datastore.NewWatchCanceledErr()
   182  					return
   183  				}
   184  			}
   185  		}
   186  	}()
   187  
   188  	return updates, errs
   189  }
   190  
   191  func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRevision) ([]revisionWithXid, error) {
   192  	var ids []revisionWithXid
   193  	if err := pgx.BeginTxFunc(ctx, pgd.readPool, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, func(tx pgx.Tx) error {
   194  		rows, err := tx.Query(ctx, newRevisionsQuery, afterTX.snapshot)
   195  		if err != nil {
   196  			return fmt.Errorf("unable to load new revisions: %w", err)
   197  		}
   198  		defer rows.Close()
   199  
   200  		for rows.Next() {
   201  			var nextXID xid8
   202  			var nextSnapshot pgSnapshot
   203  			if err := rows.Scan(&nextXID, &nextSnapshot); err != nil {
   204  				return fmt.Errorf("unable to decode new revision: %w", err)
   205  			}
   206  
   207  			ids = append(ids, revisionWithXid{
   208  				postgresRevision{nextSnapshot.markComplete(nextXID.Uint64)},
   209  				nextXID,
   210  			})
   211  		}
   212  		if rows.Err() != nil {
   213  			return fmt.Errorf("unable to load new revisions: %w", err)
   214  		}
   215  		return nil
   216  	}); err != nil {
   217  		return nil, fmt.Errorf("transaction error: %w", err)
   218  	}
   219  
   220  	return ids, nil
   221  }
   222  
   223  func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []revisionWithXid, options datastore.WatchOptions) ([]datastore.RevisionChanges, error) {
   224  	xmin := revisions[0].tx.Uint64
   225  	xmax := revisions[0].tx.Uint64
   226  	filter := make(map[uint64]int, len(revisions))
   227  	txidToRevision := make(map[uint64]revisionWithXid, len(revisions))
   228  
   229  	for i, rev := range revisions {
   230  		if rev.tx.Uint64 < xmin {
   231  			xmin = rev.tx.Uint64
   232  		}
   233  		if rev.tx.Uint64 > xmax {
   234  			xmax = rev.tx.Uint64
   235  		}
   236  		filter[rev.tx.Uint64] = i
   237  		txidToRevision[rev.tx.Uint64] = rev
   238  	}
   239  
   240  	tracked := common.NewChanges(revisionKeyFunc, options.Content)
   241  
   242  	// Load relationship changes.
   243  	if options.Content&datastore.WatchRelationships == datastore.WatchRelationships {
   244  		err := pgd.loadRelationshipChanges(ctx, xmin, xmax, txidToRevision, filter, tracked)
   245  		if err != nil {
   246  			return nil, err
   247  		}
   248  	}
   249  
   250  	// Load namespace changes.
   251  	if options.Content&datastore.WatchSchema == datastore.WatchSchema {
   252  		err := pgd.loadNamespaceChanges(ctx, xmin, xmax, txidToRevision, filter, tracked)
   253  		if err != nil {
   254  			return nil, err
   255  		}
   256  	}
   257  
   258  	// Load caveat changes.
   259  	if options.Content&datastore.WatchSchema == datastore.WatchSchema {
   260  		err := pgd.loadCaveatChanges(ctx, xmin, xmax, txidToRevision, filter, tracked)
   261  		if err != nil {
   262  			return nil, err
   263  		}
   264  	}
   265  
   266  	// Reconcile the changes.
   267  	reconciledChanges := tracked.AsRevisionChanges(func(lhs, rhs uint64) bool {
   268  		return filter[lhs] < filter[rhs]
   269  	})
   270  	return reconciledChanges, nil
   271  }
   272  
   273  func (pgd *pgDatastore) loadRelationshipChanges(ctx context.Context, xmin uint64, xmax uint64, txidToRevision map[uint64]revisionWithXid, filter map[uint64]int, tracked *common.Changes[revisionWithXid, uint64]) error {
   274  	sql, args, err := queryChangedTuples.Where(sq.Or{
   275  		sq.And{
   276  			sq.LtOrEq{colCreatedXid: xmax},
   277  			sq.GtOrEq{colCreatedXid: xmin},
   278  		},
   279  		sq.And{
   280  			sq.LtOrEq{colDeletedXid: xmax},
   281  			sq.GtOrEq{colDeletedXid: xmin},
   282  		},
   283  	}).ToSql()
   284  	if err != nil {
   285  		return fmt.Errorf("unable to prepare changes SQL: %w", err)
   286  	}
   287  
   288  	changes, err := pgd.readPool.Query(ctx, sql, args...)
   289  	if err != nil {
   290  		return fmt.Errorf("unable to load changes for XID: %w", err)
   291  	}
   292  
   293  	defer changes.Close()
   294  
   295  	for changes.Next() {
   296  		nextTuple := &core.RelationTuple{
   297  			ResourceAndRelation: &core.ObjectAndRelation{},
   298  			Subject:             &core.ObjectAndRelation{},
   299  		}
   300  
   301  		var createdXID, deletedXID xid8
   302  		var caveatName *string
   303  		var caveatContext map[string]any
   304  		if err := changes.Scan(
   305  			&nextTuple.ResourceAndRelation.Namespace,
   306  			&nextTuple.ResourceAndRelation.ObjectId,
   307  			&nextTuple.ResourceAndRelation.Relation,
   308  			&nextTuple.Subject.Namespace,
   309  			&nextTuple.Subject.ObjectId,
   310  			&nextTuple.Subject.Relation,
   311  			&caveatName,
   312  			&caveatContext,
   313  			&createdXID,
   314  			&deletedXID,
   315  		); err != nil {
   316  			return fmt.Errorf("unable to parse changed tuple: %w", err)
   317  		}
   318  
   319  		if caveatName != nil && *caveatName != "" {
   320  			contextStruct, err := structpb.NewStruct(caveatContext)
   321  			if err != nil {
   322  				return fmt.Errorf("failed to read caveat context from update: %w", err)
   323  			}
   324  			nextTuple.Caveat = &core.ContextualizedCaveat{
   325  				CaveatName: *caveatName,
   326  				Context:    contextStruct,
   327  			}
   328  		}
   329  
   330  		if _, found := filter[createdXID.Uint64]; found {
   331  			if err := tracked.AddRelationshipChange(ctx, txidToRevision[createdXID.Uint64], nextTuple, core.RelationTupleUpdate_TOUCH); err != nil {
   332  				return err
   333  			}
   334  		}
   335  		if _, found := filter[deletedXID.Uint64]; found {
   336  			if err := tracked.AddRelationshipChange(ctx, txidToRevision[deletedXID.Uint64], nextTuple, core.RelationTupleUpdate_DELETE); err != nil {
   337  				return err
   338  			}
   339  		}
   340  	}
   341  	if changes.Err() != nil {
   342  		return fmt.Errorf("unable to load changes for XID: %w", err)
   343  	}
   344  	return nil
   345  }
   346  
   347  func (pgd *pgDatastore) loadNamespaceChanges(ctx context.Context, xmin uint64, xmax uint64, txidToRevision map[uint64]revisionWithXid, filter map[uint64]int, tracked *common.Changes[revisionWithXid, uint64]) error {
   348  	sql, args, err := queryChangedNamespaces.Where(sq.Or{
   349  		sq.And{
   350  			sq.LtOrEq{colCreatedXid: xmax},
   351  			sq.GtOrEq{colCreatedXid: xmin},
   352  		},
   353  		sq.And{
   354  			sq.LtOrEq{colDeletedXid: xmax},
   355  			sq.GtOrEq{colDeletedXid: xmin},
   356  		},
   357  	}).ToSql()
   358  	if err != nil {
   359  		return fmt.Errorf("unable to prepare changes SQL: %w", err)
   360  	}
   361  
   362  	changes, err := pgd.readPool.Query(ctx, sql, args...)
   363  	if err != nil {
   364  		return fmt.Errorf("unable to load changes for XID: %w", err)
   365  	}
   366  
   367  	defer changes.Close()
   368  
   369  	for changes.Next() {
   370  		var createdXID, deletedXID xid8
   371  		var config []byte
   372  		if err := changes.Scan(
   373  			&config,
   374  			&createdXID,
   375  			&deletedXID,
   376  		); err != nil {
   377  			return fmt.Errorf("unable to parse changed namespace: %w", err)
   378  		}
   379  
   380  		loaded := &core.NamespaceDefinition{}
   381  		if err := loaded.UnmarshalVT(config); err != nil {
   382  			return fmt.Errorf(errUnableToReadConfig, err)
   383  		}
   384  
   385  		if _, found := filter[createdXID.Uint64]; found {
   386  			tracked.AddChangedDefinition(ctx, txidToRevision[deletedXID.Uint64], loaded)
   387  		}
   388  		if _, found := filter[deletedXID.Uint64]; found {
   389  			tracked.AddDeletedNamespace(ctx, txidToRevision[deletedXID.Uint64], loaded.Name)
   390  		}
   391  	}
   392  	if changes.Err() != nil {
   393  		return fmt.Errorf("unable to load changes for XID: %w", err)
   394  	}
   395  	return nil
   396  }
   397  
   398  func (pgd *pgDatastore) loadCaveatChanges(ctx context.Context, min uint64, max uint64, txidToRevision map[uint64]revisionWithXid, filter map[uint64]int, tracked *common.Changes[revisionWithXid, uint64]) error {
   399  	sql, args, err := queryChangedCaveats.Where(sq.Or{
   400  		sq.And{
   401  			sq.LtOrEq{colCreatedXid: max},
   402  			sq.GtOrEq{colCreatedXid: min},
   403  		},
   404  		sq.And{
   405  			sq.LtOrEq{colDeletedXid: max},
   406  			sq.GtOrEq{colDeletedXid: min},
   407  		},
   408  	}).ToSql()
   409  	if err != nil {
   410  		return fmt.Errorf("unable to prepare changes SQL: %w", err)
   411  	}
   412  
   413  	changes, err := pgd.readPool.Query(ctx, sql, args...)
   414  	if err != nil {
   415  		return fmt.Errorf("unable to load changes for XID: %w", err)
   416  	}
   417  
   418  	defer changes.Close()
   419  
   420  	for changes.Next() {
   421  		var createdXID, deletedXID xid8
   422  		var config []byte
   423  		var name string
   424  		if err := changes.Scan(
   425  			&name,
   426  			&config,
   427  			&createdXID,
   428  			&deletedXID,
   429  		); err != nil {
   430  			return fmt.Errorf("unable to parse changed caveat: %w", err)
   431  		}
   432  
   433  		loaded := &core.CaveatDefinition{}
   434  		if err := loaded.UnmarshalVT(config); err != nil {
   435  			return fmt.Errorf(errUnableToReadConfig, err)
   436  		}
   437  
   438  		if _, found := filter[createdXID.Uint64]; found {
   439  			tracked.AddChangedDefinition(ctx, txidToRevision[deletedXID.Uint64], loaded)
   440  		}
   441  		if _, found := filter[deletedXID.Uint64]; found {
   442  			tracked.AddDeletedCaveat(ctx, txidToRevision[deletedXID.Uint64], loaded.Name)
   443  		}
   444  	}
   445  	if changes.Err() != nil {
   446  		return fmt.Errorf("unable to load changes for XID: %w", err)
   447  	}
   448  	return nil
   449  }