github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/development/resolver.go (about)

     1  package development
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/authzed/spicedb/pkg/caveats"
     8  	"github.com/authzed/spicedb/pkg/namespace"
     9  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    10  	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
    11  	"github.com/authzed/spicedb/pkg/schemadsl/dslshape"
    12  	"github.com/authzed/spicedb/pkg/schemadsl/generator"
    13  	"github.com/authzed/spicedb/pkg/schemadsl/input"
    14  	"github.com/authzed/spicedb/pkg/typesystem"
    15  )
    16  
    17  // ReferenceType is the type of reference.
    18  type ReferenceType int
    19  
    20  const (
    21  	ReferenceTypeUnknown ReferenceType = iota
    22  	ReferenceTypeDefinition
    23  	ReferenceTypeCaveat
    24  	ReferenceTypeRelation
    25  	ReferenceTypePermission
    26  	ReferenceTypeCaveatParameter
    27  )
    28  
    29  // SchemaReference represents a reference to a schema node.
    30  type SchemaReference struct {
    31  	// Source is the source of the reference.
    32  	Source input.Source
    33  
    34  	// Position is the position of the reference in the source.
    35  	Position input.Position
    36  
    37  	// Text is the text of the reference.
    38  	Text string
    39  
    40  	// ReferenceType is the type of reference.
    41  	ReferenceType ReferenceType
    42  
    43  	// ReferenceMarkdown is the markdown representation of the reference.
    44  	ReferenceMarkdown string
    45  
    46  	// TargetSource is the source of the target node, if any.
    47  	TargetSource *input.Source
    48  
    49  	// TargetPosition is the position of the target node, if any.
    50  	TargetPosition *input.Position
    51  
    52  	// TargetSourceCode is the source code representation of the target, if any.
    53  	TargetSourceCode string
    54  
    55  	// TargetNamePositionOffset is the offset from the target position from where the
    56  	// *name* of the target is found.
    57  	TargetNamePositionOffset int
    58  }
    59  
    60  // Resolver resolves references to schema nodes from source positions.
    61  type Resolver struct {
    62  	schema      *compiler.CompiledSchema
    63  	typeSystems map[string]*typesystem.TypeSystem
    64  }
    65  
    66  // NewResolver creates a new resolver for the given schema.
    67  func NewResolver(schema *compiler.CompiledSchema) (*Resolver, error) {
    68  	typeSystems := make(map[string]*typesystem.TypeSystem, len(schema.ObjectDefinitions))
    69  	tsResolver := typesystem.ResolverForSchema(*schema)
    70  	for _, def := range schema.ObjectDefinitions {
    71  		ts, err := typesystem.NewNamespaceTypeSystem(def, tsResolver)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  
    76  		typeSystems[def.Name] = ts
    77  	}
    78  
    79  	return &Resolver{schema: schema, typeSystems: typeSystems}, nil
    80  }
    81  
    82  // ReferenceAtPosition returns the reference to the schema node at the given position in the source, if any.
    83  func (r *Resolver) ReferenceAtPosition(source input.Source, position input.Position) (*SchemaReference, error) {
    84  	nodeChain, err := compiler.PositionToAstNodeChain(r.schema, source, position)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	if nodeChain == nil {
    90  		return nil, nil
    91  	}
    92  
    93  	relationReference := func(relation *core.Relation, ts *typesystem.TypeSystem) (*SchemaReference, error) {
    94  		relationPosition := input.Position{
    95  			LineNumber:     int(relation.SourcePosition.ZeroIndexedLineNumber),
    96  			ColumnPosition: int(relation.SourcePosition.ZeroIndexedColumnPosition),
    97  		}
    98  
    99  		targetSourceCode, err := generator.GenerateRelationSource(relation)
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  
   104  		if ts.IsPermission(relation.Name) {
   105  			return &SchemaReference{
   106  				Source:   source,
   107  				Position: position,
   108  				Text:     relation.Name,
   109  
   110  				ReferenceType:     ReferenceTypePermission,
   111  				ReferenceMarkdown: fmt.Sprintf("permission %s", relation.Name),
   112  
   113  				TargetSource:             &source,
   114  				TargetPosition:           &relationPosition,
   115  				TargetSourceCode:         targetSourceCode,
   116  				TargetNamePositionOffset: len("permission "),
   117  			}, nil
   118  		}
   119  
   120  		return &SchemaReference{
   121  			Source:   source,
   122  			Position: position,
   123  			Text:     relation.Name,
   124  
   125  			ReferenceType:     ReferenceTypeRelation,
   126  			ReferenceMarkdown: fmt.Sprintf("relation %s", relation.Name),
   127  
   128  			TargetSource:             &source,
   129  			TargetPosition:           &relationPosition,
   130  			TargetSourceCode:         targetSourceCode,
   131  			TargetNamePositionOffset: len("relation "),
   132  		}, nil
   133  	}
   134  
   135  	// Type reference.
   136  	if ts, relation, ok := r.typeReferenceChain(nodeChain); ok {
   137  		if relation != nil {
   138  			return relationReference(relation, ts)
   139  		}
   140  
   141  		def := ts.Namespace()
   142  		defPosition := input.Position{
   143  			LineNumber:     int(def.SourcePosition.ZeroIndexedLineNumber),
   144  			ColumnPosition: int(def.SourcePosition.ZeroIndexedColumnPosition),
   145  		}
   146  
   147  		docComment := ""
   148  		comments := namespace.GetComments(def.Metadata)
   149  		if len(comments) > 0 {
   150  			docComment = strings.Join(comments, "\n") + "\n"
   151  		}
   152  
   153  		targetSourceCode := fmt.Sprintf("%sdefinition %s {\n\t// ...\n}", docComment, def.Name)
   154  		if len(def.Relation) == 0 {
   155  			targetSourceCode = fmt.Sprintf("%sdefinition %s {}", docComment, def.Name)
   156  		}
   157  
   158  		return &SchemaReference{
   159  			Source:   source,
   160  			Position: position,
   161  			Text:     def.Name,
   162  
   163  			ReferenceType:     ReferenceTypeDefinition,
   164  			ReferenceMarkdown: fmt.Sprintf("definition %s", def.Name),
   165  
   166  			TargetSource:             &source,
   167  			TargetPosition:           &defPosition,
   168  			TargetSourceCode:         targetSourceCode,
   169  			TargetNamePositionOffset: len("definition "),
   170  		}, nil
   171  	}
   172  
   173  	// Caveat Type reference.
   174  	if caveatDef, ok := r.caveatTypeReferenceChain(nodeChain); ok {
   175  		defPosition := input.Position{
   176  			LineNumber:     int(caveatDef.SourcePosition.ZeroIndexedLineNumber),
   177  			ColumnPosition: int(caveatDef.SourcePosition.ZeroIndexedColumnPosition),
   178  		}
   179  
   180  		var caveatSourceCode strings.Builder
   181  		caveatSourceCode.WriteString(fmt.Sprintf("caveat %s(", caveatDef.Name))
   182  		index := 0
   183  		for paramName, paramType := range caveatDef.ParameterTypes {
   184  			if index > 0 {
   185  				caveatSourceCode.WriteString(", ")
   186  			}
   187  
   188  			caveatSourceCode.WriteString(fmt.Sprintf("%s %s", paramName, caveats.ParameterTypeString(paramType)))
   189  			index++
   190  		}
   191  		caveatSourceCode.WriteString(") {\n\t// ...\n}")
   192  
   193  		return &SchemaReference{
   194  			Source:   source,
   195  			Position: position,
   196  			Text:     caveatDef.Name,
   197  
   198  			ReferenceType:     ReferenceTypeCaveat,
   199  			ReferenceMarkdown: fmt.Sprintf("caveat %s", caveatDef.Name),
   200  
   201  			TargetSource:             &source,
   202  			TargetPosition:           &defPosition,
   203  			TargetSourceCode:         caveatSourceCode.String(),
   204  			TargetNamePositionOffset: len("caveat "),
   205  		}, nil
   206  	}
   207  
   208  	// Relation reference.
   209  	if relation, ts, ok := r.relationReferenceChain(nodeChain); ok {
   210  		return relationReference(relation, ts)
   211  	}
   212  
   213  	// Caveat parameter used in expression.
   214  	if caveatParamName, caveatDef, ok := r.caveatParamChain(nodeChain, source, position); ok {
   215  		targetSourceCode := fmt.Sprintf("%s %s", caveatParamName, caveats.ParameterTypeString(caveatDef.ParameterTypes[caveatParamName]))
   216  
   217  		return &SchemaReference{
   218  			Source:   source,
   219  			Position: position,
   220  			Text:     caveatParamName,
   221  
   222  			ReferenceType:     ReferenceTypeCaveatParameter,
   223  			ReferenceMarkdown: targetSourceCode,
   224  
   225  			TargetSource:     &source,
   226  			TargetSourceCode: targetSourceCode,
   227  		}, nil
   228  	}
   229  
   230  	return nil, nil
   231  }
   232  
   233  func (r *Resolver) lookupDefinition(defName string) (*typesystem.TypeSystem, bool) {
   234  	ts, ok := r.typeSystems[defName]
   235  	return ts, ok
   236  }
   237  
   238  func (r *Resolver) lookupCaveat(caveatName string) (*core.CaveatDefinition, bool) {
   239  	for _, caveatDef := range r.schema.CaveatDefinitions {
   240  		if caveatDef.Name == caveatName {
   241  			return caveatDef, true
   242  		}
   243  	}
   244  
   245  	return nil, false
   246  }
   247  
   248  func (r *Resolver) lookupRelation(defName, relationName string) (*core.Relation, *typesystem.TypeSystem, bool) {
   249  	ts, ok := r.typeSystems[defName]
   250  	if !ok {
   251  		return nil, nil, false
   252  	}
   253  
   254  	rel, ok := ts.GetRelation(relationName)
   255  	if !ok {
   256  		return nil, nil, false
   257  	}
   258  
   259  	return rel, ts, true
   260  }
   261  
   262  func (r *Resolver) caveatParamChain(nodeChain *compiler.NodeChain, source input.Source, position input.Position) (string, *core.CaveatDefinition, bool) {
   263  	if !nodeChain.HasHeadType(dslshape.NodeTypeCaveatExpression) {
   264  		return "", nil, false
   265  	}
   266  
   267  	caveatDefNode := nodeChain.FindNodeOfType(dslshape.NodeTypeCaveatDefinition)
   268  	if caveatDefNode == nil {
   269  		return "", nil, false
   270  	}
   271  
   272  	caveatName, err := caveatDefNode.GetString(dslshape.NodeCaveatDefinitionPredicateName)
   273  	if err != nil {
   274  		return "", nil, false
   275  	}
   276  
   277  	caveatDef, ok := r.lookupCaveat(caveatName)
   278  	if !ok {
   279  		return "", nil, false
   280  	}
   281  
   282  	runePosition, err := r.schema.SourcePositionToRunePosition(source, position)
   283  	if err != nil {
   284  		return "", nil, false
   285  	}
   286  
   287  	exprRunePosition, err := nodeChain.Head().GetInt(dslshape.NodePredicateStartRune)
   288  	if err != nil {
   289  		return "", nil, false
   290  	}
   291  
   292  	if exprRunePosition > runePosition {
   293  		return "", nil, false
   294  	}
   295  
   296  	relationRunePosition := runePosition - exprRunePosition
   297  
   298  	caveatExpr, err := nodeChain.Head().GetString(dslshape.NodeCaveatExpressionPredicateExpression)
   299  	if err != nil {
   300  		return "", nil, false
   301  	}
   302  
   303  	// Split the expression into tokens and find the associated token.
   304  	tokens := strings.FieldsFunc(caveatExpr, splitCELToken)
   305  	currentIndex := 0
   306  	for _, token := range tokens {
   307  		if currentIndex <= relationRunePosition && currentIndex+len(token) >= relationRunePosition {
   308  			if _, ok := caveatDef.ParameterTypes[token]; ok {
   309  				return token, caveatDef, true
   310  			}
   311  		}
   312  	}
   313  
   314  	return "", caveatDef, true
   315  }
   316  
   317  func splitCELToken(r rune) bool {
   318  	return r == ' ' || r == '(' || r == ')' || r == '.' || r == ',' || r == '[' || r == ']' || r == '{' || r == '}' || r == ':' || r == '='
   319  }
   320  
   321  func (r *Resolver) caveatTypeReferenceChain(nodeChain *compiler.NodeChain) (*core.CaveatDefinition, bool) {
   322  	if !nodeChain.HasHeadType(dslshape.NodeTypeCaveatReference) {
   323  		return nil, false
   324  	}
   325  
   326  	caveatName, err := nodeChain.Head().GetString(dslshape.NodeCaveatPredicateCaveat)
   327  	if err != nil {
   328  		return nil, false
   329  	}
   330  
   331  	return r.lookupCaveat(caveatName)
   332  }
   333  
   334  func (r *Resolver) typeReferenceChain(nodeChain *compiler.NodeChain) (*typesystem.TypeSystem, *core.Relation, bool) {
   335  	if !nodeChain.HasHeadType(dslshape.NodeTypeSpecificTypeReference) {
   336  		return nil, nil, false
   337  	}
   338  
   339  	defName, err := nodeChain.Head().GetString(dslshape.NodeSpecificReferencePredicateType)
   340  	if err != nil {
   341  		return nil, nil, false
   342  	}
   343  
   344  	def, ok := r.lookupDefinition(defName)
   345  	if !ok {
   346  		return nil, nil, false
   347  	}
   348  
   349  	relationName, err := nodeChain.Head().GetString(dslshape.NodeSpecificReferencePredicateRelation)
   350  	if err != nil {
   351  		return def, nil, true
   352  	}
   353  
   354  	startingRune, err := nodeChain.Head().GetInt(dslshape.NodePredicateStartRune)
   355  	if err != nil {
   356  		return def, nil, true
   357  	}
   358  
   359  	// If hover over the definition name, return the definition.
   360  	if nodeChain.ForRunePosition() < startingRune+len(defName) {
   361  		return def, nil, true
   362  	}
   363  
   364  	relation, ok := def.GetRelation(relationName)
   365  	if !ok {
   366  		return nil, nil, false
   367  	}
   368  
   369  	return def, relation, true
   370  }
   371  
   372  func (r *Resolver) relationReferenceChain(nodeChain *compiler.NodeChain) (*core.Relation, *typesystem.TypeSystem, bool) {
   373  	if !nodeChain.HasHeadType(dslshape.NodeTypeIdentifier) {
   374  		return nil, nil, false
   375  	}
   376  
   377  	if arrowExpr := nodeChain.FindNodeOfType(dslshape.NodeTypeArrowExpression); arrowExpr != nil {
   378  		// Ensure this on the left side of the arrow.
   379  		rightExpr, err := arrowExpr.Lookup(dslshape.NodeExpressionPredicateRightExpr)
   380  		if err != nil {
   381  			return nil, nil, false
   382  		}
   383  
   384  		if rightExpr == nodeChain.Head() {
   385  			return nil, nil, false
   386  		}
   387  	}
   388  
   389  	relationName, err := nodeChain.Head().GetString(dslshape.NodeIdentiferPredicateValue)
   390  	if err != nil {
   391  		return nil, nil, false
   392  	}
   393  
   394  	parentDefNode := nodeChain.FindNodeOfType(dslshape.NodeTypeDefinition)
   395  	if parentDefNode == nil {
   396  		return nil, nil, false
   397  	}
   398  
   399  	defName, err := parentDefNode.GetString(dslshape.NodeDefinitionPredicateName)
   400  	if err != nil {
   401  		return nil, nil, false
   402  	}
   403  
   404  	return r.lookupRelation(defName, relationName)
   405  }