github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/compile_test.go (about) 1 package caveats 2 3 import ( 4 "errors" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 9 "github.com/authzed/spicedb/pkg/caveats/types" 10 ) 11 12 func TestCompile(t *testing.T) { 13 tcs := []struct { 14 name string 15 env *Environment 16 exprString string 17 expectedErrors []string 18 }{ 19 { 20 "missing var", 21 NewEnvironment(), 22 "hiya", 23 []string{"undeclared reference to 'hiya'"}, 24 }, 25 { 26 "empty expression", 27 NewEnvironment(), 28 "", 29 []string{"mismatched input"}, 30 }, 31 { 32 "invalid expression", 33 NewEnvironment(), 34 "a +", 35 []string{"mismatched input"}, 36 }, 37 { 38 "missing variable", 39 NewEnvironment(), 40 "a + 2", 41 []string{"undeclared reference to 'a'"}, 42 }, 43 { 44 "missing variables", 45 NewEnvironment(), 46 "a + b", 47 []string{"undeclared reference to 'a'", "undeclared reference to 'b'"}, 48 }, 49 { 50 "type mismatch", 51 MustEnvForVariables(map[string]types.VariableType{ 52 "a": types.UIntType, 53 "b": types.BooleanType, 54 }), 55 "a + b", 56 []string{"found no matching overload for '_+_'"}, 57 }, 58 { 59 "valid expression", 60 MustEnvForVariables(map[string]types.VariableType{ 61 "a": types.IntType, 62 "b": types.IntType, 63 }), 64 "a + b == 2", 65 []string{}, 66 }, 67 { 68 "invalid expression over an int", 69 MustEnvForVariables(map[string]types.VariableType{ 70 "a": types.UIntType, 71 }), 72 "a[0]", 73 []string{"found no matching overload for '_[_]'"}, 74 }, 75 { 76 "valid expression over a list", 77 MustEnvForVariables(map[string]types.VariableType{ 78 "a": types.MustListType(types.IntType), 79 }), 80 "a[0] == 1", 81 []string{}, 82 }, 83 { 84 "invalid expression over a list", 85 MustEnvForVariables(map[string]types.VariableType{ 86 "a": types.MustListType(types.UIntType), 87 }), 88 "a['hi']", 89 []string{"found no matching overload for '_[_]'"}, 90 }, 91 { 92 "valid expression over a map", 93 MustEnvForVariables(map[string]types.VariableType{ 94 "a": types.MustMapType(types.IntType), 95 }), 96 "a['hi'] == 1", 97 []string{}, 98 }, 99 { 100 "invalid expression over a map", 101 MustEnvForVariables(map[string]types.VariableType{ 102 "a": types.MustMapType(types.UIntType), 103 }), 104 "a[42]", 105 []string{"found no matching overload for '_[_]'"}, 106 }, 107 { 108 "non-boolean valid expression", 109 MustEnvForVariables(map[string]types.VariableType{ 110 "a": types.IntType, 111 "b": types.IntType, 112 }), 113 "a + b", 114 []string{"caveat expression must result in a boolean value: found `int`"}, 115 }, 116 { 117 "valid expression over a byte sequence", 118 MustEnvForVariables(map[string]types.VariableType{ 119 "a": types.BytesType, 120 }), 121 "a == b\"abc\"", 122 []string{}, 123 }, 124 { 125 "invalid expression over a byte sequence", 126 MustEnvForVariables(map[string]types.VariableType{ 127 "a": types.BytesType, 128 }), 129 "a == \"abc\"", 130 []string{"found no matching overload for '_==_'"}, 131 }, 132 { 133 "valid expression over a double", 134 MustEnvForVariables(map[string]types.VariableType{ 135 "a": types.DoubleType, 136 }), 137 "a == 7.23", 138 []string{}, 139 }, 140 { 141 "invalid expression over a double", 142 MustEnvForVariables(map[string]types.VariableType{ 143 "a": types.DoubleType, 144 }), 145 "a == true", 146 []string{"found no matching overload for '_==_'"}, 147 }, 148 { 149 "valid expression over a duration", 150 MustEnvForVariables(map[string]types.VariableType{ 151 "a": types.DurationType, 152 }), 153 "a > duration(\"1h3m\")", 154 []string{}, 155 }, 156 { 157 "invalid expression over a duration", 158 MustEnvForVariables(map[string]types.VariableType{ 159 "a": types.DurationType, 160 }), 161 "a > \"1h3m\"", 162 []string{"found no matching overload for '_>_'"}, 163 }, 164 { 165 "valid expression over a timestamp", 166 MustEnvForVariables(map[string]types.VariableType{ 167 "a": types.TimestampType, 168 }), 169 "a == timestamp(\"1972-01-01T10:00:20.021-05:00\")", 170 []string{}, 171 }, 172 { 173 "invalid expression over a timestamp", 174 MustEnvForVariables(map[string]types.VariableType{ 175 "a": types.TimestampType, 176 }), 177 "a == \"1972-01-01T10:00:20.021-05:00\"", 178 []string{"found no matching overload for '_==_'"}, 179 }, 180 { 181 "valid expression over any type", 182 MustEnvForVariables(map[string]types.VariableType{ 183 "a": types.AnyType, 184 }), 185 "a == true", 186 []string{}, 187 }, 188 } 189 190 for _, tc := range tcs { 191 tc := tc 192 t.Run(tc.name, func(t *testing.T) { 193 compiled, err := compileCaveat(tc.env, tc.exprString) 194 if len(tc.expectedErrors) == 0 { 195 require.NoError(t, err) 196 require.NotNil(t, compiled) 197 } else { 198 require.Error(t, err) 199 require.Nil(t, compiled) 200 201 isCompilationError := errors.As(err, &CompilationErrors{}) 202 require.True(t, isCompilationError) 203 204 for _, expectedError := range tc.expectedErrors { 205 require.Contains(t, err.Error(), expectedError) 206 } 207 } 208 }) 209 } 210 } 211 212 func TestDeserializeEmpty(t *testing.T) { 213 _, err := DeserializeCaveat([]byte{}, nil) 214 require.NotNil(t, err) 215 } 216 217 func TestSerialization(t *testing.T) { 218 exprs := []string{"a == 1", "a + b == 2", "b - a == 4", "l.all(i, i > 42)"} 219 220 for _, expr := range exprs { 221 expr := expr 222 t.Run(expr, func(t *testing.T) { 223 vars := map[string]types.VariableType{ 224 "a": types.IntType, 225 "b": types.IntType, 226 "l": types.MustListType(types.IntType), 227 } 228 229 env := MustEnvForVariables(vars) 230 compiled, err := compileCaveat(env, expr) 231 require.NoError(t, err) 232 233 serialized, err := compiled.Serialize() 234 require.NoError(t, err) 235 236 deserialized, err := DeserializeCaveat(serialized, vars) 237 require.NoError(t, err) 238 239 astExpr, err := deserialized.ExprString() 240 require.NoError(t, err) 241 require.Equal(t, expr, astExpr) 242 }) 243 } 244 } 245 246 func TestSerializeName(t *testing.T) { 247 vars := map[string]types.VariableType{ 248 "a": types.IntType, 249 "b": types.IntType, 250 } 251 252 env := MustEnvForVariables(vars) 253 compiled, err := CompileCaveatWithName(env, "a == 1", "hi") 254 require.NoError(t, err) 255 256 serialized, err := compiled.Serialize() 257 require.NoError(t, err) 258 259 deserialized, err := DeserializeCaveat(serialized, vars) 260 require.NoError(t, err) 261 262 require.Equal(t, "hi", deserialized.name) 263 }