github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/schemadsl/compiler/compiler.go (about)

     1  package compiler
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"google.golang.org/protobuf/proto"
     8  
     9  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    10  	"github.com/authzed/spicedb/pkg/schemadsl/dslshape"
    11  	"github.com/authzed/spicedb/pkg/schemadsl/input"
    12  	"github.com/authzed/spicedb/pkg/schemadsl/parser"
    13  )
    14  
    15  // InputSchema defines the input for a Compile.
    16  type InputSchema struct {
    17  	// Source is the source of the schema being compiled.
    18  	Source input.Source
    19  
    20  	// Schema is the contents being compiled.
    21  	SchemaString string
    22  }
    23  
    24  // SchemaDefinition represents an object or caveat definition in a schema.
    25  type SchemaDefinition interface {
    26  	proto.Message
    27  
    28  	GetName() string
    29  }
    30  
    31  // CompiledSchema is the result of compiling a schema when there are no errors.
    32  type CompiledSchema struct {
    33  	// ObjectDefinitions holds the object definitions in the schema.
    34  	ObjectDefinitions []*core.NamespaceDefinition
    35  
    36  	// CaveatDefinitions holds the caveat definitions in the schema.
    37  	CaveatDefinitions []*core.CaveatDefinition
    38  
    39  	// OrderedDefinitions holds the object and caveat definitions in the schema, in the
    40  	// order in which they were found.
    41  	OrderedDefinitions []SchemaDefinition
    42  
    43  	rootNode *dslNode
    44  	mapper   input.PositionMapper
    45  }
    46  
    47  // SourcePositionToRunePosition converts a source position to a rune position.
    48  func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, position input.Position) (int, error) {
    49  	return cs.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
    50  }
    51  
    52  type config struct {
    53  	skipValidation   bool
    54  	objectTypePrefix *string
    55  }
    56  
    57  func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } }
    58  
    59  func ObjectTypePrefix(prefix string) ObjectPrefixOption {
    60  	return func(cfg *config) { cfg.objectTypePrefix = &prefix }
    61  }
    62  
    63  func RequirePrefixedObjectType() ObjectPrefixOption {
    64  	return func(cfg *config) { cfg.objectTypePrefix = nil }
    65  }
    66  
    67  func AllowUnprefixedObjectType() ObjectPrefixOption {
    68  	return func(cfg *config) { cfg.objectTypePrefix = new(string) }
    69  }
    70  
    71  type Option func(*config)
    72  
    73  type ObjectPrefixOption func(*config)
    74  
    75  // Compile compilers the input schema into a set of namespace definition protos.
    76  func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
    77  	cfg := &config{}
    78  	prefix(cfg) // required option
    79  
    80  	for _, fn := range opts {
    81  		fn(cfg)
    82  	}
    83  
    84  	mapper := newPositionMapper(schema)
    85  	root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode)
    86  	errs := root.FindAll(dslshape.NodeTypeError)
    87  	if len(errs) > 0 {
    88  		err := errorNodeToError(errs[0], mapper)
    89  		return nil, err
    90  	}
    91  
    92  	compiled, err := translate(translationContext{
    93  		objectTypePrefix: cfg.objectTypePrefix,
    94  		mapper:           mapper,
    95  		schemaString:     schema.SchemaString,
    96  		skipValidate:     cfg.skipValidation,
    97  	}, root)
    98  	if err != nil {
    99  		var errorWithNode errorWithNode
   100  		if errors.As(err, &errorWithNode) {
   101  			err = toContextError(errorWithNode.error.Error(), errorWithNode.errorSourceCode, errorWithNode.node, mapper)
   102  		}
   103  
   104  		return nil, err
   105  	}
   106  
   107  	return compiled, nil
   108  }
   109  
   110  func errorNodeToError(node *dslNode, mapper input.PositionMapper) error {
   111  	if node.GetType() != dslshape.NodeTypeError {
   112  		return fmt.Errorf("given none error node")
   113  	}
   114  
   115  	errMessage, err := node.GetString(dslshape.NodePredicateErrorMessage)
   116  	if err != nil {
   117  		return fmt.Errorf("could not get error message for error node: %w", err)
   118  	}
   119  
   120  	errorSourceCode := ""
   121  	if node.Has(dslshape.NodePredicateErrorSource) {
   122  		es, err := node.GetString(dslshape.NodePredicateErrorSource)
   123  		if err != nil {
   124  			return fmt.Errorf("could not get error source for error node: %w", err)
   125  		}
   126  
   127  		errorSourceCode = es
   128  	}
   129  
   130  	return toContextError(errMessage, errorSourceCode, node, mapper)
   131  }
   132  
   133  func toContextError(errMessage string, errorSourceCode string, node *dslNode, mapper input.PositionMapper) error {
   134  	sourceRange, err := node.Range(mapper)
   135  	if err != nil {
   136  		return fmt.Errorf("could not get range for error node: %w", err)
   137  	}
   138  
   139  	formattedRange, err := formatRange(sourceRange)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	source, err := node.GetString(dslshape.NodePredicateSource)
   145  	if err != nil {
   146  		return fmt.Errorf("missing source for node: %w", err)
   147  	}
   148  
   149  	return ErrorWithContext{
   150  		BaseCompilerError: BaseCompilerError{
   151  			error:       fmt.Errorf("parse error in %s: %s", formattedRange, errMessage),
   152  			BaseMessage: errMessage,
   153  		},
   154  		SourceRange:     sourceRange,
   155  		Source:          input.Source(source),
   156  		ErrorSourceCode: errorSourceCode,
   157  	}
   158  }
   159  
   160  func formatRange(rnge input.SourceRange) (string, error) {
   161  	startLine, startCol, err := rnge.Start().LineAndColumn()
   162  	if err != nil {
   163  		return "", err
   164  	}
   165  
   166  	return fmt.Sprintf("`%s`, line %v, column %v", rnge.Source(), startLine+1, startCol+1), nil
   167  }