github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/schemadsl/compiler/development.go (about) 1 package compiler 2 3 import ( 4 "github.com/authzed/spicedb/pkg/schemadsl/dslshape" 5 "github.com/authzed/spicedb/pkg/schemadsl/input" 6 ) 7 8 // DSLNode is a node in the DSL AST. 9 type DSLNode interface { 10 GetType() dslshape.NodeType 11 GetString(predicateName string) (string, error) 12 GetInt(predicateName string) (int, error) 13 Lookup(predicateName string) (DSLNode, error) 14 } 15 16 // NodeChain is a chain of nodes in the DSL AST. 17 type NodeChain struct { 18 nodes []DSLNode 19 runePosition int 20 } 21 22 // Head returns the head node of the chain. 23 func (nc *NodeChain) Head() DSLNode { 24 return nc.nodes[0] 25 } 26 27 // HasHeadType returns true if the head node of the chain is of the given type. 28 func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool { 29 return nc.nodes[0].GetType() == nodeType 30 } 31 32 // ForRunePosition returns the rune position of the chain. 33 func (nc *NodeChain) ForRunePosition() int { 34 return nc.runePosition 35 } 36 37 // FindNodeOfType returns the first node of the given type in the chain, if any. 38 func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode { 39 for _, node := range nc.nodes { 40 if node.GetType() == nodeType { 41 return node 42 } 43 } 44 45 return nil 46 } 47 48 func (nc *NodeChain) String() string { 49 var out string 50 for _, node := range nc.nodes { 51 out += node.GetType().String() + " " 52 } 53 return out 54 } 55 56 // PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any. 57 func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) { 58 rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource) 59 if err != nil { 60 return nil, err 61 } 62 63 if rootSource != string(source) { 64 return nil, nil 65 } 66 67 // Map the position to a file rune. 68 runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source) 69 if err != nil { 70 return nil, err 71 } 72 73 // Find the node at the rune position. 74 found, err := runePositionToAstNodeChain(schema.rootNode, runePosition) 75 if err != nil { 76 return nil, err 77 } 78 79 if found == nil { 80 return nil, nil 81 } 82 83 return &NodeChain{nodes: found, runePosition: runePosition}, nil 84 } 85 86 func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) { 87 if !node.Has(dslshape.NodePredicateStartRune) { 88 return nil, nil 89 } 90 91 startRune, err := node.GetInt(dslshape.NodePredicateStartRune) 92 if err != nil { 93 return nil, err 94 } 95 96 endRune, err := node.GetInt(dslshape.NodePredicateEndRune) 97 if err != nil { 98 return nil, err 99 } 100 101 if runePosition < startRune || runePosition > endRune { 102 return nil, nil 103 } 104 105 for _, child := range node.AllSubNodes() { 106 childChain, err := runePositionToAstNodeChain(child, runePosition) 107 if err != nil { 108 return nil, err 109 } 110 111 if childChain != nil { 112 return append(childChain, wrapper{node}), nil 113 } 114 } 115 116 return []DSLNode{wrapper{node}}, nil 117 } 118 119 type wrapper struct { 120 node *dslNode 121 } 122 123 func (w wrapper) GetType() dslshape.NodeType { 124 return w.node.GetType() 125 } 126 127 func (w wrapper) GetString(predicateName string) (string, error) { 128 return w.node.GetString(predicateName) 129 } 130 131 func (w wrapper) GetInt(predicateName string) (int, error) { 132 return w.node.GetInt(predicateName) 133 } 134 135 func (w wrapper) Lookup(predicateName string) (DSLNode, error) { 136 found, err := w.node.Lookup(predicateName) 137 if err != nil { 138 return nil, err 139 } 140 141 return wrapper{found}, nil 142 }