github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/revisions_test.go (about) 1 package postgres 2 3 import ( 4 "fmt" 5 "strconv" 6 "testing" 7 8 "github.com/stretchr/testify/require" 9 ) 10 11 const ( 12 maxInt = ^uint64(0) >> 1 13 ) 14 15 func TestRevisionOrdering(t *testing.T) { 16 testCases := []struct { 17 lhsSnapshot pgSnapshot 18 rhsSnapshot pgSnapshot 19 relationship comparisonResult 20 }{ 21 {snap(0, 0), snap(0, 0), equal}, 22 {snap(0, 5, 1), snap(0, 5, 1), equal}, 23 {snap(0, 0), snap(1, 1), lt}, 24 {snap(1, 1), snap(0, 0), gt}, 25 {snap(1, 3, 1), snap(2, 3, 2), concurrent}, 26 {snap(1, 2, 1), snap(2, 2), lt}, 27 {snap(2, 2), snap(1, 2, 1), gt}, 28 } 29 30 for _, tc := range testCases { 31 tc := tc 32 t.Run(fmt.Sprintf("%s:%s", tc.lhsSnapshot, tc.rhsSnapshot), func(t *testing.T) { 33 require := require.New(t) 34 35 lhs := postgresRevision{tc.lhsSnapshot} 36 rhs := postgresRevision{tc.rhsSnapshot} 37 38 require.Equal(tc.relationship == equal, lhs.Equal(rhs)) 39 require.Equal(tc.relationship == equal, rhs.Equal(lhs)) 40 41 require.Equal(tc.relationship == lt, lhs.LessThan(rhs)) 42 require.Equal(tc.relationship == gt, lhs.GreaterThan(rhs)) 43 44 require.Equal(tc.relationship == concurrent, !lhs.LessThan(rhs) && !lhs.GreaterThan(rhs) && !lhs.Equal(rhs)) 45 }) 46 } 47 } 48 49 func TestRevisionSerDe(t *testing.T) { 50 maxSizeList := make([]uint64, 20) 51 for i := range maxSizeList { 52 maxSizeList[i] = maxInt - uint64(len(maxSizeList)) + uint64(i) 53 } 54 55 testCases := []struct { 56 snapshot pgSnapshot 57 expectedStr string 58 }{ 59 {snap(0, 0), ""}, 60 {snap(0, 5, 1), "EAUaAQE="}, 61 {snap(1, 1), "CAE="}, 62 {snap(1, 3, 1), "CAEQAhoBAA=="}, 63 {snap(1, 2, 1), "CAEQARoBAA=="}, 64 {snap(2, 2), "CAI="}, 65 {snap(123, 123), "CHs="}, 66 {snap(100, 150, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110), "CGQQMhoKAQIDBAUGBwgJCg=="}, 67 {snap(maxSizeList[0], maxSizeList[len(maxSizeList)-1], maxSizeList...), "COv/////////fxATGhQAAQIDBAUGBwgJCgsMDQ4PEBESEw=="}, 68 } 69 70 for _, tc := range testCases { 71 tc := tc 72 t.Run(tc.snapshot.String(), func(t *testing.T) { 73 require := require.New(t) 74 75 rev := postgresRevision{tc.snapshot} 76 serialized := rev.String() 77 require.Equal(tc.expectedStr, serialized) 78 79 parsed, err := parseRevisionProto(serialized) 80 require.NoError(err) 81 require.Equal(rev, parsed) 82 }) 83 } 84 } 85 86 func TestRevisionParseOldDecimalFormat(t *testing.T) { 87 testCases := []struct { 88 snapshot pgSnapshot 89 inputStr string 90 expectErr bool 91 }{ 92 {snap(1, 1), "0", false}, 93 {snap(1, 1), "0.0", false}, 94 {snap(501, 501), "500", false}, 95 {snap(499, 501, 499), "500.499", false}, 96 {snap(499, 507, 499, 500, 501, 502, 503, 504, 505), "506.499", false}, 97 {snap(maxInt+1, maxInt+1), "9223372036854775807", false}, 98 {snap(maxInt-1, maxInt+1, maxInt-1), "9223372036854775807.9223372036854775806", false}, 99 {snap(0, 0), "-500", true}, 100 {snap(0, 0), "", true}, 101 {snap(0, 0), "deadbeef", true}, 102 {snap(0, 0), "dead.beef", true}, 103 {snap(0, 0), ".12345", true}, 104 {snap(0, 0), "12345.", true}, 105 } 106 107 for _, tc := range testCases { 108 tc := tc 109 t.Run(tc.snapshot.String(), func(t *testing.T) { 110 require := require.New(t) 111 112 parsed, err := parseRevisionDecimal(tc.inputStr) 113 if tc.expectErr { 114 require.Error(err) 115 } else { 116 require.NoError(err) 117 require.Equal(postgresRevision{tc.snapshot}, parsed) 118 } 119 }) 120 } 121 } 122 123 func TestCombinedRevisionParsing(t *testing.T) { 124 testCases := []struct { 125 snapshot pgSnapshot 126 inputStr string 127 expectErr bool 128 }{ 129 {snap(0, 0), "", false}, 130 {snap(0, 5, 1), "EAUaAQE=", false}, 131 {snap(1, 1), "CAE=", false}, 132 {snap(1, 3, 1), "CAEQAhoBAA==", false}, 133 {snap(1, 2, 1), "CAEQARoBAA==", false}, 134 {snap(2, 2), "CAI=", false}, 135 {snap(123, 123), "CHs=", false}, 136 {snap(100, 150, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110), "CGQQMhoKAQIDBAUGBwgJCg==", false}, 137 {snap(1, 1), "0", false}, 138 {snap(1, 1), "0.0", false}, 139 {snap(501, 501), "500", false}, 140 {snap(499, 501, 499), "500.499", false}, 141 {snap(499, 507, 499, 500, 501, 502, 503, 504, 505), "506.499", false}, 142 {snap(maxInt+1, maxInt+1), "9223372036854775807", false}, 143 {snap(maxInt-1, maxInt+1, maxInt-1), "9223372036854775807.9223372036854775806", false}, 144 {snap(0, 0), "-500", true}, 145 {snap(0, 0), "gobbleygook", true}, 146 {snap(0, 0), "CAEQARoBAA", true}, 147 } 148 149 for _, tc := range testCases { 150 tc := tc 151 t.Run(tc.snapshot.String(), func(t *testing.T) { 152 require := require.New(t) 153 154 parsed, err := ParseRevisionString(tc.inputStr) 155 if tc.expectErr { 156 require.Error(err) 157 } else { 158 require.NoError(err) 159 require.Equal(postgresRevision{tc.snapshot}, parsed) 160 } 161 }) 162 } 163 } 164 165 func TestBrokenInvalidRevision(t *testing.T) { 166 _, err := ParseRevisionString("1693540940373045727.0000000001") 167 require.Error(t, err) 168 } 169 170 func FuzzRevision(f *testing.F) { 171 // Attempt to find a decimal revision that is a valid base64 encoded proto revision 172 f.Add(uint64(0), -1) 173 f.Add(uint64(0), 0) 174 f.Add(uint64(500), -1) 175 f.Add(uint64(500), 499) 176 f.Add(uint64(506), 499) 177 f.Add(uint64(9223372036854775807), -1) 178 f.Add(uint64(9223372036854775807), 9223372036854775806) 179 180 f.Fuzz(func(t *testing.T, a uint64, b int) { 181 decimalRev := strconv.FormatUint(a, 10) 182 if b >= 0 { 183 decimalRev = decimalRev + "." + strconv.Itoa(b) 184 } 185 rev, err := parseRevisionProto(decimalRev) 186 if err == nil && decimalRev != "" { 187 t.Errorf("decimal revision \"%s\" is a valid proto revision %#v", decimalRev, rev) 188 } 189 }) 190 }