github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/revisions_test.go (about)

     1  package postgres
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/require"
     9  )
    10  
    11  const (
    12  	maxInt = ^uint64(0) >> 1
    13  )
    14  
    15  func TestRevisionOrdering(t *testing.T) {
    16  	testCases := []struct {
    17  		lhsSnapshot  pgSnapshot
    18  		rhsSnapshot  pgSnapshot
    19  		relationship comparisonResult
    20  	}{
    21  		{snap(0, 0), snap(0, 0), equal},
    22  		{snap(0, 5, 1), snap(0, 5, 1), equal},
    23  		{snap(0, 0), snap(1, 1), lt},
    24  		{snap(1, 1), snap(0, 0), gt},
    25  		{snap(1, 3, 1), snap(2, 3, 2), concurrent},
    26  		{snap(1, 2, 1), snap(2, 2), lt},
    27  		{snap(2, 2), snap(1, 2, 1), gt},
    28  	}
    29  
    30  	for _, tc := range testCases {
    31  		tc := tc
    32  		t.Run(fmt.Sprintf("%s:%s", tc.lhsSnapshot, tc.rhsSnapshot), func(t *testing.T) {
    33  			require := require.New(t)
    34  
    35  			lhs := postgresRevision{tc.lhsSnapshot}
    36  			rhs := postgresRevision{tc.rhsSnapshot}
    37  
    38  			require.Equal(tc.relationship == equal, lhs.Equal(rhs))
    39  			require.Equal(tc.relationship == equal, rhs.Equal(lhs))
    40  
    41  			require.Equal(tc.relationship == lt, lhs.LessThan(rhs))
    42  			require.Equal(tc.relationship == gt, lhs.GreaterThan(rhs))
    43  
    44  			require.Equal(tc.relationship == concurrent, !lhs.LessThan(rhs) && !lhs.GreaterThan(rhs) && !lhs.Equal(rhs))
    45  		})
    46  	}
    47  }
    48  
    49  func TestRevisionSerDe(t *testing.T) {
    50  	maxSizeList := make([]uint64, 20)
    51  	for i := range maxSizeList {
    52  		maxSizeList[i] = maxInt - uint64(len(maxSizeList)) + uint64(i)
    53  	}
    54  
    55  	testCases := []struct {
    56  		snapshot    pgSnapshot
    57  		expectedStr string
    58  	}{
    59  		{snap(0, 0), ""},
    60  		{snap(0, 5, 1), "EAUaAQE="},
    61  		{snap(1, 1), "CAE="},
    62  		{snap(1, 3, 1), "CAEQAhoBAA=="},
    63  		{snap(1, 2, 1), "CAEQARoBAA=="},
    64  		{snap(2, 2), "CAI="},
    65  		{snap(123, 123), "CHs="},
    66  		{snap(100, 150, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110), "CGQQMhoKAQIDBAUGBwgJCg=="},
    67  		{snap(maxSizeList[0], maxSizeList[len(maxSizeList)-1], maxSizeList...), "COv/////////fxATGhQAAQIDBAUGBwgJCgsMDQ4PEBESEw=="},
    68  	}
    69  
    70  	for _, tc := range testCases {
    71  		tc := tc
    72  		t.Run(tc.snapshot.String(), func(t *testing.T) {
    73  			require := require.New(t)
    74  
    75  			rev := postgresRevision{tc.snapshot}
    76  			serialized := rev.String()
    77  			require.Equal(tc.expectedStr, serialized)
    78  
    79  			parsed, err := parseRevisionProto(serialized)
    80  			require.NoError(err)
    81  			require.Equal(rev, parsed)
    82  		})
    83  	}
    84  }
    85  
    86  func TestRevisionParseOldDecimalFormat(t *testing.T) {
    87  	testCases := []struct {
    88  		snapshot  pgSnapshot
    89  		inputStr  string
    90  		expectErr bool
    91  	}{
    92  		{snap(1, 1), "0", false},
    93  		{snap(1, 1), "0.0", false},
    94  		{snap(501, 501), "500", false},
    95  		{snap(499, 501, 499), "500.499", false},
    96  		{snap(499, 507, 499, 500, 501, 502, 503, 504, 505), "506.499", false},
    97  		{snap(maxInt+1, maxInt+1), "9223372036854775807", false},
    98  		{snap(maxInt-1, maxInt+1, maxInt-1), "9223372036854775807.9223372036854775806", false},
    99  		{snap(0, 0), "-500", true},
   100  		{snap(0, 0), "", true},
   101  		{snap(0, 0), "deadbeef", true},
   102  		{snap(0, 0), "dead.beef", true},
   103  		{snap(0, 0), ".12345", true},
   104  		{snap(0, 0), "12345.", true},
   105  	}
   106  
   107  	for _, tc := range testCases {
   108  		tc := tc
   109  		t.Run(tc.snapshot.String(), func(t *testing.T) {
   110  			require := require.New(t)
   111  
   112  			parsed, err := parseRevisionDecimal(tc.inputStr)
   113  			if tc.expectErr {
   114  				require.Error(err)
   115  			} else {
   116  				require.NoError(err)
   117  				require.Equal(postgresRevision{tc.snapshot}, parsed)
   118  			}
   119  		})
   120  	}
   121  }
   122  
   123  func TestCombinedRevisionParsing(t *testing.T) {
   124  	testCases := []struct {
   125  		snapshot  pgSnapshot
   126  		inputStr  string
   127  		expectErr bool
   128  	}{
   129  		{snap(0, 0), "", false},
   130  		{snap(0, 5, 1), "EAUaAQE=", false},
   131  		{snap(1, 1), "CAE=", false},
   132  		{snap(1, 3, 1), "CAEQAhoBAA==", false},
   133  		{snap(1, 2, 1), "CAEQARoBAA==", false},
   134  		{snap(2, 2), "CAI=", false},
   135  		{snap(123, 123), "CHs=", false},
   136  		{snap(100, 150, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110), "CGQQMhoKAQIDBAUGBwgJCg==", false},
   137  		{snap(1, 1), "0", false},
   138  		{snap(1, 1), "0.0", false},
   139  		{snap(501, 501), "500", false},
   140  		{snap(499, 501, 499), "500.499", false},
   141  		{snap(499, 507, 499, 500, 501, 502, 503, 504, 505), "506.499", false},
   142  		{snap(maxInt+1, maxInt+1), "9223372036854775807", false},
   143  		{snap(maxInt-1, maxInt+1, maxInt-1), "9223372036854775807.9223372036854775806", false},
   144  		{snap(0, 0), "-500", true},
   145  		{snap(0, 0), "gobbleygook", true},
   146  		{snap(0, 0), "CAEQARoBAA", true},
   147  	}
   148  
   149  	for _, tc := range testCases {
   150  		tc := tc
   151  		t.Run(tc.snapshot.String(), func(t *testing.T) {
   152  			require := require.New(t)
   153  
   154  			parsed, err := ParseRevisionString(tc.inputStr)
   155  			if tc.expectErr {
   156  				require.Error(err)
   157  			} else {
   158  				require.NoError(err)
   159  				require.Equal(postgresRevision{tc.snapshot}, parsed)
   160  			}
   161  		})
   162  	}
   163  }
   164  
   165  func TestBrokenInvalidRevision(t *testing.T) {
   166  	_, err := ParseRevisionString("1693540940373045727.0000000001")
   167  	require.Error(t, err)
   168  }
   169  
   170  func FuzzRevision(f *testing.F) {
   171  	// Attempt to find a decimal revision that is a valid base64 encoded proto revision
   172  	f.Add(uint64(0), -1)
   173  	f.Add(uint64(0), 0)
   174  	f.Add(uint64(500), -1)
   175  	f.Add(uint64(500), 499)
   176  	f.Add(uint64(506), 499)
   177  	f.Add(uint64(9223372036854775807), -1)
   178  	f.Add(uint64(9223372036854775807), 9223372036854775806)
   179  
   180  	f.Fuzz(func(t *testing.T, a uint64, b int) {
   181  		decimalRev := strconv.FormatUint(a, 10)
   182  		if b >= 0 {
   183  			decimalRev = decimalRev + "." + strconv.Itoa(b)
   184  		}
   185  		rev, err := parseRevisionProto(decimalRev)
   186  		if err == nil && decimalRev != "" {
   187  			t.Errorf("decimal revision \"%s\" is a valid proto revision %#v", decimalRev, rev)
   188  		}
   189  	})
   190  }