github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/eval_test.go (about) 1 package caveats 2 3 import ( 4 "testing" 5 "time" 6 7 "github.com/authzed/cel-go/cel" 8 "github.com/stretchr/testify/require" 9 10 "github.com/authzed/spicedb/pkg/caveats/types" 11 ) 12 13 var noMissingVars []string 14 15 func TestEvaluateCaveat(t *testing.T) { 16 wetTz, err := time.LoadLocation("WET") 17 require.NoError(t, err) 18 tcs := []struct { 19 name string 20 env *Environment 21 exprString string 22 23 context map[string]any 24 25 expectedError string 26 27 expectedValue bool 28 expectedPartialExpr string 29 missingVars []string 30 }{ 31 { 32 "static expression", 33 MustEnvForVariables(map[string]types.VariableType{}), 34 "true", 35 map[string]any{}, 36 "", 37 true, 38 "", 39 noMissingVars, 40 }, 41 { 42 "static false expression", 43 MustEnvForVariables(map[string]types.VariableType{}), 44 "false", 45 map[string]any{}, 46 "", 47 false, 48 "", 49 noMissingVars, 50 }, 51 { 52 "static numeric expression", 53 MustEnvForVariables(map[string]types.VariableType{}), 54 "1 + 2 == 3", 55 map[string]any{}, 56 "", 57 true, 58 "", 59 noMissingVars, 60 }, 61 { 62 "static false numeric expression", 63 MustEnvForVariables(map[string]types.VariableType{}), 64 "2 - 2 == 1", 65 map[string]any{}, 66 "", 67 false, 68 "", 69 noMissingVars, 70 }, 71 { 72 "computed expression", 73 MustEnvForVariables(map[string]types.VariableType{ 74 "a": types.IntType, 75 }), 76 "a + 2 == 4", 77 map[string]any{ 78 "a": 2, 79 }, 80 "", 81 true, 82 "", 83 noMissingVars, 84 }, 85 { 86 "missing variables for expression", 87 MustEnvForVariables(map[string]types.VariableType{ 88 "a": types.IntType, 89 }), 90 "a + 2 == 4", 91 map[string]any{}, 92 "", 93 false, 94 "a + 2 == 4", 95 []string{"a"}, 96 }, 97 { 98 "missing variables for right side of boolean expression", 99 MustEnvForVariables(map[string]types.VariableType{ 100 "a": types.IntType, 101 "b": types.IntType, 102 }), 103 "(a == 2) || (b == 6)", 104 map[string]any{ 105 "a": 2, 106 }, 107 "", 108 true, 109 "", 110 noMissingVars, 111 }, 112 { 113 "missing variables for left side of boolean expression", 114 MustEnvForVariables(map[string]types.VariableType{ 115 "a": types.IntType, 116 "b": types.IntType, 117 }), 118 "(a == 2) || (b == 6)", 119 map[string]any{ 120 "b": 6, 121 }, 122 "", 123 true, 124 "", 125 noMissingVars, 126 }, 127 { 128 "missing variables for both sides of boolean expression", 129 MustEnvForVariables(map[string]types.VariableType{ 130 "a": types.IntType, 131 "b": types.IntType, 132 }), 133 "(a == 2) || (b == 6)", 134 map[string]any{}, 135 "", 136 false, 137 "a == 2 || b == 6", 138 []string{"a", "b"}, 139 }, 140 { 141 "missing variable for left side of and boolean expression", 142 MustEnvForVariables(map[string]types.VariableType{ 143 "a": types.IntType, 144 "b": types.IntType, 145 }), 146 "(a == 2) && (b == 6)", 147 map[string]any{ 148 "b": 6, 149 }, 150 "", 151 false, 152 "a == 2", 153 []string{"a"}, 154 }, 155 { 156 "missing variable for right side of and boolean expression", 157 MustEnvForVariables(map[string]types.VariableType{ 158 "a": types.IntType, 159 "b": types.IntType, 160 }), 161 "(a == 2) && (b == 6)", 162 map[string]any{ 163 "a": 2, 164 }, 165 "", 166 false, 167 "b == 6", 168 []string{"b"}, 169 }, 170 { 171 "map evaluation", 172 MustEnvForVariables(map[string]types.VariableType{ 173 "m": types.MustMapType(types.BooleanType), 174 "idx": types.StringType, 175 }), 176 "m[idx]", 177 map[string]any{ 178 "m": map[string]bool{ 179 "1": true, 180 }, 181 "idx": "1", 182 }, 183 "", 184 true, 185 "", 186 noMissingVars, 187 }, 188 { 189 "map dot evaluation", 190 MustEnvForVariables(map[string]types.VariableType{ 191 "m": types.MustMapType(types.BooleanType), 192 }), 193 "m.foo", 194 map[string]any{ 195 "m": map[string]bool{ 196 "foo": true, 197 }, 198 }, 199 "", 200 true, 201 "", 202 noMissingVars, 203 }, 204 { 205 "missing map for evaluation", 206 MustEnvForVariables(map[string]types.VariableType{ 207 "m": types.MustMapType(types.BooleanType), 208 "idx": types.StringType, 209 }), 210 "m[idx]", 211 map[string]any{ 212 "idx": "1", 213 }, 214 "", 215 false, 216 "m[idx]", 217 []string{"m"}, 218 }, 219 { 220 "missing map for attribute evaluation", 221 MustEnvForVariables(map[string]types.VariableType{ 222 "m": types.MustMapType(types.BooleanType), 223 }), 224 "m.first", 225 map[string]any{}, 226 "", 227 false, 228 "m.first", 229 []string{"m"}, 230 }, 231 { 232 "nested evaluation", 233 MustEnvForVariables(map[string]types.VariableType{ 234 "metadata.l": types.MustListType(types.StringType), 235 "metadata.idx": types.IntType, 236 }), 237 "metadata.l[metadata.idx] == 'hello'", 238 map[string]any{ 239 "metadata.l": []string{"hi", "hello", "yo"}, 240 "metadata.idx": 1, 241 }, 242 "", 243 true, 244 "", 245 noMissingVars, 246 }, 247 { 248 "nested evaluation with missing value", 249 MustEnvForVariables(map[string]types.VariableType{ 250 "metadata.l": types.MustListType(types.StringType), 251 "metadata.idx": types.IntType, 252 }), 253 "metadata.l[metadata.idx] == 'hello'", 254 map[string]any{ 255 "metadata.l": []string{"hi", "hello", "yo"}, 256 }, 257 "", 258 false, 259 `metadata.l[metadata.idx] == "hello"`, 260 []string{"metadata.idx"}, 261 }, 262 { 263 "nested evaluation with missing list", 264 MustEnvForVariables(map[string]types.VariableType{ 265 "metadata.l": types.MustListType(types.StringType), 266 "metadata.idx": types.IntType, 267 }), 268 "metadata.l[metadata.idx] == 'hello'", 269 map[string]any{ 270 "metadata.idx": 1, 271 }, 272 "", 273 false, 274 `metadata.l[metadata.idx] == "hello"`, 275 []string{"metadata.l"}, 276 }, 277 { 278 "timestamp operations default to UTC", 279 MustEnvForVariables(map[string]types.VariableType{ 280 "a": types.TimestampType, 281 }), 282 "a.getHours() == 9", 283 map[string]any{ 284 "a": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz), 285 }, 286 "", 287 true, 288 "", 289 noMissingVars, 290 }, 291 { 292 "timestamp comparison", 293 MustEnvForVariables(map[string]types.VariableType{ 294 "a": types.TimestampType, 295 "b": types.TimestampType, 296 }), 297 "a < b", 298 map[string]any{ 299 "a": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz), 300 "b": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz), 301 }, 302 "", 303 false, 304 "", 305 noMissingVars, 306 }, 307 { 308 "timestamp comparison 2", 309 MustEnvForVariables(map[string]types.VariableType{ 310 "a": types.TimestampType, 311 "b": types.TimestampType, 312 }), 313 "a <= b", 314 map[string]any{ 315 "a": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz), 316 "b": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz), 317 }, 318 "", 319 true, 320 "", 321 noMissingVars, 322 }, 323 { 324 "optional types not found", 325 MustEnvForVariables(map[string]types.VariableType{ 326 "m": types.MustMapType(types.BooleanType), 327 "key": types.StringType, 328 }), 329 "m[?key].orValue(true)", 330 map[string]any{ 331 "m": map[string]bool{"foo": true, "bar": false}, 332 "key": "baz", 333 }, 334 "", 335 true, 336 "", 337 noMissingVars, 338 }, 339 { 340 "optional types found", 341 MustEnvForVariables(map[string]types.VariableType{ 342 "m": types.MustMapType(types.BooleanType), 343 "key": types.StringType, 344 }), 345 "m[?key].orValue(true)", 346 map[string]any{ 347 "m": map[string]bool{"foo": true, "bar": false}, 348 "key": "bar", 349 }, 350 "", 351 false, 352 "", 353 noMissingVars, 354 }, 355 } 356 357 for _, tc := range tcs { 358 tc := tc 359 t.Run(tc.name, func(t *testing.T) { 360 compiled, err := compileCaveat(tc.env, tc.exprString) 361 require.NoError(t, err) 362 363 result, err := EvaluateCaveat(compiled, tc.context) 364 if tc.expectedError != "" { 365 require.Error(t, err) 366 require.Contains(t, err.Error(), tc.expectedError) 367 require.Nil(t, result) 368 } else { 369 require.NoError(t, err) 370 require.NotNil(t, result) 371 require.Equal(t, tc.expectedValue, result.Value()) 372 373 if tc.expectedPartialExpr != "" { 374 require.True(t, result.IsPartial()) 375 376 partialValue, err := result.PartialValue() 377 require.NoError(t, err) 378 379 astExpr, err := cel.AstToString(partialValue.ast) 380 require.NoError(t, err) 381 382 require.Equal(t, tc.expectedPartialExpr, astExpr) 383 384 vars, err := result.MissingVarNames() 385 require.NoError(t, err) 386 require.EqualValues(t, tc.missingVars, vars) 387 } else { 388 require.False(t, result.IsPartial()) 389 _, partialErr := result.PartialValue() 390 require.Error(t, partialErr) 391 require.Nil(t, tc.missingVars) 392 require.Nil(t, result.missingVarNames) 393 } 394 } 395 }) 396 } 397 } 398 399 func TestPartialEvaluation(t *testing.T) { 400 compiled, err := compileCaveat(MustEnvForVariables(map[string]types.VariableType{ 401 "a": types.IntType, 402 "b": types.IntType, 403 }), "a + b > 47") 404 require.NoError(t, err) 405 406 result, err := EvaluateCaveat(compiled, map[string]any{ 407 "a": 42, 408 }) 409 require.NoError(t, err) 410 require.False(t, result.Value()) 411 require.True(t, result.IsPartial()) 412 413 partialValue, err := result.PartialValue() 414 require.NoError(t, err) 415 416 astExpr, err := cel.AstToString(partialValue.ast) 417 require.NoError(t, err) 418 require.Equal(t, "42 + b > 47", astExpr) 419 420 fullResult, err := EvaluateCaveat(partialValue, map[string]any{ 421 "b": 6, 422 }) 423 require.NoError(t, err) 424 require.True(t, fullResult.Value()) 425 require.False(t, fullResult.IsPartial()) 426 427 fullResult, err = EvaluateCaveat(partialValue, map[string]any{ 428 "b": 2, 429 }) 430 require.NoError(t, err) 431 require.False(t, fullResult.Value()) 432 require.False(t, fullResult.IsPartial()) 433 } 434 435 func TestEvalWithMaxCost(t *testing.T) { 436 compiled, err := compileCaveat(MustEnvForVariables(map[string]types.VariableType{ 437 "a": types.IntType, 438 "b": types.IntType, 439 }), "a + b > 47") 440 require.NoError(t, err) 441 442 _, err = EvaluateCaveatWithConfig(compiled, map[string]any{ 443 "a": 42, 444 "b": 4, 445 }, &EvaluationConfig{ 446 MaxCost: 1, 447 }) 448 require.Error(t, err) 449 require.Equal(t, "operation cancelled: actual cost limit exceeded", err.Error()) 450 } 451 452 func TestEvalWithNesting(t *testing.T) { 453 compiled, err := compileCaveat(MustEnvForVariables(map[string]types.VariableType{ 454 "foo.a": types.IntType, 455 "foo.b": types.IntType, 456 }), "foo.a + foo.b > 47") 457 require.NoError(t, err) 458 459 result, err := EvaluateCaveat(compiled, map[string]any{ 460 "foo.a": 42, 461 "foo.b": 4, 462 }) 463 require.NoError(t, err) 464 require.False(t, result.Value()) 465 require.False(t, result.IsPartial()) 466 }