github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/watch.go (about)

     1  package crdb
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/prometheus/client_golang/prometheus"
    14  
    15  	"github.com/authzed/spicedb/internal/datastore/common"
    16  	"github.com/authzed/spicedb/internal/datastore/crdb/pool"
    17  	pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
    18  	"github.com/authzed/spicedb/internal/datastore/revisions"
    19  	"github.com/authzed/spicedb/pkg/datastore"
    20  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    21  	"github.com/authzed/spicedb/pkg/spiceerrors"
    22  )
    23  
    24  const (
    25  	queryChangefeed       = "EXPERIMENTAL CHANGEFEED FOR %s WITH updated, cursor = '%s', resolved = '%s', min_checkpoint_frequency = '0';"
    26  	queryChangefeedPreV22 = "EXPERIMENTAL CHANGEFEED FOR %s WITH updated, cursor = '%s', resolved = '%s';"
    27  )
    28  
    29  var retryHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
    30  	Namespace: "spicedb",
    31  	Subsystem: "datastore",
    32  	Name:      "crdb_watch_retries",
    33  	Help:      "watch retry distribution",
    34  	Buckets:   []float64{0, 1, 2, 5, 10, 20, 50},
    35  })
    36  
    37  func init() {
    38  	prometheus.MustRegister(retryHistogram)
    39  }
    40  
    41  type changeDetails struct {
    42  	Resolved string
    43  	Updated  string
    44  	After    *struct {
    45  		Namespace                 string `json:"namespace"`
    46  		SerializedNamespaceConfig string `json:"serialized_config"`
    47  
    48  		CaveatName                 string `json:"name"`
    49  		SerializedCaveatDefinition string `json:"definition"`
    50  
    51  		RelationshipCaveatContext map[string]any `json:"caveat_context"`
    52  		RelationshipCaveatName    string         `json:"caveat_name"`
    53  	}
    54  }
    55  
    56  func (cds *crdbDatastore) Watch(ctx context.Context, afterRevision datastore.Revision, options datastore.WatchOptions) (<-chan *datastore.RevisionChanges, <-chan error) {
    57  	watchBufferLength := options.WatchBufferLength
    58  	if watchBufferLength <= 0 {
    59  		watchBufferLength = cds.watchBufferLength
    60  	}
    61  
    62  	updates := make(chan *datastore.RevisionChanges, watchBufferLength)
    63  	errs := make(chan error, 1)
    64  
    65  	features, err := cds.Features(ctx)
    66  	if err != nil {
    67  		errs <- err
    68  		return updates, errs
    69  	}
    70  
    71  	if !features.Watch.Enabled {
    72  		errs <- datastore.NewWatchDisabledErr(fmt.Sprintf("%s. See https://spicedb.dev/d/enable-watch-api-crdb", features.Watch.Reason))
    73  		return updates, errs
    74  	}
    75  
    76  	go cds.watch(ctx, afterRevision, options, updates, errs)
    77  
    78  	return updates, errs
    79  }
    80  
    81  func (cds *crdbDatastore) watch(
    82  	ctx context.Context,
    83  	afterRevision datastore.Revision,
    84  	opts datastore.WatchOptions,
    85  	updates chan *datastore.RevisionChanges,
    86  	errs chan error,
    87  ) {
    88  	defer close(updates)
    89  	defer close(errs)
    90  
    91  	// get non-pooled connection for watch
    92  	// "applications should explicitly create dedicated connections to consume
    93  	// changefeed data, instead of using a connection pool as most client
    94  	// drivers do by default."
    95  	// see: https://www.cockroachlabs.com/docs/v22.2/changefeed-for#considerations
    96  	conn, err := pgxcommon.ConnectWithInstrumentation(ctx, cds.dburl)
    97  	if err != nil {
    98  		errs <- err
    99  		return
   100  	}
   101  	defer func() { _ = conn.Close(ctx) }()
   102  
   103  	tableNames := make([]string, 0, 3)
   104  	if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships {
   105  		tableNames = append(tableNames, tableTuple)
   106  	}
   107  	if opts.Content&datastore.WatchSchema == datastore.WatchSchema {
   108  		tableNames = append(tableNames, tableNamespace)
   109  		tableNames = append(tableNames, tableCaveat)
   110  	}
   111  
   112  	if len(tableNames) == 0 {
   113  		errs <- fmt.Errorf("at least relationships or schema must be specified")
   114  		return
   115  	}
   116  
   117  	if opts.CheckpointInterval < 0 {
   118  		errs <- fmt.Errorf("invalid checkpoint interval given")
   119  		return
   120  	}
   121  
   122  	// Default: 1s
   123  	resolvedDuration := 1 * time.Second
   124  	if opts.CheckpointInterval > 0 {
   125  		resolvedDuration = opts.CheckpointInterval
   126  	}
   127  
   128  	resolvedDurationString := strconv.FormatInt(resolvedDuration.Milliseconds(), 10) + "ms"
   129  	interpolated := fmt.Sprintf(cds.beginChangefeedQuery, strings.Join(tableNames, ","), afterRevision, resolvedDurationString)
   130  
   131  	sendError := func(err error) {
   132  		if errors.Is(ctx.Err(), context.Canceled) {
   133  			errs <- datastore.NewWatchCanceledErr()
   134  			return
   135  		}
   136  
   137  		if pool.IsResettableError(ctx, err) || pool.IsRetryableError(ctx, err) {
   138  			errs <- datastore.NewWatchTemporaryErr(err)
   139  			return
   140  		}
   141  
   142  		errs <- err
   143  	}
   144  
   145  	watchBufferWriteTimeout := opts.WatchBufferWriteTimeout
   146  	if watchBufferWriteTimeout <= 0 {
   147  		watchBufferWriteTimeout = cds.watchBufferWriteTimeout
   148  	}
   149  
   150  	sendChange := func(change *datastore.RevisionChanges) bool {
   151  		select {
   152  		case updates <- change:
   153  			return true
   154  
   155  		default:
   156  			// If we cannot immediately write, setup the timer and try again.
   157  		}
   158  
   159  		timer := time.NewTimer(watchBufferWriteTimeout)
   160  		defer timer.Stop()
   161  
   162  		select {
   163  		case updates <- change:
   164  			return true
   165  
   166  		case <-timer.C:
   167  			errs <- datastore.NewWatchDisconnectedErr()
   168  			return false
   169  		}
   170  	}
   171  
   172  	changes, err := conn.Query(ctx, interpolated)
   173  	if err != nil {
   174  		sendError(err)
   175  		return
   176  	}
   177  
   178  	// We call Close async here because it can be slow and blocks closing the channels. There is
   179  	// no return value so we're not really losing anything.
   180  	defer func() { go changes.Close() }()
   181  
   182  	tracked := common.NewChanges(revisions.HLCKeyFunc, opts.Content)
   183  
   184  	for changes.Next() {
   185  		var tableNameBytes []byte
   186  		var changeJSON []byte
   187  		var primaryKeyValuesJSON []byte
   188  
   189  		// Pull in the table name, the primary key(s) and change information.
   190  		if err := changes.Scan(&tableNameBytes, &primaryKeyValuesJSON, &changeJSON); err != nil {
   191  			sendError(err)
   192  			return
   193  		}
   194  
   195  		var details changeDetails
   196  		if err := json.Unmarshal(changeJSON, &details); err != nil {
   197  			sendError(err)
   198  			return
   199  		}
   200  
   201  		// Resolved indicates that the specified revision is "complete"; no additional updates can come in before or at it.
   202  		// Therefore, at this point, we issue tracked updates from before that time, and the checkpoint update.
   203  		if details.Resolved != "" {
   204  			rev, err := revisions.HLCRevisionFromString(details.Resolved)
   205  			if err != nil {
   206  				sendError(fmt.Errorf("malformed resolved timestamp: %w", err))
   207  				return
   208  			}
   209  
   210  			for _, revChange := range tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) {
   211  				revChange := revChange
   212  				if !sendChange(&revChange) {
   213  					return
   214  				}
   215  			}
   216  
   217  			if opts.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints {
   218  				if !sendChange(&datastore.RevisionChanges{
   219  					Revision:     rev,
   220  					IsCheckpoint: true,
   221  				}) {
   222  					return
   223  				}
   224  			}
   225  			continue
   226  		}
   227  
   228  		// Otherwise, this a notification of a row change.
   229  		tableName := string(tableNameBytes)
   230  
   231  		var pkValues []string
   232  		if err := json.Unmarshal(primaryKeyValuesJSON, &pkValues); err != nil {
   233  			sendError(err)
   234  			return
   235  		}
   236  
   237  		switch tableName {
   238  		case tableTuple:
   239  			var caveatName string
   240  			var caveatContext map[string]any
   241  			if details.After != nil && details.After.RelationshipCaveatName != "" {
   242  				caveatName = details.After.RelationshipCaveatName
   243  				caveatContext = details.After.RelationshipCaveatContext
   244  			}
   245  			ctxCaveat, err := common.ContextualizedCaveatFrom(caveatName, caveatContext)
   246  			if err != nil {
   247  				sendError(err)
   248  				return
   249  			}
   250  
   251  			tuple := &core.RelationTuple{
   252  				ResourceAndRelation: &core.ObjectAndRelation{
   253  					Namespace: pkValues[0],
   254  					ObjectId:  pkValues[1],
   255  					Relation:  pkValues[2],
   256  				},
   257  				Subject: &core.ObjectAndRelation{
   258  					Namespace: pkValues[3],
   259  					ObjectId:  pkValues[4],
   260  					Relation:  pkValues[5],
   261  				},
   262  				Caveat: ctxCaveat,
   263  			}
   264  
   265  			rev, err := revisions.HLCRevisionFromString(details.Updated)
   266  			if err != nil {
   267  				sendError(fmt.Errorf("malformed update timestamp: %w", err))
   268  				return
   269  			}
   270  
   271  			if details.After == nil {
   272  				if err := tracked.AddRelationshipChange(ctx, rev, tuple, core.RelationTupleUpdate_DELETE); err != nil {
   273  					sendError(err)
   274  					return
   275  				}
   276  			} else {
   277  				if err := tracked.AddRelationshipChange(ctx, rev, tuple, core.RelationTupleUpdate_TOUCH); err != nil {
   278  					sendError(err)
   279  					return
   280  				}
   281  			}
   282  
   283  		case tableNamespace:
   284  			if len(pkValues) != 1 {
   285  				sendError(spiceerrors.MustBugf("expected a single definition name for the primary key in change feed. found: %s", string(primaryKeyValuesJSON)))
   286  				return
   287  			}
   288  
   289  			definitionName := pkValues[0]
   290  
   291  			rev, err := revisions.HLCRevisionFromString(details.Updated)
   292  			if err != nil {
   293  				sendError(fmt.Errorf("malformed update timestamp: %w", err))
   294  				return
   295  			}
   296  
   297  			if details.After != nil && details.After.SerializedNamespaceConfig != "" {
   298  				namespaceDef := &core.NamespaceDefinition{}
   299  				defBytes, err := hex.DecodeString(details.After.SerializedNamespaceConfig[2:]) // drop the \x
   300  				if err != nil {
   301  					sendError(fmt.Errorf("could not decode namespace definition: %w", err))
   302  					return
   303  				}
   304  
   305  				if err := namespaceDef.UnmarshalVT(defBytes); err != nil {
   306  					sendError(fmt.Errorf("could not unmarshal namespace definition: %w", err))
   307  					return
   308  				}
   309  				tracked.AddChangedDefinition(ctx, rev, namespaceDef)
   310  			} else {
   311  				tracked.AddDeletedNamespace(ctx, rev, definitionName)
   312  			}
   313  
   314  		case tableCaveat:
   315  			if len(pkValues) != 1 {
   316  				sendError(spiceerrors.MustBugf("expected a single definition name for the primary key in change feed. found: %s", string(primaryKeyValuesJSON)))
   317  				return
   318  			}
   319  
   320  			definitionName := pkValues[0]
   321  
   322  			rev, err := revisions.HLCRevisionFromString(details.Updated)
   323  			if err != nil {
   324  				sendError(fmt.Errorf("malformed update timestamp: %w", err))
   325  				return
   326  			}
   327  
   328  			if details.After != nil && details.After.SerializedCaveatDefinition != "" {
   329  				caveatDef := &core.CaveatDefinition{}
   330  				defBytes, err := hex.DecodeString(details.After.SerializedCaveatDefinition[2:]) // drop the \x
   331  				if err != nil {
   332  					sendError(fmt.Errorf("could not decode caveat definition: %w", err))
   333  					return
   334  				}
   335  
   336  				if err := caveatDef.UnmarshalVT(defBytes); err != nil {
   337  					sendError(fmt.Errorf("could not unmarshal caveat definition: %w", err))
   338  					return
   339  				}
   340  				tracked.AddChangedDefinition(ctx, rev, caveatDef)
   341  			} else {
   342  				tracked.AddDeletedCaveat(ctx, rev, definitionName)
   343  			}
   344  		}
   345  	}
   346  
   347  	if changes.Err() != nil {
   348  		if errors.Is(ctx.Err(), context.Canceled) {
   349  			closeCtx, closeCancel := context.WithTimeout(context.Background(), 5*time.Second)
   350  			defer closeCancel()
   351  			if err := conn.Close(closeCtx); err != nil {
   352  				errs <- err
   353  				return
   354  			}
   355  			errs <- datastore.NewWatchCanceledErr()
   356  		} else {
   357  			errs <- changes.Err()
   358  		}
   359  		return
   360  	}
   361  }