github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/testutil/subjects.go (about) 1 package testutil 2 3 import ( 4 "cmp" 5 "fmt" 6 "maps" 7 "slices" 8 "strings" 9 "testing" 10 11 "github.com/stretchr/testify/require" 12 13 "github.com/authzed/spicedb/pkg/genutil/mapz" 14 core "github.com/authzed/spicedb/pkg/proto/core/v1" 15 v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" 16 "github.com/authzed/spicedb/pkg/spiceerrors" 17 "github.com/authzed/spicedb/pkg/tuple" 18 ) 19 20 // WrapFoundSubject wraps the given subject into a pointer to it, unless nil, in which case this method returns 21 // nil. 22 func WrapFoundSubject(sub *v1.FoundSubject) **v1.FoundSubject { 23 if sub == nil { 24 return nil 25 } 26 27 return &sub 28 } 29 30 // FoundSubject returns a FoundSubject with the given ID. 31 func FoundSubject(subjectID string) *v1.FoundSubject { 32 return &v1.FoundSubject{ 33 SubjectId: subjectID, 34 } 35 } 36 37 // CaveatedFoundSubject returns a FoundSubject with the given ID and caveat expression. 38 func CaveatedFoundSubject(subjectID string, expr *core.CaveatExpression) *v1.FoundSubject { 39 return &v1.FoundSubject{ 40 SubjectId: subjectID, 41 CaveatExpression: expr, 42 } 43 } 44 45 // CaveatedWildcard returns a wildcard FoundSubject with the given caveat expression and exclusions. 46 func CaveatedWildcard(expr *core.CaveatExpression, exclusions ...*v1.FoundSubject) *v1.FoundSubject { 47 return &v1.FoundSubject{ 48 SubjectId: tuple.PublicWildcard, 49 ExcludedSubjects: exclusions, 50 CaveatExpression: expr, 51 } 52 } 53 54 // Wildcard returns a FoundSubject with the given subject IDs as concrete exclusions. 55 func Wildcard(exclusions ...string) *v1.FoundSubject { 56 excludedSubjects := make([]*v1.FoundSubject, 0, len(exclusions)) 57 for _, excludedID := range exclusions { 58 excludedSubjects = append(excludedSubjects, &v1.FoundSubject{ 59 SubjectId: excludedID, 60 }) 61 } 62 63 return &v1.FoundSubject{ 64 SubjectId: tuple.PublicWildcard, 65 ExcludedSubjects: excludedSubjects, 66 } 67 } 68 69 // RequireEquivalentSets requires that the given sets of subjects are equivalent. 70 func RequireEquivalentSets(t *testing.T, expected []*v1.FoundSubject, found []*v1.FoundSubject) { 71 t.Helper() 72 err := CheckEquivalentSets(expected, found) 73 require.NoError(t, err, "found different subject sets: %v", err) 74 } 75 76 // RequireExpectedSubject requires that the given expected and produced subjects match. 77 func RequireExpectedSubject(t *testing.T, expected *v1.FoundSubject, produced **v1.FoundSubject) { 78 t.Helper() 79 if expected == nil { 80 require.Nil(t, produced) 81 } else { 82 require.NotNil(t, produced) 83 84 found := *produced 85 err := CheckEquivalentSubjects(expected, found) 86 require.NoError(t, err, "found different subjects: %v", err) 87 } 88 } 89 90 // CheckEquivalentSets checks if the sets of subjects are equivalent and returns an error if they are not. 91 func CheckEquivalentSets(expected []*v1.FoundSubject, found []*v1.FoundSubject) error { 92 if len(expected) != len(found) { 93 return fmt.Errorf("found mismatch in number of elements:\n\texpected: %s\n\tfound: %s", FormatSubjects(expected), FormatSubjects(found)) 94 } 95 96 slices.SortFunc(expected, CmpSubjects) 97 slices.SortFunc(found, CmpSubjects) 98 99 for index := range expected { 100 err := CheckEquivalentSubjects(expected[index], found[index]) 101 if err != nil { 102 return fmt.Errorf("found mismatch for subject #%d: %w", index, err) 103 } 104 } 105 106 return nil 107 } 108 109 // CheckEquivalentSubjects checks if the given subjects are equivalent and returns an error if they are not. 110 func CheckEquivalentSubjects(expected *v1.FoundSubject, found *v1.FoundSubject) error { 111 if expected.SubjectId != found.SubjectId { 112 return fmt.Errorf("expected subject %s, found %s", expected.SubjectId, found.SubjectId) 113 } 114 115 err := CheckEquivalentSets(expected.ExcludedSubjects, found.ExcludedSubjects) 116 if err != nil { 117 return fmt.Errorf("difference in exclusions: %w", err) 118 } 119 120 return checkEquivalentCaveatExprs(expected.CaveatExpression, found.CaveatExpression) 121 } 122 123 // FormatSubjects formats the given slice of subjects in a human-readable string. 124 func FormatSubjects(subs []*v1.FoundSubject) string { 125 formatted := make([]string, 0, len(subs)) 126 for _, sub := range subs { 127 formatted = append(formatted, FormatSubject(sub)) 128 } 129 return strings.Join(formatted, ",") 130 } 131 132 // FormatSubject formats the given subject (which can be nil) into a human-readable string. 133 func FormatSubject(sub *v1.FoundSubject) string { 134 if sub == nil { 135 return "[nil]" 136 } 137 138 if sub.GetSubjectId() == tuple.PublicWildcard { 139 exclusions := make([]string, 0, len(sub.GetExcludedSubjects())) 140 for _, excludedSubject := range sub.GetExcludedSubjects() { 141 exclusions = append(exclusions, FormatSubject(excludedSubject)) 142 } 143 144 exclusionsStr := "" 145 if len(exclusions) > 0 { 146 exclusionsStr = fmt.Sprintf("- {%s}", strings.Join(exclusions, ",")) 147 } 148 149 if sub.GetCaveatExpression() != nil { 150 return fmt.Sprintf("{*%s}[%s]", exclusionsStr, formatCaveatExpr(sub.GetCaveatExpression())) 151 } 152 153 return fmt.Sprintf("{*%s}", exclusionsStr) 154 } 155 156 if sub.GetCaveatExpression() != nil { 157 return fmt.Sprintf("%s[%s]", sub.GetSubjectId(), formatCaveatExpr(sub.GetCaveatExpression())) 158 } 159 160 return sub.GetSubjectId() 161 } 162 163 // formatCaveatExpr formats a caveat expression (which can be nil) into a human readable string. 164 func formatCaveatExpr(expr *core.CaveatExpression) string { 165 if expr == nil { 166 return "[nil]" 167 } 168 169 if expr.GetCaveat() != nil { 170 return expr.GetCaveat().CaveatName 171 } 172 173 switch expr.GetOperation().Op { 174 case core.CaveatOperation_AND: 175 return fmt.Sprintf("(%s) && (%s)", 176 formatCaveatExpr(expr.GetOperation().GetChildren()[0]), 177 formatCaveatExpr(expr.GetOperation().GetChildren()[1]), 178 ) 179 180 case core.CaveatOperation_OR: 181 return fmt.Sprintf("(%s) || (%s)", 182 formatCaveatExpr(expr.GetOperation().GetChildren()[0]), 183 formatCaveatExpr(expr.GetOperation().GetChildren()[1]), 184 ) 185 186 case core.CaveatOperation_NOT: 187 return fmt.Sprintf("!(%s)", 188 formatCaveatExpr(expr.GetOperation().GetChildren()[0]), 189 ) 190 191 default: 192 panic("unknown op") 193 } 194 } 195 196 // checkEquivalentCaveatExprs checks if the given caveat expressions are equivalent and returns an error if they are not. 197 func checkEquivalentCaveatExprs(expected *core.CaveatExpression, found *core.CaveatExpression) error { 198 if expected == nil { 199 if found != nil { 200 return fmt.Errorf("found non-nil caveat expression `%s` where expected nil", formatCaveatExpr(found)) 201 } 202 return nil 203 } 204 205 if found == nil { 206 if expected != nil { 207 return fmt.Errorf("expected non-nil caveat expression `%s` where found nil", formatCaveatExpr(expected)) 208 } 209 return nil 210 } 211 212 // Caveat expressions that the subjectset generates can be different in structure but *logically* equivalent, 213 // so we compare by building a boolean table for each referenced caveat name and then checking all combinations 214 // of boolean inputs to ensure the expressions produce the same output. Note that while this isn't the most 215 // efficient means of comparison, it is logically correct. 216 referencedNamesSet := mapz.NewSet[string]() 217 collectReferencedNames(expected, referencedNamesSet) 218 collectReferencedNames(found, referencedNamesSet) 219 220 referencedNames := referencedNamesSet.AsSlice() 221 for _, values := range combinatorialValues(referencedNames) { 222 expectedResult, err := executeCaveatExprForTesting(expected, values) 223 if err != nil { 224 return err 225 } 226 227 foundResult, err := executeCaveatExprForTesting(found, values) 228 if err != nil { 229 return err 230 } 231 232 if expectedResult != foundResult { 233 return fmt.Errorf("found difference between caveats for values:\n\tvalues: %v\n\texpected caveat: %s\n\tfound caveat:%s", values, formatCaveatExpr(expected), formatCaveatExpr(found)) 234 } 235 } 236 return nil 237 } 238 239 // executeCaveatExprForTesting "executes" the given caveat expression for testing. DO NOT USE OUTSIDE OF TESTING. 240 // This method *ignores* caveat context and treats each caveat as just its name. 241 func executeCaveatExprForTesting(expr *core.CaveatExpression, values map[string]bool) (bool, error) { 242 if expr.GetCaveat() != nil { 243 return values[expr.GetCaveat().CaveatName], nil 244 } 245 246 switch expr.GetOperation().Op { 247 case core.CaveatOperation_AND: 248 if len(expr.GetOperation().Children) != 2 { 249 return false, spiceerrors.MustBugf("found invalid child count for AND") 250 } 251 252 left, err := executeCaveatExprForTesting(expr.GetOperation().Children[0], values) 253 if err != nil { 254 return false, err 255 } 256 257 right, err := executeCaveatExprForTesting(expr.GetOperation().Children[1], values) 258 if err != nil { 259 return false, err 260 } 261 262 return left && right, nil 263 264 case core.CaveatOperation_OR: 265 if len(expr.GetOperation().Children) != 2 { 266 return false, spiceerrors.MustBugf("found invalid child count for OR") 267 } 268 269 left, err := executeCaveatExprForTesting(expr.GetOperation().Children[0], values) 270 if err != nil { 271 return false, err 272 } 273 274 right, err := executeCaveatExprForTesting(expr.GetOperation().Children[1], values) 275 if err != nil { 276 return false, err 277 } 278 279 return left || right, nil 280 281 case core.CaveatOperation_NOT: 282 if len(expr.GetOperation().Children) != 1 { 283 return false, spiceerrors.MustBugf("found invalid child count for NOT") 284 } 285 286 result, err := executeCaveatExprForTesting(expr.GetOperation().Children[0], values) 287 if err != nil { 288 return false, err 289 } 290 return !result, nil 291 292 default: 293 return false, spiceerrors.MustBugf("unknown caveat operation") 294 } 295 } 296 297 // combinatorialValues returns the combinatorial set of values where each name is either true and false. 298 func combinatorialValues(names []string) []map[string]bool { 299 if len(names) == 0 { 300 return nil 301 } 302 303 name := names[0] 304 childMaps := combinatorialValues(names[1:]) 305 306 cmaps := make([]map[string]bool, 0, len(childMaps)*2) 307 if len(childMaps) == 0 { 308 for _, v := range []bool{true, false} { 309 cloned := map[string]bool{} 310 cloned[name] = v 311 cmaps = append(cmaps, cloned) 312 } 313 } else { 314 for _, childMap := range childMaps { 315 for _, v := range []bool{true, false} { 316 cloned := maps.Clone(childMap) 317 cloned[name] = v 318 cmaps = append(cmaps, cloned) 319 } 320 } 321 } 322 323 return cmaps 324 } 325 326 // collectReferencedNames collects all referenced caveat names into the given set. 327 func collectReferencedNames(expr *core.CaveatExpression, nameSet *mapz.Set[string]) { 328 if expr.GetCaveat() != nil { 329 nameSet.Insert(expr.GetCaveat().CaveatName) 330 return 331 } 332 333 for _, child := range expr.GetOperation().GetChildren() { 334 collectReferencedNames(child, nameSet) 335 } 336 } 337 338 // CmpSubjects compares FoundSubjects such that they can be sorted. 339 func CmpSubjects(a, b *v1.FoundSubject) int { 340 return cmp.Compare(a.SubjectId, b.SubjectId) 341 }