github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/compile_test.go (about)

     1  package caveats
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"github.com/authzed/spicedb/pkg/caveats/types"
    10  )
    11  
    12  func TestCompile(t *testing.T) {
    13  	tcs := []struct {
    14  		name           string
    15  		env            *Environment
    16  		exprString     string
    17  		expectedErrors []string
    18  	}{
    19  		{
    20  			"missing var",
    21  			NewEnvironment(),
    22  			"hiya",
    23  			[]string{"undeclared reference to 'hiya'"},
    24  		},
    25  		{
    26  			"empty expression",
    27  			NewEnvironment(),
    28  			"",
    29  			[]string{"mismatched input"},
    30  		},
    31  		{
    32  			"invalid expression",
    33  			NewEnvironment(),
    34  			"a +",
    35  			[]string{"mismatched input"},
    36  		},
    37  		{
    38  			"missing variable",
    39  			NewEnvironment(),
    40  			"a + 2",
    41  			[]string{"undeclared reference to 'a'"},
    42  		},
    43  		{
    44  			"missing variables",
    45  			NewEnvironment(),
    46  			"a + b",
    47  			[]string{"undeclared reference to 'a'", "undeclared reference to 'b'"},
    48  		},
    49  		{
    50  			"type mismatch",
    51  			MustEnvForVariables(map[string]types.VariableType{
    52  				"a": types.UIntType,
    53  				"b": types.BooleanType,
    54  			}),
    55  			"a + b",
    56  			[]string{"found no matching overload for '_+_'"},
    57  		},
    58  		{
    59  			"valid expression",
    60  			MustEnvForVariables(map[string]types.VariableType{
    61  				"a": types.IntType,
    62  				"b": types.IntType,
    63  			}),
    64  			"a + b == 2",
    65  			[]string{},
    66  		},
    67  		{
    68  			"invalid expression over an int",
    69  			MustEnvForVariables(map[string]types.VariableType{
    70  				"a": types.UIntType,
    71  			}),
    72  			"a[0]",
    73  			[]string{"found no matching overload for '_[_]'"},
    74  		},
    75  		{
    76  			"valid expression over a list",
    77  			MustEnvForVariables(map[string]types.VariableType{
    78  				"a": types.MustListType(types.IntType),
    79  			}),
    80  			"a[0] == 1",
    81  			[]string{},
    82  		},
    83  		{
    84  			"invalid expression over a list",
    85  			MustEnvForVariables(map[string]types.VariableType{
    86  				"a": types.MustListType(types.UIntType),
    87  			}),
    88  			"a['hi']",
    89  			[]string{"found no matching overload for '_[_]'"},
    90  		},
    91  		{
    92  			"valid expression over a map",
    93  			MustEnvForVariables(map[string]types.VariableType{
    94  				"a": types.MustMapType(types.IntType),
    95  			}),
    96  			"a['hi'] == 1",
    97  			[]string{},
    98  		},
    99  		{
   100  			"invalid expression over a map",
   101  			MustEnvForVariables(map[string]types.VariableType{
   102  				"a": types.MustMapType(types.UIntType),
   103  			}),
   104  			"a[42]",
   105  			[]string{"found no matching overload for '_[_]'"},
   106  		},
   107  		{
   108  			"non-boolean valid expression",
   109  			MustEnvForVariables(map[string]types.VariableType{
   110  				"a": types.IntType,
   111  				"b": types.IntType,
   112  			}),
   113  			"a + b",
   114  			[]string{"caveat expression must result in a boolean value: found `int`"},
   115  		},
   116  		{
   117  			"valid expression over a byte sequence",
   118  			MustEnvForVariables(map[string]types.VariableType{
   119  				"a": types.BytesType,
   120  			}),
   121  			"a == b\"abc\"",
   122  			[]string{},
   123  		},
   124  		{
   125  			"invalid expression over a byte sequence",
   126  			MustEnvForVariables(map[string]types.VariableType{
   127  				"a": types.BytesType,
   128  			}),
   129  			"a == \"abc\"",
   130  			[]string{"found no matching overload for '_==_'"},
   131  		},
   132  		{
   133  			"valid expression over a double",
   134  			MustEnvForVariables(map[string]types.VariableType{
   135  				"a": types.DoubleType,
   136  			}),
   137  			"a == 7.23",
   138  			[]string{},
   139  		},
   140  		{
   141  			"invalid expression over a double",
   142  			MustEnvForVariables(map[string]types.VariableType{
   143  				"a": types.DoubleType,
   144  			}),
   145  			"a == true",
   146  			[]string{"found no matching overload for '_==_'"},
   147  		},
   148  		{
   149  			"valid expression over a duration",
   150  			MustEnvForVariables(map[string]types.VariableType{
   151  				"a": types.DurationType,
   152  			}),
   153  			"a > duration(\"1h3m\")",
   154  			[]string{},
   155  		},
   156  		{
   157  			"invalid expression over a duration",
   158  			MustEnvForVariables(map[string]types.VariableType{
   159  				"a": types.DurationType,
   160  			}),
   161  			"a > \"1h3m\"",
   162  			[]string{"found no matching overload for '_>_'"},
   163  		},
   164  		{
   165  			"valid expression over a timestamp",
   166  			MustEnvForVariables(map[string]types.VariableType{
   167  				"a": types.TimestampType,
   168  			}),
   169  			"a == timestamp(\"1972-01-01T10:00:20.021-05:00\")",
   170  			[]string{},
   171  		},
   172  		{
   173  			"invalid expression over a timestamp",
   174  			MustEnvForVariables(map[string]types.VariableType{
   175  				"a": types.TimestampType,
   176  			}),
   177  			"a == \"1972-01-01T10:00:20.021-05:00\"",
   178  			[]string{"found no matching overload for '_==_'"},
   179  		},
   180  		{
   181  			"valid expression over any type",
   182  			MustEnvForVariables(map[string]types.VariableType{
   183  				"a": types.AnyType,
   184  			}),
   185  			"a == true",
   186  			[]string{},
   187  		},
   188  	}
   189  
   190  	for _, tc := range tcs {
   191  		tc := tc
   192  		t.Run(tc.name, func(t *testing.T) {
   193  			compiled, err := compileCaveat(tc.env, tc.exprString)
   194  			if len(tc.expectedErrors) == 0 {
   195  				require.NoError(t, err)
   196  				require.NotNil(t, compiled)
   197  			} else {
   198  				require.Error(t, err)
   199  				require.Nil(t, compiled)
   200  
   201  				isCompilationError := errors.As(err, &CompilationErrors{})
   202  				require.True(t, isCompilationError)
   203  
   204  				for _, expectedError := range tc.expectedErrors {
   205  					require.Contains(t, err.Error(), expectedError)
   206  				}
   207  			}
   208  		})
   209  	}
   210  }
   211  
   212  func TestDeserializeEmpty(t *testing.T) {
   213  	_, err := DeserializeCaveat([]byte{}, nil)
   214  	require.NotNil(t, err)
   215  }
   216  
   217  func TestSerialization(t *testing.T) {
   218  	exprs := []string{"a == 1", "a + b == 2", "b - a == 4", "l.all(i, i > 42)"}
   219  
   220  	for _, expr := range exprs {
   221  		expr := expr
   222  		t.Run(expr, func(t *testing.T) {
   223  			vars := map[string]types.VariableType{
   224  				"a": types.IntType,
   225  				"b": types.IntType,
   226  				"l": types.MustListType(types.IntType),
   227  			}
   228  
   229  			env := MustEnvForVariables(vars)
   230  			compiled, err := compileCaveat(env, expr)
   231  			require.NoError(t, err)
   232  
   233  			serialized, err := compiled.Serialize()
   234  			require.NoError(t, err)
   235  
   236  			deserialized, err := DeserializeCaveat(serialized, vars)
   237  			require.NoError(t, err)
   238  
   239  			astExpr, err := deserialized.ExprString()
   240  			require.NoError(t, err)
   241  			require.Equal(t, expr, astExpr)
   242  		})
   243  	}
   244  }
   245  
   246  func TestSerializeName(t *testing.T) {
   247  	vars := map[string]types.VariableType{
   248  		"a": types.IntType,
   249  		"b": types.IntType,
   250  	}
   251  
   252  	env := MustEnvForVariables(vars)
   253  	compiled, err := CompileCaveatWithName(env, "a == 1", "hi")
   254  	require.NoError(t, err)
   255  
   256  	serialized, err := compiled.Serialize()
   257  	require.NoError(t, err)
   258  
   259  	deserialized, err := DeserializeCaveat(serialized, vars)
   260  	require.NoError(t, err)
   261  
   262  	require.Equal(t, "hi", deserialized.name)
   263  }