github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/schemadsl/compiler/development.go (about)

     1  package compiler
     2  
     3  import (
     4  	"github.com/authzed/spicedb/pkg/schemadsl/dslshape"
     5  	"github.com/authzed/spicedb/pkg/schemadsl/input"
     6  )
     7  
     8  // DSLNode is a node in the DSL AST.
     9  type DSLNode interface {
    10  	GetType() dslshape.NodeType
    11  	GetString(predicateName string) (string, error)
    12  	GetInt(predicateName string) (int, error)
    13  	Lookup(predicateName string) (DSLNode, error)
    14  }
    15  
    16  // NodeChain is a chain of nodes in the DSL AST.
    17  type NodeChain struct {
    18  	nodes        []DSLNode
    19  	runePosition int
    20  }
    21  
    22  // Head returns the head node of the chain.
    23  func (nc *NodeChain) Head() DSLNode {
    24  	return nc.nodes[0]
    25  }
    26  
    27  // HasHeadType returns true if the head node of the chain is of the given type.
    28  func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool {
    29  	return nc.nodes[0].GetType() == nodeType
    30  }
    31  
    32  // ForRunePosition returns the rune position of the chain.
    33  func (nc *NodeChain) ForRunePosition() int {
    34  	return nc.runePosition
    35  }
    36  
    37  // FindNodeOfType returns the first node of the given type in the chain, if any.
    38  func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode {
    39  	for _, node := range nc.nodes {
    40  		if node.GetType() == nodeType {
    41  			return node
    42  		}
    43  	}
    44  
    45  	return nil
    46  }
    47  
    48  func (nc *NodeChain) String() string {
    49  	var out string
    50  	for _, node := range nc.nodes {
    51  		out += node.GetType().String() + " "
    52  	}
    53  	return out
    54  }
    55  
    56  // PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any.
    57  func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) {
    58  	rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	if rootSource != string(source) {
    64  		return nil, nil
    65  	}
    66  
    67  	// Map the position to a file rune.
    68  	runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	// Find the node at the rune position.
    74  	found, err := runePositionToAstNodeChain(schema.rootNode, runePosition)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	if found == nil {
    80  		return nil, nil
    81  	}
    82  
    83  	return &NodeChain{nodes: found, runePosition: runePosition}, nil
    84  }
    85  
    86  func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) {
    87  	if !node.Has(dslshape.NodePredicateStartRune) {
    88  		return nil, nil
    89  	}
    90  
    91  	startRune, err := node.GetInt(dslshape.NodePredicateStartRune)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	endRune, err := node.GetInt(dslshape.NodePredicateEndRune)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	if runePosition < startRune || runePosition > endRune {
   102  		return nil, nil
   103  	}
   104  
   105  	for _, child := range node.AllSubNodes() {
   106  		childChain, err := runePositionToAstNodeChain(child, runePosition)
   107  		if err != nil {
   108  			return nil, err
   109  		}
   110  
   111  		if childChain != nil {
   112  			return append(childChain, wrapper{node}), nil
   113  		}
   114  	}
   115  
   116  	return []DSLNode{wrapper{node}}, nil
   117  }
   118  
   119  type wrapper struct {
   120  	node *dslNode
   121  }
   122  
   123  func (w wrapper) GetType() dslshape.NodeType {
   124  	return w.node.GetType()
   125  }
   126  
   127  func (w wrapper) GetString(predicateName string) (string, error) {
   128  	return w.node.GetString(predicateName)
   129  }
   130  
   131  func (w wrapper) GetInt(predicateName string) (int, error) {
   132  	return w.node.GetInt(predicateName)
   133  }
   134  
   135  func (w wrapper) Lookup(predicateName string) (DSLNode, error) {
   136  	found, err := w.node.Lookup(predicateName)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	return wrapper{found}, nil
   142  }