github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/testutil/proto.go (about)

     1  package testutil
     2  
     3  import (
     4  	"fmt"
     5  	"slices"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/google/go-cmp/cmp"
    10  	"google.golang.org/protobuf/encoding/prototext"
    11  	"google.golang.org/protobuf/proto"
    12  
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  // AreProtoEqual returns whether the expected and required protocol buffer messages are equal, via proto.Equal.
    17  // If the messages are not equal, returns an error.
    18  func AreProtoEqual[T proto.Message](expected, found T, message string, args ...any) error {
    19  	areEqual := proto.Equal(expected, found)
    20  	if areEqual {
    21  		return nil
    22  	}
    23  
    24  	formattedMessage := fmt.Sprintf(message, args...)
    25  
    26  	return fmt.Errorf("%s\n\nExpected:\n%s\nActual:\n%s\nDiff:%s",
    27  		formattedMessage,
    28  		indent(prototext.Format(expected)),
    29  		indent(prototext.Format(found)),
    30  		cmp.Diff(prototext.Format(expected), prototext.Format(found)))
    31  }
    32  
    33  func indent(value string) string {
    34  	lines := strings.Split(value, "\n")
    35  	newLines := make([]string, 0, len(lines))
    36  	for _, line := range lines {
    37  		newLines = append(newLines, "\t"+line)
    38  	}
    39  	return strings.Join(newLines, "\n")
    40  }
    41  
    42  // RequireProtoEqual ensures that the expected and required protocol buffer messages are equal, via proto.Equal.
    43  func RequireProtoEqual[T proto.Message](t testing.TB, expected, found T, message string, args ...any) {
    44  	areEqual := AreProtoEqual(expected, found, message, args...)
    45  	require.NoError(t, areEqual)
    46  }
    47  
    48  func formatMessages[T proto.Message](messages []T) string {
    49  	formatted := make([]string, 0, len(messages))
    50  	for _, message := range messages {
    51  		formatted = append(formatted, prototext.Format(message))
    52  	}
    53  	return strings.Join(formatted, ",")
    54  }
    55  
    56  // AreProtoSlicesEqual returns whether the slices of protocol buffers are equal via protocol buffer comparison.
    57  func AreProtoSlicesEqual[T proto.Message](expected, found []T, cmp func(a, b T) int, message string, args ...any) error {
    58  	formattedMessage := fmt.Sprintf(message, args...)
    59  
    60  	if len(expected) != len(found) {
    61  		return fmt.Errorf("%s\n\nFound different number of elements in slices: %d in expected, %d in actual\nExpected: %s\nActual: %s",
    62  			formattedMessage,
    63  			len(expected),
    64  			len(found),
    65  			formatMessages(expected),
    66  			formatMessages(found),
    67  		)
    68  	}
    69  
    70  	if cmp != nil {
    71  		slices.SortFunc(expected, cmp)
    72  		slices.SortFunc(found, cmp)
    73  	}
    74  
    75  	for index := range expected {
    76  		err := AreProtoEqual(expected[index], found[index], "%s\n\nFound mismatch for element at index %d", formattedMessage, index)
    77  		if err != nil {
    78  			return err
    79  		}
    80  	}
    81  
    82  	return nil
    83  }
    84  
    85  // RequireProtoSlicesEqual ensures that the expected slices of protocol buffers are equal. The
    86  // sort function is used to sort the messages before comparison.
    87  func RequireProtoSlicesEqual[T proto.Message](t testing.TB, expected, found []T, cmp func(a, b T) int, message string, args ...any) {
    88  	err := AreProtoSlicesEqual(expected, found, cmp, message, args...)
    89  	require.NoError(t, err)
    90  }