github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/services/v1/schema.go (about)

     1  package v1
     2  
     3  import (
     4  	"context"
     5  
     6  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
     7  	grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
     8  	"google.golang.org/grpc/codes"
     9  	"google.golang.org/grpc/status"
    10  
    11  	log "github.com/authzed/spicedb/internal/logging"
    12  	"github.com/authzed/spicedb/internal/middleware"
    13  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    14  	"github.com/authzed/spicedb/internal/middleware/usagemetrics"
    15  	"github.com/authzed/spicedb/internal/services/shared"
    16  	"github.com/authzed/spicedb/pkg/datastore"
    17  	"github.com/authzed/spicedb/pkg/genutil"
    18  	dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    19  	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
    20  	"github.com/authzed/spicedb/pkg/schemadsl/generator"
    21  	"github.com/authzed/spicedb/pkg/schemadsl/input"
    22  	"github.com/authzed/spicedb/pkg/zedtoken"
    23  )
    24  
    25  // NewSchemaServer creates a SchemaServiceServer instance.
    26  func NewSchemaServer(additiveOnly bool) v1.SchemaServiceServer {
    27  	return &schemaServer{
    28  		WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
    29  			Unary: middleware.ChainUnaryServer(
    30  				grpcvalidate.UnaryServerInterceptor(),
    31  				usagemetrics.UnaryServerInterceptor(),
    32  			),
    33  			Stream: middleware.ChainStreamServer(
    34  				grpcvalidate.StreamServerInterceptor(),
    35  				usagemetrics.StreamServerInterceptor(),
    36  			),
    37  		},
    38  		additiveOnly: additiveOnly,
    39  	}
    40  }
    41  
    42  type schemaServer struct {
    43  	v1.UnimplementedSchemaServiceServer
    44  	shared.WithServiceSpecificInterceptors
    45  
    46  	additiveOnly bool
    47  }
    48  
    49  func (ss *schemaServer) rewriteError(ctx context.Context, err error) error {
    50  	return shared.RewriteError(ctx, err, nil)
    51  }
    52  
    53  func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
    54  	// Schema is always read from the head revision.
    55  	ds := datastoremw.MustFromContext(ctx)
    56  	headRevision, err := ds.HeadRevision(ctx)
    57  	if err != nil {
    58  		return nil, ss.rewriteError(ctx, err)
    59  	}
    60  
    61  	reader := ds.SnapshotReader(headRevision)
    62  
    63  	nsDefs, err := reader.ListAllNamespaces(ctx)
    64  	if err != nil {
    65  		return nil, ss.rewriteError(ctx, err)
    66  	}
    67  
    68  	caveatDefs, err := reader.ListAllCaveats(ctx)
    69  	if err != nil {
    70  		return nil, ss.rewriteError(ctx, err)
    71  	}
    72  
    73  	if len(nsDefs) == 0 {
    74  		return nil, status.Errorf(codes.NotFound, "No schema has been defined; please call WriteSchema to start")
    75  	}
    76  
    77  	schemaDefinitions := make([]compiler.SchemaDefinition, 0, len(nsDefs)+len(caveatDefs))
    78  	for _, caveatDef := range caveatDefs {
    79  		schemaDefinitions = append(schemaDefinitions, caveatDef.Definition)
    80  	}
    81  
    82  	for _, nsDef := range nsDefs {
    83  		schemaDefinitions = append(schemaDefinitions, nsDef.Definition)
    84  	}
    85  
    86  	schemaText, _, err := generator.GenerateSchema(schemaDefinitions)
    87  	if err != nil {
    88  		return nil, ss.rewriteError(ctx, err)
    89  	}
    90  
    91  	dispatchCount, err := genutil.EnsureUInt32(len(nsDefs) + len(caveatDefs))
    92  	if err != nil {
    93  		return nil, ss.rewriteError(ctx, err)
    94  	}
    95  
    96  	usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
    97  		DispatchCount: dispatchCount,
    98  	})
    99  
   100  	return &v1.ReadSchemaResponse{
   101  		SchemaText: schemaText,
   102  		ReadAt:     zedtoken.MustNewFromRevision(headRevision),
   103  	}, nil
   104  }
   105  
   106  func (ss *schemaServer) WriteSchema(ctx context.Context, in *v1.WriteSchemaRequest) (*v1.WriteSchemaResponse, error) {
   107  	log.Ctx(ctx).Trace().Str("schema", in.GetSchema()).Msg("requested Schema to be written")
   108  
   109  	ds := datastoremw.MustFromContext(ctx)
   110  
   111  	// Compile the schema into the namespace definitions.
   112  	compiled, err := compiler.Compile(compiler.InputSchema{
   113  		Source:       input.Source("schema"),
   114  		SchemaString: in.GetSchema(),
   115  	}, compiler.AllowUnprefixedObjectType())
   116  	if err != nil {
   117  		return nil, ss.rewriteError(ctx, err)
   118  	}
   119  	log.Ctx(ctx).Trace().Int("objectDefinitions", len(compiled.ObjectDefinitions)).Int("caveatDefinitions", len(compiled.CaveatDefinitions)).Msg("compiled namespace definitions")
   120  
   121  	// Do as much validation as we can before talking to the datastore.
   122  	validated, err := shared.ValidateSchemaChanges(ctx, compiled, ss.additiveOnly)
   123  	if err != nil {
   124  		return nil, ss.rewriteError(ctx, err)
   125  	}
   126  
   127  	// Update the schema.
   128  	revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
   129  		applied, err := shared.ApplySchemaChanges(ctx, rwt, validated)
   130  		if err != nil {
   131  			return err
   132  		}
   133  
   134  		dispatchCount, err := genutil.EnsureUInt32(applied.TotalOperationCount)
   135  		if err != nil {
   136  			return err
   137  		}
   138  
   139  		usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
   140  			DispatchCount: dispatchCount,
   141  		})
   142  		return nil
   143  	})
   144  	if err != nil {
   145  		return nil, ss.rewriteError(ctx, err)
   146  	}
   147  
   148  	return &v1.WriteSchemaResponse{
   149  		WrittenAt: zedtoken.MustNewFromRevision(revision),
   150  	}, nil
   151  }