github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/testutil/subjects.go (about)

     1  package testutil
     2  
     3  import (
     4  	"cmp"
     5  	"fmt"
     6  	"maps"
     7  	"slices"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    14  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    15  	v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    16  	"github.com/authzed/spicedb/pkg/spiceerrors"
    17  	"github.com/authzed/spicedb/pkg/tuple"
    18  )
    19  
    20  // WrapFoundSubject wraps the given subject into a pointer to it, unless nil, in which case this method returns
    21  // nil.
    22  func WrapFoundSubject(sub *v1.FoundSubject) **v1.FoundSubject {
    23  	if sub == nil {
    24  		return nil
    25  	}
    26  
    27  	return &sub
    28  }
    29  
    30  // FoundSubject returns a FoundSubject with the given ID.
    31  func FoundSubject(subjectID string) *v1.FoundSubject {
    32  	return &v1.FoundSubject{
    33  		SubjectId: subjectID,
    34  	}
    35  }
    36  
    37  // CaveatedFoundSubject returns a FoundSubject with the given ID and caveat expression.
    38  func CaveatedFoundSubject(subjectID string, expr *core.CaveatExpression) *v1.FoundSubject {
    39  	return &v1.FoundSubject{
    40  		SubjectId:        subjectID,
    41  		CaveatExpression: expr,
    42  	}
    43  }
    44  
    45  // CaveatedWildcard returns a wildcard FoundSubject with the given caveat expression and exclusions.
    46  func CaveatedWildcard(expr *core.CaveatExpression, exclusions ...*v1.FoundSubject) *v1.FoundSubject {
    47  	return &v1.FoundSubject{
    48  		SubjectId:        tuple.PublicWildcard,
    49  		ExcludedSubjects: exclusions,
    50  		CaveatExpression: expr,
    51  	}
    52  }
    53  
    54  // Wildcard returns a FoundSubject with the given subject IDs as concrete exclusions.
    55  func Wildcard(exclusions ...string) *v1.FoundSubject {
    56  	excludedSubjects := make([]*v1.FoundSubject, 0, len(exclusions))
    57  	for _, excludedID := range exclusions {
    58  		excludedSubjects = append(excludedSubjects, &v1.FoundSubject{
    59  			SubjectId: excludedID,
    60  		})
    61  	}
    62  
    63  	return &v1.FoundSubject{
    64  		SubjectId:        tuple.PublicWildcard,
    65  		ExcludedSubjects: excludedSubjects,
    66  	}
    67  }
    68  
    69  // RequireEquivalentSets requires that the given sets of subjects are equivalent.
    70  func RequireEquivalentSets(t *testing.T, expected []*v1.FoundSubject, found []*v1.FoundSubject) {
    71  	t.Helper()
    72  	err := CheckEquivalentSets(expected, found)
    73  	require.NoError(t, err, "found different subject sets: %v", err)
    74  }
    75  
    76  // RequireExpectedSubject requires that the given expected and produced subjects match.
    77  func RequireExpectedSubject(t *testing.T, expected *v1.FoundSubject, produced **v1.FoundSubject) {
    78  	t.Helper()
    79  	if expected == nil {
    80  		require.Nil(t, produced)
    81  	} else {
    82  		require.NotNil(t, produced)
    83  
    84  		found := *produced
    85  		err := CheckEquivalentSubjects(expected, found)
    86  		require.NoError(t, err, "found different subjects: %v", err)
    87  	}
    88  }
    89  
    90  // CheckEquivalentSets checks if the sets of subjects are equivalent and returns an error if they are not.
    91  func CheckEquivalentSets(expected []*v1.FoundSubject, found []*v1.FoundSubject) error {
    92  	if len(expected) != len(found) {
    93  		return fmt.Errorf("found mismatch in number of elements:\n\texpected: %s\n\tfound: %s", FormatSubjects(expected), FormatSubjects(found))
    94  	}
    95  
    96  	slices.SortFunc(expected, CmpSubjects)
    97  	slices.SortFunc(found, CmpSubjects)
    98  
    99  	for index := range expected {
   100  		err := CheckEquivalentSubjects(expected[index], found[index])
   101  		if err != nil {
   102  			return fmt.Errorf("found mismatch for subject #%d: %w", index, err)
   103  		}
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  // CheckEquivalentSubjects checks if the given subjects are equivalent and returns an error if they are not.
   110  func CheckEquivalentSubjects(expected *v1.FoundSubject, found *v1.FoundSubject) error {
   111  	if expected.SubjectId != found.SubjectId {
   112  		return fmt.Errorf("expected subject %s, found %s", expected.SubjectId, found.SubjectId)
   113  	}
   114  
   115  	err := CheckEquivalentSets(expected.ExcludedSubjects, found.ExcludedSubjects)
   116  	if err != nil {
   117  		return fmt.Errorf("difference in exclusions: %w", err)
   118  	}
   119  
   120  	return checkEquivalentCaveatExprs(expected.CaveatExpression, found.CaveatExpression)
   121  }
   122  
   123  // FormatSubjects formats the given slice of subjects in a human-readable string.
   124  func FormatSubjects(subs []*v1.FoundSubject) string {
   125  	formatted := make([]string, 0, len(subs))
   126  	for _, sub := range subs {
   127  		formatted = append(formatted, FormatSubject(sub))
   128  	}
   129  	return strings.Join(formatted, ",")
   130  }
   131  
   132  // FormatSubject formats the given subject (which can be nil) into a human-readable string.
   133  func FormatSubject(sub *v1.FoundSubject) string {
   134  	if sub == nil {
   135  		return "[nil]"
   136  	}
   137  
   138  	if sub.GetSubjectId() == tuple.PublicWildcard {
   139  		exclusions := make([]string, 0, len(sub.GetExcludedSubjects()))
   140  		for _, excludedSubject := range sub.GetExcludedSubjects() {
   141  			exclusions = append(exclusions, FormatSubject(excludedSubject))
   142  		}
   143  
   144  		exclusionsStr := ""
   145  		if len(exclusions) > 0 {
   146  			exclusionsStr = fmt.Sprintf("- {%s}", strings.Join(exclusions, ","))
   147  		}
   148  
   149  		if sub.GetCaveatExpression() != nil {
   150  			return fmt.Sprintf("{*%s}[%s]", exclusionsStr, formatCaveatExpr(sub.GetCaveatExpression()))
   151  		}
   152  
   153  		return fmt.Sprintf("{*%s}", exclusionsStr)
   154  	}
   155  
   156  	if sub.GetCaveatExpression() != nil {
   157  		return fmt.Sprintf("%s[%s]", sub.GetSubjectId(), formatCaveatExpr(sub.GetCaveatExpression()))
   158  	}
   159  
   160  	return sub.GetSubjectId()
   161  }
   162  
   163  // formatCaveatExpr formats a caveat expression (which can be nil) into a human readable string.
   164  func formatCaveatExpr(expr *core.CaveatExpression) string {
   165  	if expr == nil {
   166  		return "[nil]"
   167  	}
   168  
   169  	if expr.GetCaveat() != nil {
   170  		return expr.GetCaveat().CaveatName
   171  	}
   172  
   173  	switch expr.GetOperation().Op {
   174  	case core.CaveatOperation_AND:
   175  		return fmt.Sprintf("(%s) && (%s)",
   176  			formatCaveatExpr(expr.GetOperation().GetChildren()[0]),
   177  			formatCaveatExpr(expr.GetOperation().GetChildren()[1]),
   178  		)
   179  
   180  	case core.CaveatOperation_OR:
   181  		return fmt.Sprintf("(%s) || (%s)",
   182  			formatCaveatExpr(expr.GetOperation().GetChildren()[0]),
   183  			formatCaveatExpr(expr.GetOperation().GetChildren()[1]),
   184  		)
   185  
   186  	case core.CaveatOperation_NOT:
   187  		return fmt.Sprintf("!(%s)",
   188  			formatCaveatExpr(expr.GetOperation().GetChildren()[0]),
   189  		)
   190  
   191  	default:
   192  		panic("unknown op")
   193  	}
   194  }
   195  
   196  // checkEquivalentCaveatExprs checks if the given caveat expressions are equivalent and returns an error if they are not.
   197  func checkEquivalentCaveatExprs(expected *core.CaveatExpression, found *core.CaveatExpression) error {
   198  	if expected == nil {
   199  		if found != nil {
   200  			return fmt.Errorf("found non-nil caveat expression `%s` where expected nil", formatCaveatExpr(found))
   201  		}
   202  		return nil
   203  	}
   204  
   205  	if found == nil {
   206  		if expected != nil {
   207  			return fmt.Errorf("expected non-nil caveat expression `%s` where found nil", formatCaveatExpr(expected))
   208  		}
   209  		return nil
   210  	}
   211  
   212  	// Caveat expressions that the subjectset generates can be different in structure but *logically* equivalent,
   213  	// so we compare by building a boolean table for each referenced caveat name and then checking all combinations
   214  	// of boolean inputs to ensure the expressions produce the same output. Note that while this isn't the most
   215  	// efficient means of comparison, it is logically correct.
   216  	referencedNamesSet := mapz.NewSet[string]()
   217  	collectReferencedNames(expected, referencedNamesSet)
   218  	collectReferencedNames(found, referencedNamesSet)
   219  
   220  	referencedNames := referencedNamesSet.AsSlice()
   221  	for _, values := range combinatorialValues(referencedNames) {
   222  		expectedResult, err := executeCaveatExprForTesting(expected, values)
   223  		if err != nil {
   224  			return err
   225  		}
   226  
   227  		foundResult, err := executeCaveatExprForTesting(found, values)
   228  		if err != nil {
   229  			return err
   230  		}
   231  
   232  		if expectedResult != foundResult {
   233  			return fmt.Errorf("found difference between caveats for values:\n\tvalues: %v\n\texpected caveat: %s\n\tfound caveat:%s", values, formatCaveatExpr(expected), formatCaveatExpr(found))
   234  		}
   235  	}
   236  	return nil
   237  }
   238  
   239  // executeCaveatExprForTesting "executes" the given caveat expression for testing. DO NOT USE OUTSIDE OF TESTING.
   240  // This method *ignores* caveat context and treats each caveat as just its name.
   241  func executeCaveatExprForTesting(expr *core.CaveatExpression, values map[string]bool) (bool, error) {
   242  	if expr.GetCaveat() != nil {
   243  		return values[expr.GetCaveat().CaveatName], nil
   244  	}
   245  
   246  	switch expr.GetOperation().Op {
   247  	case core.CaveatOperation_AND:
   248  		if len(expr.GetOperation().Children) != 2 {
   249  			return false, spiceerrors.MustBugf("found invalid child count for AND")
   250  		}
   251  
   252  		left, err := executeCaveatExprForTesting(expr.GetOperation().Children[0], values)
   253  		if err != nil {
   254  			return false, err
   255  		}
   256  
   257  		right, err := executeCaveatExprForTesting(expr.GetOperation().Children[1], values)
   258  		if err != nil {
   259  			return false, err
   260  		}
   261  
   262  		return left && right, nil
   263  
   264  	case core.CaveatOperation_OR:
   265  		if len(expr.GetOperation().Children) != 2 {
   266  			return false, spiceerrors.MustBugf("found invalid child count for OR")
   267  		}
   268  
   269  		left, err := executeCaveatExprForTesting(expr.GetOperation().Children[0], values)
   270  		if err != nil {
   271  			return false, err
   272  		}
   273  
   274  		right, err := executeCaveatExprForTesting(expr.GetOperation().Children[1], values)
   275  		if err != nil {
   276  			return false, err
   277  		}
   278  
   279  		return left || right, nil
   280  
   281  	case core.CaveatOperation_NOT:
   282  		if len(expr.GetOperation().Children) != 1 {
   283  			return false, spiceerrors.MustBugf("found invalid child count for NOT")
   284  		}
   285  
   286  		result, err := executeCaveatExprForTesting(expr.GetOperation().Children[0], values)
   287  		if err != nil {
   288  			return false, err
   289  		}
   290  		return !result, nil
   291  
   292  	default:
   293  		return false, spiceerrors.MustBugf("unknown caveat operation")
   294  	}
   295  }
   296  
   297  // combinatorialValues returns the combinatorial set of values where each name is either true and false.
   298  func combinatorialValues(names []string) []map[string]bool {
   299  	if len(names) == 0 {
   300  		return nil
   301  	}
   302  
   303  	name := names[0]
   304  	childMaps := combinatorialValues(names[1:])
   305  
   306  	cmaps := make([]map[string]bool, 0, len(childMaps)*2)
   307  	if len(childMaps) == 0 {
   308  		for _, v := range []bool{true, false} {
   309  			cloned := map[string]bool{}
   310  			cloned[name] = v
   311  			cmaps = append(cmaps, cloned)
   312  		}
   313  	} else {
   314  		for _, childMap := range childMaps {
   315  			for _, v := range []bool{true, false} {
   316  				cloned := maps.Clone(childMap)
   317  				cloned[name] = v
   318  				cmaps = append(cmaps, cloned)
   319  			}
   320  		}
   321  	}
   322  
   323  	return cmaps
   324  }
   325  
   326  // collectReferencedNames collects all referenced caveat names into the given set.
   327  func collectReferencedNames(expr *core.CaveatExpression, nameSet *mapz.Set[string]) {
   328  	if expr.GetCaveat() != nil {
   329  		nameSet.Insert(expr.GetCaveat().CaveatName)
   330  		return
   331  	}
   332  
   333  	for _, child := range expr.GetOperation().GetChildren() {
   334  		collectReferencedNames(child, nameSet)
   335  	}
   336  }
   337  
   338  // CmpSubjects compares FoundSubjects such that they can be sorted.
   339  func CmpSubjects(a, b *v1.FoundSubject) int {
   340  	return cmp.Compare(a.SubjectId, b.SubjectId)
   341  }