github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/services/shared/schema.go (about)

     1  package shared
     2  
     3  import (
     4  	"context"
     5  
     6  	log "github.com/authzed/spicedb/internal/logging"
     7  	"github.com/authzed/spicedb/internal/namespace"
     8  	"github.com/authzed/spicedb/pkg/datastore"
     9  	"github.com/authzed/spicedb/pkg/datastore/options"
    10  	caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats"
    11  	nsdiff "github.com/authzed/spicedb/pkg/diff/namespace"
    12  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    13  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    14  	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
    15  	"github.com/authzed/spicedb/pkg/spiceerrors"
    16  	"github.com/authzed/spicedb/pkg/tuple"
    17  	"github.com/authzed/spicedb/pkg/typesystem"
    18  )
    19  
    20  // ValidatedSchemaChanges is a set of validated schema changes that can be applied to the datastore.
    21  type ValidatedSchemaChanges struct {
    22  	compiled             *compiler.CompiledSchema
    23  	validatedTypeSystems map[string]*typesystem.ValidatedNamespaceTypeSystem
    24  	newCaveatDefNames    *mapz.Set[string]
    25  	newObjectDefNames    *mapz.Set[string]
    26  	additiveOnly         bool
    27  }
    28  
    29  // ValidateSchemaChanges validates the schema found in the compiled schema and returns a
    30  // ValidatedSchemaChanges, if fully validated.
    31  func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchema, additiveOnly bool) (*ValidatedSchemaChanges, error) {
    32  	// 1) Validate the caveats defined.
    33  	newCaveatDefNames := mapz.NewSet[string]()
    34  	for _, caveatDef := range compiled.CaveatDefinitions {
    35  		if err := namespace.ValidateCaveatDefinition(caveatDef); err != nil {
    36  			return nil, err
    37  		}
    38  
    39  		newCaveatDefNames.Insert(caveatDef.Name)
    40  	}
    41  
    42  	// 2) Validate the namespaces defined.
    43  	newObjectDefNames := mapz.NewSet[string]()
    44  	validatedTypeSystems := make(map[string]*typesystem.ValidatedNamespaceTypeSystem, len(compiled.ObjectDefinitions))
    45  
    46  	for _, nsdef := range compiled.ObjectDefinitions {
    47  		ts, err := typesystem.NewNamespaceTypeSystem(nsdef,
    48  			typesystem.ResolverForPredefinedDefinitions(typesystem.PredefinedElements{
    49  				Namespaces: compiled.ObjectDefinitions,
    50  				Caveats:    compiled.CaveatDefinitions,
    51  			}))
    52  		if err != nil {
    53  			return nil, err
    54  		}
    55  
    56  		vts, err := ts.Validate(ctx)
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  
    61  		validatedTypeSystems[nsdef.Name] = vts
    62  		newObjectDefNames.Insert(nsdef.Name)
    63  	}
    64  
    65  	return &ValidatedSchemaChanges{
    66  		compiled:             compiled,
    67  		validatedTypeSystems: validatedTypeSystems,
    68  		newCaveatDefNames:    newCaveatDefNames,
    69  		newObjectDefNames:    newObjectDefNames,
    70  		additiveOnly:         additiveOnly,
    71  	}, nil
    72  }
    73  
    74  // AppliedSchemaChanges holds information about the applied schema changes.
    75  type AppliedSchemaChanges struct {
    76  	// TotalOperationCount holds the total number of "dispatch" operations performed by the schema
    77  	// being applied.
    78  	TotalOperationCount int
    79  
    80  	// NewObjectDefNames contains the names of the newly added object definitions.
    81  	NewObjectDefNames []string
    82  
    83  	// RemovedObjectDefNames contains the names of the removed object definitions.
    84  	RemovedObjectDefNames []string
    85  
    86  	// NewCaveatDefNames contains the names of the newly added caveat definitions.
    87  	NewCaveatDefNames []string
    88  
    89  	// RemovedCaveatDefNames contains the names of the removed caveat definitions.
    90  	RemovedCaveatDefNames []string
    91  }
    92  
    93  // ApplySchemaChanges applies schema changes found in the validated changes struct, via the specified
    94  // ReadWriteTransaction.
    95  func ApplySchemaChanges(ctx context.Context, rwt datastore.ReadWriteTransaction, validated *ValidatedSchemaChanges) (*AppliedSchemaChanges, error) {
    96  	existingCaveats, err := rwt.ListAllCaveats(ctx)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	existingObjectDefs, err := rwt.ListAllNamespaces(ctx)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	return ApplySchemaChangesOverExisting(ctx, rwt, validated, datastore.DefinitionsOf(existingCaveats), datastore.DefinitionsOf(existingObjectDefs))
   107  }
   108  
   109  // ApplySchemaChangesOverExisting applies schema changes found in the validated changes struct, against
   110  // existing caveat and object definitions given.
   111  func ApplySchemaChangesOverExisting(
   112  	ctx context.Context,
   113  	rwt datastore.ReadWriteTransaction,
   114  	validated *ValidatedSchemaChanges,
   115  	existingCaveats []*core.CaveatDefinition,
   116  	existingObjectDefs []*core.NamespaceDefinition,
   117  ) (*AppliedSchemaChanges, error) {
   118  	// Build a map of existing caveats to determine those being removed, if any.
   119  	existingCaveatDefMap := make(map[string]*core.CaveatDefinition, len(existingCaveats))
   120  	existingCaveatDefNames := mapz.NewSet[string]()
   121  
   122  	for _, existingCaveat := range existingCaveats {
   123  		existingCaveatDefMap[existingCaveat.Name] = existingCaveat
   124  		existingCaveatDefNames.Insert(existingCaveat.Name)
   125  	}
   126  
   127  	// For each caveat definition, perform a diff and ensure the changes will not result in type errors.
   128  	caveatDefsWithChanges := make([]*core.CaveatDefinition, 0, len(validated.compiled.CaveatDefinitions))
   129  	for _, caveatDef := range validated.compiled.CaveatDefinitions {
   130  		diff, err := sanityCheckCaveatChanges(ctx, rwt, caveatDef, existingCaveatDefMap)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  
   135  		if len(diff.Deltas()) > 0 {
   136  			caveatDefsWithChanges = append(caveatDefsWithChanges, caveatDef)
   137  		}
   138  	}
   139  
   140  	removedCaveatDefNames := existingCaveatDefNames.Subtract(validated.newCaveatDefNames)
   141  
   142  	// Build a map of existing definitions to determine those being removed, if any.
   143  	existingObjectDefMap := make(map[string]*core.NamespaceDefinition, len(existingObjectDefs))
   144  	existingObjectDefNames := mapz.NewSet[string]()
   145  	for _, existingDef := range existingObjectDefs {
   146  		existingObjectDefMap[existingDef.Name] = existingDef
   147  		existingObjectDefNames.Insert(existingDef.Name)
   148  	}
   149  
   150  	// For each definition, perform a diff and ensure the changes will not result in any
   151  	// breaking changes.
   152  	objectDefsWithChanges := make([]*core.NamespaceDefinition, 0, len(validated.compiled.ObjectDefinitions))
   153  	for _, nsdef := range validated.compiled.ObjectDefinitions {
   154  		diff, err := sanityCheckNamespaceChanges(ctx, rwt, nsdef, existingObjectDefMap)
   155  		if err != nil {
   156  			return nil, err
   157  		}
   158  
   159  		if len(diff.Deltas()) > 0 {
   160  			objectDefsWithChanges = append(objectDefsWithChanges, nsdef)
   161  
   162  			vts, ok := validated.validatedTypeSystems[nsdef.Name]
   163  			if !ok {
   164  				return nil, spiceerrors.MustBugf("validated type system not found for namespace `%s`", nsdef.Name)
   165  			}
   166  
   167  			if err := namespace.AnnotateNamespace(vts); err != nil {
   168  				return nil, err
   169  			}
   170  		}
   171  	}
   172  
   173  	log.Ctx(ctx).
   174  		Trace().
   175  		Int("objectDefinitions", len(validated.compiled.ObjectDefinitions)).
   176  		Int("caveatDefinitions", len(validated.compiled.CaveatDefinitions)).
   177  		Int("objectDefsWithChanges", len(objectDefsWithChanges)).
   178  		Int("caveatDefsWithChanges", len(caveatDefsWithChanges)).
   179  		Msg("validated namespace definitions")
   180  
   181  	// Ensure that deleting namespaces will not result in any relationships left without associated
   182  	// schema.
   183  	removedObjectDefNames := existingObjectDefNames.Subtract(validated.newObjectDefNames)
   184  	if !validated.additiveOnly {
   185  		if err := removedObjectDefNames.ForEach(func(nsdefName string) error {
   186  			return ensureNoRelationshipsExist(ctx, rwt, nsdefName)
   187  		}); err != nil {
   188  			return nil, err
   189  		}
   190  	}
   191  
   192  	// Write the new/changes caveats.
   193  	if len(caveatDefsWithChanges) > 0 {
   194  		if err := rwt.WriteCaveats(ctx, caveatDefsWithChanges); err != nil {
   195  			return nil, err
   196  		}
   197  	}
   198  
   199  	// Write the new/changed namespaces.
   200  	if len(objectDefsWithChanges) > 0 {
   201  		if err := rwt.WriteNamespaces(ctx, objectDefsWithChanges...); err != nil {
   202  			return nil, err
   203  		}
   204  	}
   205  
   206  	if !validated.additiveOnly {
   207  		// Delete the removed namespaces.
   208  		if removedObjectDefNames.Len() > 0 {
   209  			if err := rwt.DeleteNamespaces(ctx, removedObjectDefNames.AsSlice()...); err != nil {
   210  				return nil, err
   211  			}
   212  		}
   213  
   214  		// Delete the removed caveats.
   215  		if !removedCaveatDefNames.IsEmpty() {
   216  			if err := rwt.DeleteCaveats(ctx, removedCaveatDefNames.AsSlice()); err != nil {
   217  				return nil, err
   218  			}
   219  		}
   220  	}
   221  
   222  	log.Ctx(ctx).Trace().
   223  		Interface("objectDefinitions", validated.compiled.ObjectDefinitions).
   224  		Interface("caveatDefinitions", validated.compiled.CaveatDefinitions).
   225  		Object("addedOrChangedObjectDefinitions", validated.newObjectDefNames).
   226  		Object("removedObjectDefinitions", removedObjectDefNames).
   227  		Object("addedOrChangedCaveatDefinitions", validated.newCaveatDefNames).
   228  		Object("removedCaveatDefinitions", removedCaveatDefNames).
   229  		Msg("completed schema update")
   230  
   231  	return &AppliedSchemaChanges{
   232  		TotalOperationCount:   len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len(),
   233  		NewObjectDefNames:     validated.newObjectDefNames.Subtract(existingObjectDefNames).AsSlice(),
   234  		RemovedObjectDefNames: removedObjectDefNames.AsSlice(),
   235  		NewCaveatDefNames:     validated.newCaveatDefNames.Subtract(existingCaveatDefNames).AsSlice(),
   236  		RemovedCaveatDefNames: removedCaveatDefNames.AsSlice(),
   237  	}, nil
   238  }
   239  
   240  // sanityCheckCaveatChanges ensures that a caveat definition being written does not break
   241  // the types of the parameters that may already exist on relationships.
   242  func sanityCheckCaveatChanges(
   243  	_ context.Context,
   244  	_ datastore.ReadWriteTransaction,
   245  	caveatDef *core.CaveatDefinition,
   246  	existingDefs map[string]*core.CaveatDefinition,
   247  ) (*caveatdiff.Diff, error) {
   248  	// Ensure that the updated namespace does not break the existing tuple data.
   249  	existing := existingDefs[caveatDef.Name]
   250  	diff, err := caveatdiff.DiffCaveats(existing, caveatDef)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	for _, delta := range diff.Deltas() {
   256  		switch delta.Type {
   257  		case caveatdiff.RemovedParameter:
   258  			return diff, NewSchemaWriteDataValidationError("cannot remove parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name)
   259  
   260  		case caveatdiff.ParameterTypeChanged:
   261  			return diff, NewSchemaWriteDataValidationError("cannot change the type of parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name)
   262  		}
   263  	}
   264  
   265  	return diff, nil
   266  }
   267  
   268  // ensureNoRelationshipsExist ensures that no relationships exist within the namespace with the given name.
   269  func ensureNoRelationshipsExist(ctx context.Context, rwt datastore.ReadWriteTransaction, namespaceName string) error {
   270  	qy, qyErr := rwt.QueryRelationships(
   271  		ctx,
   272  		datastore.RelationshipsFilter{OptionalResourceType: namespaceName},
   273  		options.WithLimit(options.LimitOne),
   274  	)
   275  	if err := errorIfTupleIteratorReturnsTuples(
   276  		ctx,
   277  		qy,
   278  		qyErr,
   279  		"cannot delete object definition `%s`, as a relationship exists under it",
   280  		namespaceName,
   281  	); err != nil {
   282  		return err
   283  	}
   284  
   285  	qy, qyErr = rwt.ReverseQueryRelationships(ctx, datastore.SubjectsFilter{
   286  		SubjectType: namespaceName,
   287  	}, options.WithLimitForReverse(options.LimitOne))
   288  	err := errorIfTupleIteratorReturnsTuples(
   289  		ctx,
   290  		qy,
   291  		qyErr,
   292  		"cannot delete object definition `%s`, as a relationship references it",
   293  		namespaceName,
   294  	)
   295  	qy.Close()
   296  	if err != nil {
   297  		return err
   298  	}
   299  
   300  	return nil
   301  }
   302  
   303  // sanityCheckNamespaceChanges ensures that a namespace definition being written does not result
   304  // in breaking changes, such as relationships without associated defined schema object definitions
   305  // and relations.
   306  func sanityCheckNamespaceChanges(
   307  	ctx context.Context,
   308  	rwt datastore.ReadWriteTransaction,
   309  	nsdef *core.NamespaceDefinition,
   310  	existingDefs map[string]*core.NamespaceDefinition,
   311  ) (*nsdiff.Diff, error) {
   312  	// Ensure that the updated namespace does not break the existing tuple data.
   313  	existing := existingDefs[nsdef.Name]
   314  	diff, err := nsdiff.DiffNamespaces(existing, nsdef)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  
   319  	for _, delta := range diff.Deltas() {
   320  		switch delta.Type {
   321  		case nsdiff.RemovedRelation:
   322  			qy, qyErr := rwt.QueryRelationships(ctx, datastore.RelationshipsFilter{
   323  				OptionalResourceType:     nsdef.Name,
   324  				OptionalResourceRelation: delta.RelationName,
   325  			})
   326  
   327  			err = errorIfTupleIteratorReturnsTuples(
   328  				ctx,
   329  				qy,
   330  				qyErr,
   331  				"cannot delete relation `%s` in object definition `%s`, as a relationship exists under it", delta.RelationName, nsdef.Name)
   332  			if err != nil {
   333  				return diff, err
   334  			}
   335  
   336  			// Also check for right sides of tuples.
   337  			qy, qyErr = rwt.ReverseQueryRelationships(ctx, datastore.SubjectsFilter{
   338  				SubjectType: nsdef.Name,
   339  				RelationFilter: datastore.SubjectRelationFilter{
   340  					NonEllipsisRelation: delta.RelationName,
   341  				},
   342  			}, options.WithLimitForReverse(options.LimitOne))
   343  			err = errorIfTupleIteratorReturnsTuples(
   344  				ctx,
   345  				qy,
   346  				qyErr,
   347  				"cannot delete relation `%s` in object definition `%s`, as a relationship references it", delta.RelationName, nsdef.Name)
   348  			qy.Close()
   349  			if err != nil {
   350  				return diff, err
   351  			}
   352  
   353  		case nsdiff.RelationAllowedTypeRemoved:
   354  			var optionalSubjectIds []string
   355  			var relationFilter datastore.SubjectRelationFilter
   356  			optionalCaveatName := ""
   357  
   358  			if delta.AllowedType.GetPublicWildcard() != nil {
   359  				optionalSubjectIds = []string{tuple.PublicWildcard}
   360  			} else {
   361  				relationFilter = datastore.SubjectRelationFilter{
   362  					NonEllipsisRelation: delta.AllowedType.GetRelation(),
   363  				}
   364  			}
   365  
   366  			if delta.AllowedType.GetRequiredCaveat() != nil {
   367  				optionalCaveatName = delta.AllowedType.GetRequiredCaveat().CaveatName
   368  			}
   369  
   370  			qyr, qyrErr := rwt.QueryRelationships(
   371  				ctx,
   372  				datastore.RelationshipsFilter{
   373  					OptionalResourceType:     nsdef.Name,
   374  					OptionalResourceRelation: delta.RelationName,
   375  					OptionalSubjectsSelectors: []datastore.SubjectsSelector{
   376  						{
   377  							OptionalSubjectType: delta.AllowedType.Namespace,
   378  							OptionalSubjectIds:  optionalSubjectIds,
   379  							RelationFilter:      relationFilter,
   380  						},
   381  					},
   382  					OptionalCaveatName: optionalCaveatName,
   383  				},
   384  				options.WithLimit(options.LimitOne),
   385  			)
   386  			err = errorIfTupleIteratorReturnsTuples(
   387  				ctx,
   388  				qyr,
   389  				qyrErr,
   390  				"cannot remove allowed type `%s` from relation `%s` in object definition `%s`, as a relationship exists with it",
   391  				typesystem.SourceForAllowedRelation(delta.AllowedType), delta.RelationName, nsdef.Name)
   392  			qyr.Close()
   393  			if err != nil {
   394  				return diff, err
   395  			}
   396  		}
   397  	}
   398  	return diff, nil
   399  }
   400  
   401  // errorIfTupleIteratorReturnsTuples takes a tuple iterator and any error that was generated
   402  // when the original iterator was created, and returns an error if iterator contains any tuples.
   403  func errorIfTupleIteratorReturnsTuples(_ context.Context, qy datastore.RelationshipIterator, qyErr error, message string, args ...interface{}) error {
   404  	if qyErr != nil {
   405  		return qyErr
   406  	}
   407  	defer qy.Close()
   408  
   409  	if rt := qy.Next(); rt != nil {
   410  		if qy.Err() != nil {
   411  			return qy.Err()
   412  		}
   413  
   414  		return NewSchemaWriteDataValidationError(message, args...)
   415  	}
   416  	return nil
   417  }