github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/testutil/proto.go (about) 1 package testutil 2 3 import ( 4 "fmt" 5 "slices" 6 "strings" 7 "testing" 8 9 "github.com/google/go-cmp/cmp" 10 "google.golang.org/protobuf/encoding/prototext" 11 "google.golang.org/protobuf/proto" 12 13 "github.com/stretchr/testify/require" 14 ) 15 16 // AreProtoEqual returns whether the expected and required protocol buffer messages are equal, via proto.Equal. 17 // If the messages are not equal, returns an error. 18 func AreProtoEqual[T proto.Message](expected, found T, message string, args ...any) error { 19 areEqual := proto.Equal(expected, found) 20 if areEqual { 21 return nil 22 } 23 24 formattedMessage := fmt.Sprintf(message, args...) 25 26 return fmt.Errorf("%s\n\nExpected:\n%s\nActual:\n%s\nDiff:%s", 27 formattedMessage, 28 indent(prototext.Format(expected)), 29 indent(prototext.Format(found)), 30 cmp.Diff(prototext.Format(expected), prototext.Format(found))) 31 } 32 33 func indent(value string) string { 34 lines := strings.Split(value, "\n") 35 newLines := make([]string, 0, len(lines)) 36 for _, line := range lines { 37 newLines = append(newLines, "\t"+line) 38 } 39 return strings.Join(newLines, "\n") 40 } 41 42 // RequireProtoEqual ensures that the expected and required protocol buffer messages are equal, via proto.Equal. 43 func RequireProtoEqual[T proto.Message](t testing.TB, expected, found T, message string, args ...any) { 44 areEqual := AreProtoEqual(expected, found, message, args...) 45 require.NoError(t, areEqual) 46 } 47 48 func formatMessages[T proto.Message](messages []T) string { 49 formatted := make([]string, 0, len(messages)) 50 for _, message := range messages { 51 formatted = append(formatted, prototext.Format(message)) 52 } 53 return strings.Join(formatted, ",") 54 } 55 56 // AreProtoSlicesEqual returns whether the slices of protocol buffers are equal via protocol buffer comparison. 57 func AreProtoSlicesEqual[T proto.Message](expected, found []T, cmp func(a, b T) int, message string, args ...any) error { 58 formattedMessage := fmt.Sprintf(message, args...) 59 60 if len(expected) != len(found) { 61 return fmt.Errorf("%s\n\nFound different number of elements in slices: %d in expected, %d in actual\nExpected: %s\nActual: %s", 62 formattedMessage, 63 len(expected), 64 len(found), 65 formatMessages(expected), 66 formatMessages(found), 67 ) 68 } 69 70 if cmp != nil { 71 slices.SortFunc(expected, cmp) 72 slices.SortFunc(found, cmp) 73 } 74 75 for index := range expected { 76 err := AreProtoEqual(expected[index], found[index], "%s\n\nFound mismatch for element at index %d", formattedMessage, index) 77 if err != nil { 78 return err 79 } 80 } 81 82 return nil 83 } 84 85 // RequireProtoSlicesEqual ensures that the expected slices of protocol buffers are equal. The 86 // sort function is used to sort the messages before comparison. 87 func RequireProtoSlicesEqual[T proto.Message](t testing.TB, expected, found []T, cmp func(a, b T) int, message string, args ...any) { 88 err := AreProtoSlicesEqual(expected, found, cmp, message, args...) 89 require.NoError(t, err) 90 }