github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/libs/bits/bit_array_test.go (about) 1 package bits 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "math" 7 "testing" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 12 tmrand "github.com/ari-anchor/sei-tendermint/libs/rand" 13 tmprotobits "github.com/ari-anchor/sei-tendermint/proto/tendermint/libs/bits" 14 ) 15 16 func randBitArray(bits int) *BitArray { 17 src := tmrand.Bytes((bits + 7) / 8) 18 bA := NewBitArray(bits) 19 for i := 0; i < len(src); i++ { 20 for j := 0; j < 8; j++ { 21 if i*8+j >= bits { 22 return bA 23 } 24 setBit := src[i]&(1<<uint(j)) > 0 25 bA.SetIndex(i*8+j, setBit) 26 } 27 } 28 return bA 29 } 30 31 func TestAnd(t *testing.T) { 32 33 bA1 := randBitArray(51) 34 bA2 := randBitArray(31) 35 bA3 := bA1.And(bA2) 36 37 var bNil *BitArray 38 require.Equal(t, bNil.And(bA1), (*BitArray)(nil)) 39 require.Equal(t, bA1.And(nil), (*BitArray)(nil)) 40 require.Equal(t, bNil.And(nil), (*BitArray)(nil)) 41 42 if bA3.Bits != 31 { 43 t.Error("Expected min bits", bA3.Bits) 44 } 45 if len(bA3.Elems) != len(bA2.Elems) { 46 t.Error("Expected min elems length") 47 } 48 for i := 0; i < bA3.Bits; i++ { 49 expected := bA1.GetIndex(i) && bA2.GetIndex(i) 50 if bA3.GetIndex(i) != expected { 51 t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i)) 52 } 53 } 54 } 55 56 func TestOr(t *testing.T) { 57 bA1 := randBitArray(51) 58 bA2 := randBitArray(31) 59 bA3 := bA1.Or(bA2) 60 61 bNil := (*BitArray)(nil) 62 require.Equal(t, bNil.Or(bA1), bA1) 63 require.Equal(t, bA1.Or(nil), bA1) 64 require.Equal(t, bNil.Or(nil), (*BitArray)(nil)) 65 66 if bA3.Bits != 51 { 67 t.Error("Expected max bits") 68 } 69 if len(bA3.Elems) != len(bA1.Elems) { 70 t.Error("Expected max elems length") 71 } 72 for i := 0; i < bA3.Bits; i++ { 73 expected := bA1.GetIndex(i) || bA2.GetIndex(i) 74 if bA3.GetIndex(i) != expected { 75 t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i)) 76 } 77 } 78 } 79 80 func TestSub(t *testing.T) { 81 testCases := []struct { 82 initBA string 83 subtractingBA string 84 expectedBA string 85 }{ 86 {`null`, `null`, `null`}, 87 {`"x"`, `null`, `null`}, 88 {`null`, `"x"`, `null`}, 89 {`"x"`, `"x"`, `"_"`}, 90 {`"xxxxxx"`, `"x_x_x_"`, `"_x_x_x"`}, 91 {`"x_x_x_"`, `"xxxxxx"`, `"______"`}, 92 {`"xxxxxx"`, `"x_x_x_xxxx"`, `"_x_x_x"`}, 93 {`"x_x_x_xxxx"`, `"xxxxxx"`, `"______xxxx"`}, 94 {`"xxxxxxxxxx"`, `"x_x_x_"`, `"_x_x_xxxxx"`}, 95 {`"x_x_x_"`, `"xxxxxxxxxx"`, `"______"`}, 96 } 97 for _, tc := range testCases { 98 var bA *BitArray 99 err := json.Unmarshal([]byte(tc.initBA), &bA) 100 require.NoError(t, err) 101 102 var o *BitArray 103 err = json.Unmarshal([]byte(tc.subtractingBA), &o) 104 require.NoError(t, err) 105 106 got, _ := json.Marshal(bA.Sub(o)) 107 require.Equal( 108 t, 109 tc.expectedBA, 110 string(got), 111 "%s minus %s doesn't equal %s", 112 tc.initBA, 113 tc.subtractingBA, 114 tc.expectedBA, 115 ) 116 } 117 } 118 119 func TestPickRandom(t *testing.T) { 120 empty16Bits := "________________" 121 empty64Bits := empty16Bits + empty16Bits + empty16Bits + empty16Bits 122 testCases := []struct { 123 bA string 124 ok bool 125 }{ 126 {`null`, false}, 127 {`"x"`, true}, 128 {`"` + empty16Bits + `"`, false}, 129 {`"x` + empty16Bits + `"`, true}, 130 {`"` + empty16Bits + `x"`, true}, 131 {`"x` + empty16Bits + `x"`, true}, 132 {`"` + empty64Bits + `"`, false}, 133 {`"x` + empty64Bits + `"`, true}, 134 {`"` + empty64Bits + `x"`, true}, 135 {`"x` + empty64Bits + `x"`, true}, 136 } 137 for _, tc := range testCases { 138 var bitArr *BitArray 139 err := json.Unmarshal([]byte(tc.bA), &bitArr) 140 require.NoError(t, err) 141 _, ok := bitArr.PickRandom() 142 require.Equal(t, tc.ok, ok, "PickRandom got an unexpected result on input %s", tc.bA) 143 } 144 } 145 146 func TestBytes(t *testing.T) { 147 bA := NewBitArray(4) 148 bA.SetIndex(0, true) 149 check := func(bA *BitArray, bz []byte) { 150 require.True(t, bytes.Equal(bA.Bytes(), bz), 151 "Expected %X but got %X", bz, bA.Bytes()) 152 } 153 check(bA, []byte{0x01}) 154 bA.SetIndex(3, true) 155 check(bA, []byte{0x09}) 156 157 bA = NewBitArray(9) 158 check(bA, []byte{0x00, 0x00}) 159 bA.SetIndex(7, true) 160 check(bA, []byte{0x80, 0x00}) 161 bA.SetIndex(8, true) 162 check(bA, []byte{0x80, 0x01}) 163 164 bA = NewBitArray(16) 165 check(bA, []byte{0x00, 0x00}) 166 bA.SetIndex(7, true) 167 check(bA, []byte{0x80, 0x00}) 168 bA.SetIndex(8, true) 169 check(bA, []byte{0x80, 0x01}) 170 bA.SetIndex(9, true) 171 check(bA, []byte{0x80, 0x03}) 172 173 require.False(t, bA.SetIndex(-1, true)) 174 } 175 176 func TestEmptyFull(t *testing.T) { 177 ns := []int{47, 123} 178 for _, n := range ns { 179 bA := NewBitArray(n) 180 if !bA.IsEmpty() { 181 t.Fatal("Expected bit array to be empty") 182 } 183 for i := 0; i < n; i++ { 184 bA.SetIndex(i, true) 185 } 186 if !bA.IsFull() { 187 t.Fatal("Expected bit array to be full") 188 } 189 } 190 } 191 192 func TestUpdateNeverPanics(t *testing.T) { 193 newRandBitArray := func(n int) *BitArray { return randBitArray(n) } 194 pairs := []struct { 195 a, b *BitArray 196 }{ 197 {nil, nil}, 198 {newRandBitArray(10), newRandBitArray(12)}, 199 {newRandBitArray(23), newRandBitArray(23)}, 200 {newRandBitArray(37), nil}, 201 {nil, NewBitArray(10)}, 202 } 203 204 for _, pair := range pairs { 205 a, b := pair.a, pair.b 206 a.Update(b) 207 b.Update(a) 208 } 209 } 210 211 func TestNewBitArrayNeverCrashesOnNegatives(t *testing.T) { 212 bitList := []int{-127, -128, -1 << 31} 213 for _, bits := range bitList { 214 _ = NewBitArray(bits) 215 } 216 } 217 218 func TestJSONMarshalUnmarshal(t *testing.T) { 219 220 bA1 := NewBitArray(0) 221 222 bA2 := NewBitArray(1) 223 224 bA3 := NewBitArray(1) 225 bA3.SetIndex(0, true) 226 227 bA4 := NewBitArray(5) 228 bA4.SetIndex(0, true) 229 bA4.SetIndex(1, true) 230 231 testCases := []struct { 232 bA *BitArray 233 marshalledBA string 234 }{ 235 {nil, `null`}, 236 {bA1, `null`}, 237 {bA2, `"_"`}, 238 {bA3, `"x"`}, 239 {bA4, `"xx___"`}, 240 } 241 242 for _, tc := range testCases { 243 tc := tc 244 t.Run(tc.bA.String(), func(t *testing.T) { 245 bz, err := json.Marshal(tc.bA) 246 require.NoError(t, err) 247 248 assert.Equal(t, tc.marshalledBA, string(bz)) 249 250 var unmarshalledBA *BitArray 251 err = json.Unmarshal(bz, &unmarshalledBA) 252 require.NoError(t, err) 253 254 if tc.bA == nil { 255 require.Nil(t, unmarshalledBA) 256 } else { 257 require.NotNil(t, unmarshalledBA) 258 assert.EqualValues(t, tc.bA.Bits, unmarshalledBA.Bits) 259 if assert.EqualValues(t, tc.bA.String(), unmarshalledBA.String()) { 260 assert.EqualValues(t, tc.bA.Elems, unmarshalledBA.Elems) 261 } 262 } 263 }) 264 } 265 } 266 267 func TestBitArrayToFromProto(t *testing.T) { 268 testCases := []struct { 269 msg string 270 bA1 *BitArray 271 expPass bool 272 }{ 273 {"success empty", &BitArray{}, true}, 274 {"success", NewBitArray(1), true}, 275 {"success", NewBitArray(2), true}, 276 {"negative", NewBitArray(-1), false}, 277 } 278 for _, tc := range testCases { 279 protoBA := tc.bA1.ToProto() 280 ba := new(BitArray) 281 err := ba.FromProto(protoBA) 282 if tc.expPass { 283 assert.NoError(t, err) 284 require.Equal(t, tc.bA1, ba, tc.msg) 285 } else { 286 require.NotEqual(t, tc.bA1, ba, tc.msg) 287 } 288 } 289 } 290 291 func TestBitArrayFromProto(t *testing.T) { 292 testCases := []struct { 293 pbA *tmprotobits.BitArray 294 resA *BitArray 295 expErr bool 296 }{ 297 0: {nil, &BitArray{}, false}, 298 1: {&tmprotobits.BitArray{}, &BitArray{Elems: []uint64{}}, false}, 299 300 2: {&tmprotobits.BitArray{Bits: 1, Elems: make([]uint64, 1)}, &BitArray{Bits: 1, Elems: make([]uint64, 1)}, false}, 301 302 3: {&tmprotobits.BitArray{Bits: -1, Elems: make([]uint64, 1)}, &BitArray{}, true}, 303 4: {&tmprotobits.BitArray{Bits: math.MaxInt32 + 1, Elems: make([]uint64, 1)}, &BitArray{}, true}, 304 5: {&tmprotobits.BitArray{Bits: 1, Elems: make([]uint64, 2)}, &BitArray{}, true}, 305 } 306 307 for i, tc := range testCases { 308 bA := new(BitArray) 309 err := bA.FromProto(tc.pbA) 310 if tc.expErr { 311 assert.Error(t, err, "#%d", i) 312 assert.Equal(t, tc.resA, bA, "#%d", i) 313 } else { 314 assert.NoError(t, err, "#%d", i) 315 assert.Equal(t, tc.resA, bA, "#%d", i) 316 } 317 } 318 }