github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/eval_test.go (about)

     1  package caveats
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/authzed/cel-go/cel"
     8  	"github.com/stretchr/testify/require"
     9  
    10  	"github.com/authzed/spicedb/pkg/caveats/types"
    11  )
    12  
    13  var noMissingVars []string
    14  
    15  func TestEvaluateCaveat(t *testing.T) {
    16  	wetTz, err := time.LoadLocation("WET")
    17  	require.NoError(t, err)
    18  	tcs := []struct {
    19  		name       string
    20  		env        *Environment
    21  		exprString string
    22  
    23  		context map[string]any
    24  
    25  		expectedError string
    26  
    27  		expectedValue       bool
    28  		expectedPartialExpr string
    29  		missingVars         []string
    30  	}{
    31  		{
    32  			"static expression",
    33  			MustEnvForVariables(map[string]types.VariableType{}),
    34  			"true",
    35  			map[string]any{},
    36  			"",
    37  			true,
    38  			"",
    39  			noMissingVars,
    40  		},
    41  		{
    42  			"static false expression",
    43  			MustEnvForVariables(map[string]types.VariableType{}),
    44  			"false",
    45  			map[string]any{},
    46  			"",
    47  			false,
    48  			"",
    49  			noMissingVars,
    50  		},
    51  		{
    52  			"static numeric expression",
    53  			MustEnvForVariables(map[string]types.VariableType{}),
    54  			"1 + 2 == 3",
    55  			map[string]any{},
    56  			"",
    57  			true,
    58  			"",
    59  			noMissingVars,
    60  		},
    61  		{
    62  			"static false numeric expression",
    63  			MustEnvForVariables(map[string]types.VariableType{}),
    64  			"2 - 2 == 1",
    65  			map[string]any{},
    66  			"",
    67  			false,
    68  			"",
    69  			noMissingVars,
    70  		},
    71  		{
    72  			"computed expression",
    73  			MustEnvForVariables(map[string]types.VariableType{
    74  				"a": types.IntType,
    75  			}),
    76  			"a + 2 == 4",
    77  			map[string]any{
    78  				"a": 2,
    79  			},
    80  			"",
    81  			true,
    82  			"",
    83  			noMissingVars,
    84  		},
    85  		{
    86  			"missing variables for expression",
    87  			MustEnvForVariables(map[string]types.VariableType{
    88  				"a": types.IntType,
    89  			}),
    90  			"a + 2 == 4",
    91  			map[string]any{},
    92  			"",
    93  			false,
    94  			"a + 2 == 4",
    95  			[]string{"a"},
    96  		},
    97  		{
    98  			"missing variables for right side of boolean expression",
    99  			MustEnvForVariables(map[string]types.VariableType{
   100  				"a": types.IntType,
   101  				"b": types.IntType,
   102  			}),
   103  			"(a == 2) || (b == 6)",
   104  			map[string]any{
   105  				"a": 2,
   106  			},
   107  			"",
   108  			true,
   109  			"",
   110  			noMissingVars,
   111  		},
   112  		{
   113  			"missing variables for left side of boolean expression",
   114  			MustEnvForVariables(map[string]types.VariableType{
   115  				"a": types.IntType,
   116  				"b": types.IntType,
   117  			}),
   118  			"(a == 2) || (b == 6)",
   119  			map[string]any{
   120  				"b": 6,
   121  			},
   122  			"",
   123  			true,
   124  			"",
   125  			noMissingVars,
   126  		},
   127  		{
   128  			"missing variables for both sides of boolean expression",
   129  			MustEnvForVariables(map[string]types.VariableType{
   130  				"a": types.IntType,
   131  				"b": types.IntType,
   132  			}),
   133  			"(a == 2) || (b == 6)",
   134  			map[string]any{},
   135  			"",
   136  			false,
   137  			"a == 2 || b == 6",
   138  			[]string{"a", "b"},
   139  		},
   140  		{
   141  			"missing variable for left side of and boolean expression",
   142  			MustEnvForVariables(map[string]types.VariableType{
   143  				"a": types.IntType,
   144  				"b": types.IntType,
   145  			}),
   146  			"(a == 2) && (b == 6)",
   147  			map[string]any{
   148  				"b": 6,
   149  			},
   150  			"",
   151  			false,
   152  			"a == 2",
   153  			[]string{"a"},
   154  		},
   155  		{
   156  			"missing variable for right side of and boolean expression",
   157  			MustEnvForVariables(map[string]types.VariableType{
   158  				"a": types.IntType,
   159  				"b": types.IntType,
   160  			}),
   161  			"(a == 2) && (b == 6)",
   162  			map[string]any{
   163  				"a": 2,
   164  			},
   165  			"",
   166  			false,
   167  			"b == 6",
   168  			[]string{"b"},
   169  		},
   170  		{
   171  			"map evaluation",
   172  			MustEnvForVariables(map[string]types.VariableType{
   173  				"m":   types.MustMapType(types.BooleanType),
   174  				"idx": types.StringType,
   175  			}),
   176  			"m[idx]",
   177  			map[string]any{
   178  				"m": map[string]bool{
   179  					"1": true,
   180  				},
   181  				"idx": "1",
   182  			},
   183  			"",
   184  			true,
   185  			"",
   186  			noMissingVars,
   187  		},
   188  		{
   189  			"map dot evaluation",
   190  			MustEnvForVariables(map[string]types.VariableType{
   191  				"m": types.MustMapType(types.BooleanType),
   192  			}),
   193  			"m.foo",
   194  			map[string]any{
   195  				"m": map[string]bool{
   196  					"foo": true,
   197  				},
   198  			},
   199  			"",
   200  			true,
   201  			"",
   202  			noMissingVars,
   203  		},
   204  		{
   205  			"missing map for evaluation",
   206  			MustEnvForVariables(map[string]types.VariableType{
   207  				"m":   types.MustMapType(types.BooleanType),
   208  				"idx": types.StringType,
   209  			}),
   210  			"m[idx]",
   211  			map[string]any{
   212  				"idx": "1",
   213  			},
   214  			"",
   215  			false,
   216  			"m[idx]",
   217  			[]string{"m"},
   218  		},
   219  		{
   220  			"missing map for attribute evaluation",
   221  			MustEnvForVariables(map[string]types.VariableType{
   222  				"m": types.MustMapType(types.BooleanType),
   223  			}),
   224  			"m.first",
   225  			map[string]any{},
   226  			"",
   227  			false,
   228  			"m.first",
   229  			[]string{"m"},
   230  		},
   231  		{
   232  			"nested evaluation",
   233  			MustEnvForVariables(map[string]types.VariableType{
   234  				"metadata.l":   types.MustListType(types.StringType),
   235  				"metadata.idx": types.IntType,
   236  			}),
   237  			"metadata.l[metadata.idx] == 'hello'",
   238  			map[string]any{
   239  				"metadata.l":   []string{"hi", "hello", "yo"},
   240  				"metadata.idx": 1,
   241  			},
   242  			"",
   243  			true,
   244  			"",
   245  			noMissingVars,
   246  		},
   247  		{
   248  			"nested evaluation with missing value",
   249  			MustEnvForVariables(map[string]types.VariableType{
   250  				"metadata.l":   types.MustListType(types.StringType),
   251  				"metadata.idx": types.IntType,
   252  			}),
   253  			"metadata.l[metadata.idx] == 'hello'",
   254  			map[string]any{
   255  				"metadata.l": []string{"hi", "hello", "yo"},
   256  			},
   257  			"",
   258  			false,
   259  			`metadata.l[metadata.idx] == "hello"`,
   260  			[]string{"metadata.idx"},
   261  		},
   262  		{
   263  			"nested evaluation with missing list",
   264  			MustEnvForVariables(map[string]types.VariableType{
   265  				"metadata.l":   types.MustListType(types.StringType),
   266  				"metadata.idx": types.IntType,
   267  			}),
   268  			"metadata.l[metadata.idx] == 'hello'",
   269  			map[string]any{
   270  				"metadata.idx": 1,
   271  			},
   272  			"",
   273  			false,
   274  			`metadata.l[metadata.idx] == "hello"`,
   275  			[]string{"metadata.l"},
   276  		},
   277  		{
   278  			"timestamp operations default to UTC",
   279  			MustEnvForVariables(map[string]types.VariableType{
   280  				"a": types.TimestampType,
   281  			}),
   282  			"a.getHours() == 9",
   283  			map[string]any{
   284  				"a": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz),
   285  			},
   286  			"",
   287  			true,
   288  			"",
   289  			noMissingVars,
   290  		},
   291  		{
   292  			"timestamp comparison",
   293  			MustEnvForVariables(map[string]types.VariableType{
   294  				"a": types.TimestampType,
   295  				"b": types.TimestampType,
   296  			}),
   297  			"a < b",
   298  			map[string]any{
   299  				"a": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz),
   300  				"b": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz),
   301  			},
   302  			"",
   303  			false,
   304  			"",
   305  			noMissingVars,
   306  		},
   307  		{
   308  			"timestamp comparison 2",
   309  			MustEnvForVariables(map[string]types.VariableType{
   310  				"a": types.TimestampType,
   311  				"b": types.TimestampType,
   312  			}),
   313  			"a <= b",
   314  			map[string]any{
   315  				"a": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz),
   316  				"b": time.Date(2000, 10, 10, 10, 10, 10, 10, wetTz),
   317  			},
   318  			"",
   319  			true,
   320  			"",
   321  			noMissingVars,
   322  		},
   323  		{
   324  			"optional types not found",
   325  			MustEnvForVariables(map[string]types.VariableType{
   326  				"m":   types.MustMapType(types.BooleanType),
   327  				"key": types.StringType,
   328  			}),
   329  			"m[?key].orValue(true)",
   330  			map[string]any{
   331  				"m":   map[string]bool{"foo": true, "bar": false},
   332  				"key": "baz",
   333  			},
   334  			"",
   335  			true,
   336  			"",
   337  			noMissingVars,
   338  		},
   339  		{
   340  			"optional types found",
   341  			MustEnvForVariables(map[string]types.VariableType{
   342  				"m":   types.MustMapType(types.BooleanType),
   343  				"key": types.StringType,
   344  			}),
   345  			"m[?key].orValue(true)",
   346  			map[string]any{
   347  				"m":   map[string]bool{"foo": true, "bar": false},
   348  				"key": "bar",
   349  			},
   350  			"",
   351  			false,
   352  			"",
   353  			noMissingVars,
   354  		},
   355  	}
   356  
   357  	for _, tc := range tcs {
   358  		tc := tc
   359  		t.Run(tc.name, func(t *testing.T) {
   360  			compiled, err := compileCaveat(tc.env, tc.exprString)
   361  			require.NoError(t, err)
   362  
   363  			result, err := EvaluateCaveat(compiled, tc.context)
   364  			if tc.expectedError != "" {
   365  				require.Error(t, err)
   366  				require.Contains(t, err.Error(), tc.expectedError)
   367  				require.Nil(t, result)
   368  			} else {
   369  				require.NoError(t, err)
   370  				require.NotNil(t, result)
   371  				require.Equal(t, tc.expectedValue, result.Value())
   372  
   373  				if tc.expectedPartialExpr != "" {
   374  					require.True(t, result.IsPartial())
   375  
   376  					partialValue, err := result.PartialValue()
   377  					require.NoError(t, err)
   378  
   379  					astExpr, err := cel.AstToString(partialValue.ast)
   380  					require.NoError(t, err)
   381  
   382  					require.Equal(t, tc.expectedPartialExpr, astExpr)
   383  
   384  					vars, err := result.MissingVarNames()
   385  					require.NoError(t, err)
   386  					require.EqualValues(t, tc.missingVars, vars)
   387  				} else {
   388  					require.False(t, result.IsPartial())
   389  					_, partialErr := result.PartialValue()
   390  					require.Error(t, partialErr)
   391  					require.Nil(t, tc.missingVars)
   392  					require.Nil(t, result.missingVarNames)
   393  				}
   394  			}
   395  		})
   396  	}
   397  }
   398  
   399  func TestPartialEvaluation(t *testing.T) {
   400  	compiled, err := compileCaveat(MustEnvForVariables(map[string]types.VariableType{
   401  		"a": types.IntType,
   402  		"b": types.IntType,
   403  	}), "a + b > 47")
   404  	require.NoError(t, err)
   405  
   406  	result, err := EvaluateCaveat(compiled, map[string]any{
   407  		"a": 42,
   408  	})
   409  	require.NoError(t, err)
   410  	require.False(t, result.Value())
   411  	require.True(t, result.IsPartial())
   412  
   413  	partialValue, err := result.PartialValue()
   414  	require.NoError(t, err)
   415  
   416  	astExpr, err := cel.AstToString(partialValue.ast)
   417  	require.NoError(t, err)
   418  	require.Equal(t, "42 + b > 47", astExpr)
   419  
   420  	fullResult, err := EvaluateCaveat(partialValue, map[string]any{
   421  		"b": 6,
   422  	})
   423  	require.NoError(t, err)
   424  	require.True(t, fullResult.Value())
   425  	require.False(t, fullResult.IsPartial())
   426  
   427  	fullResult, err = EvaluateCaveat(partialValue, map[string]any{
   428  		"b": 2,
   429  	})
   430  	require.NoError(t, err)
   431  	require.False(t, fullResult.Value())
   432  	require.False(t, fullResult.IsPartial())
   433  }
   434  
   435  func TestEvalWithMaxCost(t *testing.T) {
   436  	compiled, err := compileCaveat(MustEnvForVariables(map[string]types.VariableType{
   437  		"a": types.IntType,
   438  		"b": types.IntType,
   439  	}), "a + b > 47")
   440  	require.NoError(t, err)
   441  
   442  	_, err = EvaluateCaveatWithConfig(compiled, map[string]any{
   443  		"a": 42,
   444  		"b": 4,
   445  	}, &EvaluationConfig{
   446  		MaxCost: 1,
   447  	})
   448  	require.Error(t, err)
   449  	require.Equal(t, "operation cancelled: actual cost limit exceeded", err.Error())
   450  }
   451  
   452  func TestEvalWithNesting(t *testing.T) {
   453  	compiled, err := compileCaveat(MustEnvForVariables(map[string]types.VariableType{
   454  		"foo.a": types.IntType,
   455  		"foo.b": types.IntType,
   456  	}), "foo.a + foo.b > 47")
   457  	require.NoError(t, err)
   458  
   459  	result, err := EvaluateCaveat(compiled, map[string]any{
   460  		"foo.a": 42,
   461  		"foo.b": 4,
   462  	})
   463  	require.NoError(t, err)
   464  	require.False(t, result.Value())
   465  	require.False(t, result.IsPartial())
   466  }