github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/spanner/watch.go (about)

     1  package spanner
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"regexp"
    10  	"time"
    11  
    12  	"cloud.google.com/go/spanner"
    13  	sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
    14  	"github.com/cloudspannerecosystem/spanner-change-streams-tail/changestreams"
    15  	"github.com/prometheus/client_golang/prometheus"
    16  	"google.golang.org/api/option"
    17  
    18  	"github.com/authzed/spicedb/internal/datastore/common"
    19  	"github.com/authzed/spicedb/internal/datastore/revisions"
    20  	"github.com/authzed/spicedb/pkg/datastore"
    21  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    22  	"github.com/authzed/spicedb/pkg/spiceerrors"
    23  )
    24  
    25  const (
    26  	CombinedChangeStreamName = "combined_change_stream"
    27  )
    28  
    29  var retryHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
    30  	Namespace: "spicedb",
    31  	Subsystem: "datastore",
    32  	Name:      "spanner_watch_retries",
    33  	Help:      "watch retry distribution",
    34  	Buckets:   []float64{0, 1, 2, 5, 10, 20, 50},
    35  })
    36  
    37  func init() {
    38  	prometheus.MustRegister(retryHistogram)
    39  }
    40  
    41  // Copied from the spanner library: https://github.com/googleapis/google-cloud-go/blob/f03779538f949fb4ad93d5247d3c6b3e5b21091a/spanner/client.go#L67
    42  // License: Apache License, Version 2.0, Copyright 2017 Google LLC
    43  var validDBPattern = regexp.MustCompile("^projects/(?P<project>[^/]+)/instances/(?P<instance>[^/]+)/databases/(?P<database>[^/]+)$")
    44  
    45  func parseDatabaseName(db string) (project, instance, database string, err error) {
    46  	matches := validDBPattern.FindStringSubmatch(db)
    47  	if len(matches) == 0 {
    48  		return "", "", "", fmt.Errorf("failed to parse database name from %q according to pattern %q",
    49  			db, validDBPattern.String())
    50  	}
    51  	return matches[1], matches[2], matches[3], nil
    52  }
    53  
    54  func (sd *spannerDatastore) Watch(ctx context.Context, afterRevision datastore.Revision, opts datastore.WatchOptions) (<-chan *datastore.RevisionChanges, <-chan error) {
    55  	watchBufferLength := opts.WatchBufferLength
    56  	if watchBufferLength <= 0 {
    57  		watchBufferLength = sd.watchBufferLength
    58  	}
    59  
    60  	updates := make(chan *datastore.RevisionChanges, watchBufferLength)
    61  	errs := make(chan error, 1)
    62  
    63  	go sd.watch(ctx, afterRevision, opts, updates, errs)
    64  
    65  	return updates, errs
    66  }
    67  
    68  func (sd *spannerDatastore) watch(
    69  	ctx context.Context,
    70  	afterRevisionRaw datastore.Revision,
    71  	opts datastore.WatchOptions,
    72  	updates chan *datastore.RevisionChanges,
    73  	errs chan error,
    74  ) {
    75  	defer close(updates)
    76  	defer close(errs)
    77  
    78  	// NOTE: 100ms is the minimum allowed.
    79  	heartbeatInterval := opts.CheckpointInterval
    80  	if heartbeatInterval < 100*time.Millisecond {
    81  		heartbeatInterval = 100 * time.Millisecond
    82  	}
    83  
    84  	sendError := func(err error) {
    85  		if errors.Is(ctx.Err(), context.Canceled) {
    86  			errs <- datastore.NewWatchCanceledErr()
    87  			return
    88  		}
    89  
    90  		errs <- err
    91  	}
    92  
    93  	watchBufferWriteTimeout := opts.WatchBufferWriteTimeout
    94  	if watchBufferWriteTimeout <= 0 {
    95  		watchBufferWriteTimeout = sd.watchBufferWriteTimeout
    96  	}
    97  
    98  	sendChange := func(change *datastore.RevisionChanges) bool {
    99  		select {
   100  		case updates <- change:
   101  			return true
   102  
   103  		default:
   104  			// If we cannot immediately write, setup the timer and try again.
   105  		}
   106  
   107  		timer := time.NewTimer(watchBufferWriteTimeout)
   108  		defer timer.Stop()
   109  
   110  		select {
   111  		case updates <- change:
   112  			return true
   113  
   114  		case <-timer.C:
   115  			errs <- datastore.NewWatchDisconnectedErr()
   116  			return false
   117  		}
   118  	}
   119  
   120  	project, instance, database, err := parseDatabaseName(sd.database)
   121  	if err != nil {
   122  		sendError(err)
   123  		return
   124  	}
   125  
   126  	afterRevision, ok := afterRevisionRaw.(revisions.TimestampRevision)
   127  	if !ok {
   128  		sendError(datastore.NewInvalidRevisionErr(afterRevisionRaw, datastore.CouldNotDetermineRevision))
   129  		return
   130  	}
   131  
   132  	reader, err := changestreams.NewReaderWithConfig(
   133  		ctx,
   134  		project,
   135  		instance,
   136  		database,
   137  		CombinedChangeStreamName,
   138  		changestreams.Config{
   139  			StartTimestamp:    afterRevision.Time(),
   140  			HeartbeatInterval: heartbeatInterval,
   141  			SpannerClientOptions: []option.ClientOption{
   142  				option.WithCredentialsFile(sd.config.credentialsFilePath),
   143  			},
   144  			SpannerClientConfig: spanner.ClientConfig{
   145  				QueryOptions: spanner.QueryOptions{
   146  					Priority: sppb.RequestOptions_PRIORITY_LOW,
   147  				},
   148  				ApplyOptions: []spanner.ApplyOption{
   149  					spanner.Priority(sppb.RequestOptions_PRIORITY_LOW),
   150  				},
   151  			},
   152  		})
   153  	if err != nil {
   154  		sendError(err)
   155  		return
   156  	}
   157  	defer reader.Close()
   158  
   159  	err = reader.Read(ctx, func(result *changestreams.ReadResult) error {
   160  		// See: https://cloud.google.com/spanner/docs/change-streams/details
   161  		for _, record := range result.ChangeRecords {
   162  			tracked := common.NewChanges(revisions.TimestampIDKeyFunc, opts.Content)
   163  
   164  			for _, dcr := range record.DataChangeRecords {
   165  				changeRevision := revisions.NewForTime(dcr.CommitTimestamp)
   166  				modType := dcr.ModType // options are INSERT, UPDATE, DELETE
   167  
   168  				for _, mod := range dcr.Mods {
   169  					primaryKeyColumnValues, ok := mod.Keys.Value.(map[string]any)
   170  					if !ok {
   171  						return spiceerrors.MustBugf("error converting keys map")
   172  					}
   173  
   174  					switch modType {
   175  					case "DELETE":
   176  						switch dcr.TableName {
   177  						case tableRelationship:
   178  							relationTuple := relationTupleFromPrimaryKey(primaryKeyColumnValues)
   179  
   180  							oldValues, ok := mod.OldValues.Value.(map[string]any)
   181  							if !ok {
   182  								return spiceerrors.MustBugf("error converting old values map")
   183  							}
   184  
   185  							relationTuple.Caveat, err = contextualizedCaveatFromValues(oldValues)
   186  							if err != nil {
   187  								return err
   188  							}
   189  
   190  							err := tracked.AddRelationshipChange(ctx, changeRevision, relationTuple, core.RelationTupleUpdate_DELETE)
   191  							if err != nil {
   192  								return err
   193  							}
   194  
   195  						case tableNamespace:
   196  							namespaceNameValue, ok := primaryKeyColumnValues[colNamespaceName]
   197  							if !ok {
   198  								return spiceerrors.MustBugf("missing namespace name value")
   199  							}
   200  
   201  							namespaceName, ok := namespaceNameValue.(string)
   202  							if !ok {
   203  								return spiceerrors.MustBugf("error converting namespace name: %v", primaryKeyColumnValues[colNamespaceName])
   204  							}
   205  
   206  							tracked.AddDeletedNamespace(ctx, changeRevision, namespaceName)
   207  
   208  						case tableCaveat:
   209  							caveatNameValue, ok := primaryKeyColumnValues[colNamespaceName]
   210  							if !ok {
   211  								return spiceerrors.MustBugf("missing caveat name")
   212  							}
   213  
   214  							caveatName, ok := caveatNameValue.(string)
   215  							if !ok {
   216  								return spiceerrors.MustBugf("error converting caveat name: %v", primaryKeyColumnValues[colName])
   217  							}
   218  
   219  							tracked.AddDeletedCaveat(ctx, changeRevision, caveatName)
   220  
   221  						default:
   222  							return spiceerrors.MustBugf("unknown table name %s in delete of change stream", dcr.TableName)
   223  						}
   224  
   225  					case "INSERT":
   226  						fallthrough
   227  
   228  					case "UPDATE":
   229  						newValues, ok := mod.NewValues.Value.(map[string]any)
   230  						if !ok {
   231  							return spiceerrors.MustBugf("error new values keys map")
   232  						}
   233  
   234  						switch dcr.TableName {
   235  						case tableRelationship:
   236  							relationTuple := relationTupleFromPrimaryKey(primaryKeyColumnValues)
   237  
   238  							oldValues, ok := mod.OldValues.Value.(map[string]any)
   239  							if !ok {
   240  								return spiceerrors.MustBugf("error converting old values map")
   241  							}
   242  
   243  							// NOTE: Spanner's change stream will return a record for a TOUCH operation that does not
   244  							// change anything. Therefore, we check  to see if the caveat name or context has changed
   245  							// between the old and new values, and only raise the event in that case. This works for
   246  							// caveat context because Spanner will return either `nil` or a string value of the JSON.
   247  							newValues, ok := mod.NewValues.Value.(map[string]any)
   248  							if !ok {
   249  								return spiceerrors.MustBugf("error converting new values map")
   250  							}
   251  
   252  							if oldValues[colCaveatName] == newValues[colCaveatName] && oldValues[colCaveatContext] == newValues[colCaveatContext] {
   253  								continue
   254  							}
   255  
   256  							relationTuple.Caveat, err = contextualizedCaveatFromValues(newValues)
   257  							if err != nil {
   258  								return err
   259  							}
   260  
   261  							err := tracked.AddRelationshipChange(ctx, changeRevision, relationTuple, core.RelationTupleUpdate_TOUCH)
   262  							if err != nil {
   263  								return err
   264  							}
   265  
   266  						case tableNamespace:
   267  							namespaceConfigValue, ok := newValues[colNamespaceConfig]
   268  							if !ok {
   269  								return spiceerrors.MustBugf("missing namespace config value")
   270  							}
   271  
   272  							ns := &core.NamespaceDefinition{}
   273  							if err := unmarshalSchemaDefinition(ns, namespaceConfigValue); err != nil {
   274  								return err
   275  							}
   276  
   277  							tracked.AddChangedDefinition(ctx, changeRevision, ns)
   278  
   279  						case tableCaveat:
   280  							caveatDefValue, ok := newValues[colCaveatDefinition]
   281  							if !ok {
   282  								return spiceerrors.MustBugf("missing caveat definition value")
   283  							}
   284  
   285  							caveat := &core.CaveatDefinition{}
   286  							if err := unmarshalSchemaDefinition(caveat, caveatDefValue); err != nil {
   287  								return err
   288  							}
   289  
   290  							tracked.AddChangedDefinition(ctx, changeRevision, caveat)
   291  
   292  						default:
   293  							return spiceerrors.MustBugf("unknown table name %s in delete of change stream", dcr.TableName)
   294  						}
   295  
   296  					default:
   297  						return spiceerrors.MustBugf("unknown modtype in spanner change stream record")
   298  					}
   299  				}
   300  			}
   301  
   302  			if !tracked.IsEmpty() {
   303  				for _, revChange := range tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) {
   304  					revChange := revChange
   305  					if !sendChange(&revChange) {
   306  						return datastore.NewWatchDisconnectedErr()
   307  					}
   308  				}
   309  			}
   310  
   311  			if opts.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints {
   312  				for _, hbr := range record.HeartbeatRecords {
   313  					if !sendChange(&datastore.RevisionChanges{
   314  						Revision:     revisions.NewForTime(hbr.Timestamp),
   315  						IsCheckpoint: true,
   316  					}) {
   317  						return datastore.NewWatchDisconnectedErr()
   318  					}
   319  				}
   320  			}
   321  		}
   322  		return nil
   323  	})
   324  	if err != nil {
   325  		sendError(err)
   326  		return
   327  	}
   328  }
   329  
   330  type unmarshallable interface {
   331  	UnmarshalVT([]byte) error
   332  }
   333  
   334  func unmarshalSchemaDefinition(def unmarshallable, configValue any) error {
   335  	base64SerializedConfig, ok := configValue.(string)
   336  	if !ok {
   337  		return spiceerrors.MustBugf("error converting config value")
   338  	}
   339  
   340  	serializedConfig, err := base64.StdEncoding.DecodeString(base64SerializedConfig)
   341  	if err != nil {
   342  		return fmt.Errorf(errUnableToReadConfig, err)
   343  	}
   344  
   345  	if err := def.UnmarshalVT(serializedConfig); err != nil {
   346  		return fmt.Errorf(errUnableToReadConfig, err)
   347  	}
   348  
   349  	return nil
   350  }
   351  
   352  func relationTupleFromPrimaryKey(primaryKeyColumnValues map[string]any) *core.RelationTuple {
   353  	return &core.RelationTuple{
   354  		ResourceAndRelation: &core.ObjectAndRelation{
   355  			Namespace: primaryKeyColumnValues[colNamespace].(string),
   356  			ObjectId:  primaryKeyColumnValues[colObjectID].(string),
   357  			Relation:  primaryKeyColumnValues[colRelation].(string),
   358  		},
   359  		Subject: &core.ObjectAndRelation{
   360  			Namespace: primaryKeyColumnValues[colUsersetNamespace].(string),
   361  			ObjectId:  primaryKeyColumnValues[colUsersetObjectID].(string),
   362  			Relation:  primaryKeyColumnValues[colUsersetRelation].(string),
   363  		},
   364  	}
   365  }
   366  
   367  func contextualizedCaveatFromValues(values map[string]any) (*core.ContextualizedCaveat, error) {
   368  	name := values[colCaveatName].(string)
   369  	if name != "" {
   370  		contextString := values[colCaveatContext]
   371  
   372  		// NOTE: spanner returns the JSON field as a string here.
   373  		var context map[string]any
   374  		if contextString != nil {
   375  			if err := json.Unmarshal([]byte(contextString.(string)), &context); err != nil {
   376  				return nil, err
   377  			}
   378  		}
   379  
   380  		return common.ContextualizedCaveatFrom(name, context)
   381  	}
   382  	return nil, nil
   383  }