github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/schemadsl/compiler/node.go (about)

     1  package compiler
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  
     7  	"github.com/authzed/spicedb/pkg/schemadsl/dslshape"
     8  	"github.com/authzed/spicedb/pkg/schemadsl/input"
     9  	"github.com/authzed/spicedb/pkg/schemadsl/parser"
    10  )
    11  
    12  type dslNode struct {
    13  	nodeType   dslshape.NodeType
    14  	properties map[string]interface{}
    15  	children   map[string]*list.List
    16  }
    17  
    18  func createAstNode(_ input.Source, kind dslshape.NodeType) parser.AstNode {
    19  	return &dslNode{
    20  		nodeType:   kind,
    21  		properties: make(map[string]interface{}),
    22  		children:   make(map[string]*list.List),
    23  	}
    24  }
    25  
    26  func (tn *dslNode) GetType() dslshape.NodeType {
    27  	return tn.nodeType
    28  }
    29  
    30  func (tn *dslNode) Connect(predicate string, other parser.AstNode) {
    31  	if tn.children[predicate] == nil {
    32  		tn.children[predicate] = list.New()
    33  	}
    34  
    35  	tn.children[predicate].PushBack(other)
    36  }
    37  
    38  func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode {
    39  	if _, ok := tn.properties[property]; ok {
    40  		panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
    41  	}
    42  
    43  	tn.properties[property] = value
    44  	return tn
    45  }
    46  
    47  func (tn *dslNode) MustDecorateWithInt(property string, value int) parser.AstNode {
    48  	if _, ok := tn.properties[property]; ok {
    49  		panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
    50  	}
    51  
    52  	tn.properties[property] = value
    53  	return tn
    54  }
    55  
    56  func (tn *dslNode) Range(mapper input.PositionMapper) (input.SourceRange, error) {
    57  	sourceStr, err := tn.GetString(dslshape.NodePredicateSource)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	source := input.Source(sourceStr)
    63  
    64  	startRune, err := tn.GetInt(dslshape.NodePredicateStartRune)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	endRune, err := tn.GetInt(dslshape.NodePredicateEndRune)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	return source.RangeForRunePositions(startRune, endRune, mapper), nil
    75  }
    76  
    77  func (tn *dslNode) Has(predicateName string) bool {
    78  	_, ok := tn.properties[predicateName]
    79  	return ok
    80  }
    81  
    82  func (tn *dslNode) GetInt(predicateName string) (int, error) {
    83  	predicate, ok := tn.properties[predicateName]
    84  	if !ok {
    85  		return 0, fmt.Errorf("unknown predicate %s", predicateName)
    86  	}
    87  
    88  	value, ok := predicate.(int)
    89  	if !ok {
    90  		return 0, fmt.Errorf("predicate %s is not an int", predicateName)
    91  	}
    92  
    93  	return value, nil
    94  }
    95  
    96  func (tn *dslNode) GetString(predicateName string) (string, error) {
    97  	predicate, ok := tn.properties[predicateName]
    98  	if !ok {
    99  		return "", fmt.Errorf("unknown predicate %s", predicateName)
   100  	}
   101  
   102  	value, ok := predicate.(string)
   103  	if !ok {
   104  		return "", fmt.Errorf("predicate %s is not a string", predicateName)
   105  	}
   106  
   107  	return value, nil
   108  }
   109  
   110  func (tn *dslNode) AllSubNodes() []*dslNode {
   111  	nodes := []*dslNode{}
   112  	for _, childList := range tn.children {
   113  		for e := childList.Front(); e != nil; e = e.Next() {
   114  			nodes = append(nodes, e.Value.(*dslNode))
   115  		}
   116  	}
   117  	return nodes
   118  }
   119  
   120  func (tn *dslNode) GetChildren() []*dslNode {
   121  	return tn.List(dslshape.NodePredicateChild)
   122  }
   123  
   124  func (tn *dslNode) FindAll(nodeType dslshape.NodeType) []*dslNode {
   125  	found := []*dslNode{}
   126  	if tn.nodeType == dslshape.NodeTypeError {
   127  		found = append(found, tn)
   128  	}
   129  
   130  	for _, childList := range tn.children {
   131  		for e := childList.Front(); e != nil; e = e.Next() {
   132  			childFound := e.Value.(*dslNode).FindAll(nodeType)
   133  			found = append(found, childFound...)
   134  		}
   135  	}
   136  	return found
   137  }
   138  
   139  func (tn *dslNode) List(predicateName string) []*dslNode {
   140  	children := []*dslNode{}
   141  	childList, ok := tn.children[predicateName]
   142  	if !ok {
   143  		return children
   144  	}
   145  
   146  	for e := childList.Front(); e != nil; e = e.Next() {
   147  		children = append(children, e.Value.(*dslNode))
   148  	}
   149  
   150  	return children
   151  }
   152  
   153  func (tn *dslNode) Lookup(predicateName string) (*dslNode, error) {
   154  	childList, ok := tn.children[predicateName]
   155  	if !ok {
   156  		return nil, fmt.Errorf("unknown predicate %s", predicateName)
   157  	}
   158  
   159  	for e := childList.Front(); e != nil; e = e.Next() {
   160  		return e.Value.(*dslNode), nil
   161  	}
   162  
   163  	return nil, fmt.Errorf("nothing in predicate %s", predicateName)
   164  }
   165  
   166  func (tn *dslNode) Errorf(message string, args ...interface{}) error {
   167  	return errorWithNode{
   168  		error:           fmt.Errorf(message, args...),
   169  		errorSourceCode: "",
   170  		node:            tn,
   171  	}
   172  }
   173  
   174  func (tn *dslNode) ErrorWithSourcef(sourceCode string, message string, args ...interface{}) error {
   175  	return errorWithNode{
   176  		error:           fmt.Errorf(message, args...),
   177  		errorSourceCode: sourceCode,
   178  		node:            tn,
   179  	}
   180  }