github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/schemadsl/parser/parser_test.go (about)

     1  package parser
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"os"
     7  	"sort"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  
    13  	"github.com/authzed/spicedb/pkg/schemadsl/dslshape"
    14  	"github.com/authzed/spicedb/pkg/schemadsl/input"
    15  )
    16  
    17  type testNode struct {
    18  	nodeType   dslshape.NodeType
    19  	properties map[string]interface{}
    20  	children   map[string]*list.List
    21  }
    22  
    23  type parserTest struct {
    24  	name     string
    25  	filename string
    26  }
    27  
    28  func (pt *parserTest) input() string {
    29  	b, err := os.ReadFile(fmt.Sprintf("tests/%s.zed", pt.filename))
    30  	if err != nil {
    31  		panic(err)
    32  	}
    33  
    34  	return string(b)
    35  }
    36  
    37  func (pt *parserTest) tree() string {
    38  	b, err := os.ReadFile(fmt.Sprintf("tests/%s.zed.expected", pt.filename))
    39  	if err != nil {
    40  		panic(err)
    41  	}
    42  
    43  	return string(b)
    44  }
    45  
    46  func (pt *parserTest) writeTree(value string) {
    47  	err := os.WriteFile(fmt.Sprintf("tests/%s.zed.expected", pt.filename), []byte(value), 0o600)
    48  	if err != nil {
    49  		panic(err)
    50  	}
    51  }
    52  
    53  func createAstNode(_ input.Source, kind dslshape.NodeType) AstNode {
    54  	return &testNode{
    55  		nodeType:   kind,
    56  		properties: make(map[string]interface{}),
    57  		children:   make(map[string]*list.List),
    58  	}
    59  }
    60  
    61  func (tn *testNode) GetType() dslshape.NodeType {
    62  	return tn.nodeType
    63  }
    64  
    65  func (tn *testNode) Connect(predicate string, other AstNode) {
    66  	if tn.children[predicate] == nil {
    67  		tn.children[predicate] = list.New()
    68  	}
    69  
    70  	tn.children[predicate].PushBack(other)
    71  }
    72  
    73  func (tn *testNode) MustDecorate(property string, value string) AstNode {
    74  	if _, ok := tn.properties[property]; ok {
    75  		panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
    76  	}
    77  
    78  	tn.properties[property] = value
    79  	return tn
    80  }
    81  
    82  func (tn *testNode) MustDecorateWithInt(property string, value int) AstNode {
    83  	if _, ok := tn.properties[property]; ok {
    84  		panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
    85  	}
    86  
    87  	tn.properties[property] = value
    88  	return tn
    89  }
    90  
    91  func TestParser(t *testing.T) {
    92  	parserTests := []parserTest{
    93  		{"empty file test", "empty"},
    94  		{"basic definition test", "basic"},
    95  		{"doc comments test", "doccomments"},
    96  		{"arrow test", "arrow"},
    97  		{"multiple definition test", "multidef"},
    98  		{"broken test", "broken"},
    99  		{"relation missing type test", "relation_missing_type"},
   100  		{"permission missing expression test", "permission_missing_expression"},
   101  		{"relation invalid type test", "relation_invalid_type"},
   102  		{"permission invalid expression test", "permission_invalid_expression"},
   103  		{"cross tenant test", "crosstenant"},
   104  		{"indented comments test", "indentedcomments"},
   105  		{"parens test", "parens"},
   106  		{"multiple parens test", "multiparen"},
   107  		{"multiple slashes in object type", "multipleslashes"},
   108  		{"wildcard test", "wildcard"},
   109  		{"broken wildcard test", "brokenwildcard"},
   110  		{"nil test", "nil"},
   111  		{"caveats type test", "caveatstype"},
   112  		{"basic caveat test", "basiccaveat"},
   113  		{"complex caveat test", "complexcaveat"},
   114  		{"empty caveat test", "emptycaveat"},
   115  		{"unclosed caveat test", "unclosedcaveat"},
   116  		{"invalid caveat expr test", "invalidcaveatexpr"},
   117  		{"associativity test", "associativity"},
   118  		{"super large test", "superlarge"},
   119  		{"invalid permission name test", "invalid_perm_name"},
   120  		{"union positions test", "unionpos"},
   121  	}
   122  
   123  	for _, test := range parserTests {
   124  		test := test
   125  		t.Run(test.name, func(t *testing.T) {
   126  			root := Parse(createAstNode, input.Source(test.name), test.input())
   127  			parseTree := getParseTree((root).(*testNode), 0)
   128  			assert := assert.New(t)
   129  
   130  			found := strings.TrimSpace(parseTree)
   131  
   132  			if os.Getenv("REGEN") == "true" {
   133  				test.writeTree(found)
   134  			} else {
   135  				expected := strings.TrimSpace(test.tree())
   136  				if !assert.Equal(expected, found, test.name) {
   137  					t.Log(parseTree)
   138  				}
   139  			}
   140  		})
   141  	}
   142  }
   143  
   144  func getParseTree(currentNode *testNode, indentation int) string {
   145  	parseTree := ""
   146  	parseTree = parseTree + strings.Repeat(" ", indentation)
   147  	parseTree = parseTree + fmt.Sprintf("%v", currentNode.nodeType)
   148  	parseTree = parseTree + "\n"
   149  
   150  	keys := make([]string, 0)
   151  
   152  	for key := range currentNode.properties {
   153  		keys = append(keys, key)
   154  	}
   155  
   156  	sort.Strings(keys)
   157  
   158  	for _, key := range keys {
   159  		parseTree = parseTree + strings.Repeat(" ", indentation+2)
   160  		parseTree = parseTree + fmt.Sprintf("%s = %v", key, currentNode.properties[key])
   161  		parseTree = parseTree + "\n"
   162  	}
   163  
   164  	keys = make([]string, 0)
   165  
   166  	for key := range currentNode.children {
   167  		keys = append(keys, key)
   168  	}
   169  
   170  	sort.Strings(keys)
   171  
   172  	for _, key := range keys {
   173  		value := currentNode.children[key]
   174  		parseTree = parseTree + fmt.Sprintf("%s%v =>", strings.Repeat(" ", indentation+2), key)
   175  		parseTree = parseTree + "\n"
   176  
   177  		for e := value.Front(); e != nil; e = e.Next() {
   178  			parseTree = parseTree + getParseTree(e.Value.(*testNode), indentation+4)
   179  		}
   180  	}
   181  
   182  	return parseTree
   183  }