github.com/MetalBlockchain/metalgo@v1.11.9/x/merkledb/key_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package merkledb 5 6 import ( 7 "fmt" 8 "strconv" 9 "testing" 10 11 "github.com/stretchr/testify/require" 12 ) 13 14 func TestBranchFactor_Valid(t *testing.T) { 15 require := require.New(t) 16 for _, bf := range validBranchFactors { 17 require.NoError(bf.Valid()) 18 } 19 var empty BranchFactor 20 err := empty.Valid() 21 require.ErrorIs(err, ErrInvalidBranchFactor) 22 } 23 24 func TestHasPartialByte(t *testing.T) { 25 for _, ts := range validTokenSizes { 26 t.Run(strconv.Itoa(ts), func(t *testing.T) { 27 require := require.New(t) 28 29 key := Key{} 30 require.False(key.hasPartialByte()) 31 32 if ts == 8 { 33 // Tokens are an entire byte so 34 // there is never a partial byte. 35 key = key.Extend(ToToken(1, ts)) 36 require.False(key.hasPartialByte()) 37 key = key.Extend(ToToken(0, ts)) 38 require.False(key.hasPartialByte()) 39 return 40 } 41 42 // Fill all but the last token of the first byte. 43 for i := 0; i < 8-ts; i += ts { 44 key = key.Extend(ToToken(1, ts)) 45 require.True(key.hasPartialByte()) 46 } 47 48 // Fill the last token of the first byte. 49 key = key.Extend(ToToken(0, ts)) 50 require.False(key.hasPartialByte()) 51 52 // Fill the first token of the second byte. 53 key = key.Extend(ToToken(0, ts)) 54 require.True(key.hasPartialByte()) 55 }) 56 } 57 } 58 59 func Test_Key_Has_Prefix(t *testing.T) { 60 type test struct { 61 name string 62 keyA func(ts int) Key 63 keyB func(ts int) Key 64 isStrictPrefix bool 65 isPrefix bool 66 } 67 68 key := "Key" 69 70 tests := []test{ 71 { 72 name: "equal keys", 73 keyA: func(int) Key { return ToKey([]byte(key)) }, 74 keyB: func(int) Key { return ToKey([]byte(key)) }, 75 isPrefix: true, 76 isStrictPrefix: false, 77 }, 78 { 79 name: "one key has one fewer token", 80 keyA: func(int) Key { return ToKey([]byte(key)) }, 81 keyB: func(ts int) Key { 82 return ToKey([]byte(key)).Take(len(key)*8 - ts) 83 }, 84 isPrefix: true, 85 isStrictPrefix: true, 86 }, 87 { 88 name: "equal keys, both have one fewer token", 89 keyA: func(ts int) Key { 90 return ToKey([]byte(key)).Take(len(key)*8 - ts) 91 }, 92 keyB: func(ts int) Key { 93 return ToKey([]byte(key)).Take(len(key)*8 - ts) 94 }, 95 isPrefix: true, 96 isStrictPrefix: false, 97 }, 98 { 99 name: "different keys", 100 keyA: func(int) Key { return ToKey([]byte{0xF7}) }, 101 keyB: func(int) Key { return ToKey([]byte{0xF0}) }, 102 isPrefix: false, 103 isStrictPrefix: false, 104 }, 105 { 106 name: "same bytes, different lengths", 107 keyA: func(ts int) Key { 108 return ToKey([]byte{0x10, 0x00}).Take(ts) 109 }, 110 keyB: func(ts int) Key { 111 return ToKey([]byte{0x10, 0x00}).Take(ts * 2) 112 }, 113 isPrefix: false, 114 isStrictPrefix: false, 115 }, 116 } 117 118 for _, tt := range tests { 119 for _, ts := range validTokenSizes { 120 t.Run(tt.name+" ts "+strconv.Itoa(ts), func(t *testing.T) { 121 require := require.New(t) 122 keyA := tt.keyA(ts) 123 keyB := tt.keyB(ts) 124 125 require.Equal(tt.isPrefix, keyA.HasPrefix(keyB)) 126 require.Equal(tt.isPrefix, keyA.iteratedHasPrefix(keyB, 0, ts)) 127 require.Equal(tt.isStrictPrefix, keyA.HasStrictPrefix(keyB)) 128 }) 129 } 130 } 131 } 132 133 func Test_Key_Skip(t *testing.T) { 134 require := require.New(t) 135 136 empty := Key{} 137 require.Equal(ToKey([]byte{0}).Skip(8), empty) 138 for _, ts := range validTokenSizes { 139 if ts == 8 { 140 continue 141 } 142 shortKey := ToKey([]byte{0b0101_0101}) 143 longKey := ToKey([]byte{0b0101_0101, 0b0101_0101}) 144 for shift := 0; shift < 8; shift += ts { 145 skipKey := shortKey.Skip(shift) 146 require.Equal(byte(0b0101_0101<<shift), skipKey.value[0]) 147 148 skipKey = longKey.Skip(shift) 149 require.Equal(byte(0b0101_0101<<shift+0b0101_0101>>(8-shift)), skipKey.value[0]) 150 require.Equal(byte(0b0101_0101<<shift), skipKey.value[1]) 151 } 152 } 153 154 skip := ToKey([]byte{0b0101_0101, 0b1010_1010}).Skip(8) 155 require.Len(skip.value, 1) 156 require.Equal(byte(0b1010_1010), skip.value[0]) 157 158 skip = ToKey([]byte{0b0101_0101, 0b1010_1010, 0b0101_0101}).Skip(8) 159 require.Len(skip.value, 2) 160 require.Equal(byte(0b1010_1010), skip.value[0]) 161 require.Equal(byte(0b0101_0101), skip.value[1]) 162 } 163 164 func Test_Key_Take(t *testing.T) { 165 require := require.New(t) 166 167 require.Equal(Key{}, ToKey([]byte{0}).Take(0)) 168 169 for _, ts := range validTokenSizes { 170 if ts == 8 { 171 continue 172 } 173 key := ToKey([]byte{0b0101_0101}) 174 for length := ts; length <= 8; length += ts { 175 take := key.Take(length) 176 require.Equal(length, take.length) 177 shift := 8 - length 178 require.Equal(byte((0b0101_0101>>shift)<<shift), take.value[0]) 179 } 180 } 181 182 take := ToKey([]byte{0b0101_0101, 0b1010_1010}).Take(8) 183 require.Len(take.value, 1) 184 require.Equal(byte(0b0101_0101), take.value[0]) 185 } 186 187 func Test_Key_Token(t *testing.T) { 188 type test struct { 189 name string 190 inputBytes []byte 191 ts int 192 assertTokens func(*require.Assertions, Key) 193 } 194 195 tests := []test{ 196 { 197 name: "branch factor 2", 198 inputBytes: []byte{0b0_1_0_1_0_1_0_1, 0b1_0_1_0_1_0_1_0}, 199 ts: 1, 200 assertTokens: func(require *require.Assertions, key Key) { 201 require.Equal(byte(0), key.Token(0, 1)) 202 require.Equal(byte(1), key.Token(1, 1)) 203 require.Equal(byte(0), key.Token(2, 1)) 204 require.Equal(byte(1), key.Token(3, 1)) 205 require.Equal(byte(0), key.Token(4, 1)) 206 require.Equal(byte(1), key.Token(5, 1)) 207 require.Equal(byte(0), key.Token(6, 1)) 208 require.Equal(byte(1), key.Token(7, 1)) // end first byte 209 require.Equal(byte(1), key.Token(8, 1)) 210 require.Equal(byte(0), key.Token(9, 1)) 211 require.Equal(byte(1), key.Token(10, 1)) 212 require.Equal(byte(0), key.Token(11, 1)) 213 require.Equal(byte(1), key.Token(12, 1)) 214 require.Equal(byte(0), key.Token(13, 1)) 215 require.Equal(byte(1), key.Token(14, 1)) 216 require.Equal(byte(0), key.Token(15, 1)) // end second byte 217 }, 218 }, 219 { 220 name: "branch factor 4", 221 inputBytes: []byte{0b00_01_10_11, 0b11_10_01_00}, 222 ts: 2, 223 assertTokens: func(require *require.Assertions, key Key) { 224 require.Equal(byte(0), key.Token(0, 2)) // 00 225 require.Equal(byte(1), key.Token(2, 2)) // 01 226 require.Equal(byte(2), key.Token(4, 2)) // 10 227 require.Equal(byte(3), key.Token(6, 2)) // 11 end first byte 228 require.Equal(byte(3), key.Token(8, 2)) // 11 229 require.Equal(byte(2), key.Token(10, 2)) // 10 230 require.Equal(byte(1), key.Token(12, 2)) // 01 231 require.Equal(byte(0), key.Token(14, 2)) // 00 end second byte 232 }, 233 }, 234 { 235 name: "branch factor 16", 236 inputBytes: []byte{ 237 0b0000_0001, 238 0b0010_0011, 239 0b0100_0101, 240 0b0110_0111, 241 0b1000_1001, 242 0b1010_1011, 243 0b1100_1101, 244 0b1110_1111, 245 }, 246 ts: 4, 247 assertTokens: func(require *require.Assertions, key Key) { 248 for i := 0; i < 16; i++ { 249 require.Equal(byte(i), key.Token(i*4, 4)) 250 } 251 }, 252 }, 253 } 254 255 for i := 0; i < 256; i++ { 256 i := i 257 tests = append(tests, test{ 258 name: fmt.Sprintf("branch factor 256, byte %d", i), 259 inputBytes: []byte{byte(i)}, 260 ts: 8, 261 assertTokens: func(require *require.Assertions, key Key) { 262 require.Equal(byte(i), key.Token(0, 8)) 263 }, 264 }) 265 } 266 267 for _, tt := range tests { 268 t.Run(tt.name, func(t *testing.T) { 269 require := require.New(t) 270 key := ToKey(tt.inputBytes) 271 tt.assertTokens(require, key) 272 }) 273 } 274 } 275 276 func Test_Key_Append(t *testing.T) { 277 require := require.New(t) 278 279 key := ToKey([]byte{}) 280 for _, bf := range validBranchFactors { 281 size := BranchFactorToTokenSize[bf] 282 for i := 0; i < int(bf); i++ { 283 appendedKey := key.Extend(ToToken(byte(i), size), ToToken(byte(i/2), size)) 284 require.Equal(byte(i), appendedKey.Token(0, size)) 285 require.Equal(byte(i/2), appendedKey.Token(size, size)) 286 } 287 } 288 } 289 290 func Test_Key_AppendExtend(t *testing.T) { 291 require := require.New(t) 292 293 key2 := ToKey([]byte{0b1000_0000}).Take(1) 294 p := ToKey([]byte{0b01010101}) 295 extendedP := key2.Extend(ToToken(0, 1), p) 296 require.Equal([]byte{0b10010101, 0b01000_000}, extendedP.Bytes()) 297 require.Equal(byte(1), extendedP.Token(0, 1)) 298 require.Equal(byte(0), extendedP.Token(1, 1)) 299 require.Equal(byte(0), extendedP.Token(2, 1)) 300 require.Equal(byte(1), extendedP.Token(3, 1)) 301 require.Equal(byte(0), extendedP.Token(4, 1)) 302 require.Equal(byte(1), extendedP.Token(5, 1)) 303 require.Equal(byte(0), extendedP.Token(6, 1)) 304 require.Equal(byte(1), extendedP.Token(7, 1)) 305 require.Equal(byte(0), extendedP.Token(8, 1)) 306 require.Equal(byte(1), extendedP.Token(9, 1)) 307 308 p = ToKey([]byte{0b0101_0101, 0b1000_0000}).Take(9) 309 extendedP = key2.Extend(ToToken(0, 1), p) 310 require.Equal([]byte{0b1001_0101, 0b0110_0000}, extendedP.Bytes()) 311 require.Equal(byte(1), extendedP.Token(0, 1)) 312 require.Equal(byte(0), extendedP.Token(1, 1)) 313 require.Equal(byte(0), extendedP.Token(2, 1)) 314 require.Equal(byte(1), extendedP.Token(3, 1)) 315 require.Equal(byte(0), extendedP.Token(4, 1)) 316 require.Equal(byte(1), extendedP.Token(5, 1)) 317 require.Equal(byte(0), extendedP.Token(6, 1)) 318 require.Equal(byte(1), extendedP.Token(7, 1)) 319 require.Equal(byte(0), extendedP.Token(8, 1)) 320 require.Equal(byte(1), extendedP.Token(9, 1)) 321 require.Equal(byte(1), extendedP.Token(10, 1)) 322 323 key4 := ToKey([]byte{0b0100_0000}).Take(2) 324 p = ToKey([]byte{0b0101_0101}) 325 extendedP = key4.Extend(ToToken(0, 2), p) 326 require.Equal([]byte{0b0100_0101, 0b0101_0000}, extendedP.Bytes()) 327 require.Equal(byte(1), extendedP.Token(0, 2)) 328 require.Equal(byte(0), extendedP.Token(2, 2)) 329 require.Equal(byte(1), extendedP.Token(4, 2)) 330 require.Equal(byte(1), extendedP.Token(6, 2)) 331 require.Equal(byte(1), extendedP.Token(8, 2)) 332 require.Equal(byte(1), extendedP.Token(10, 2)) 333 334 key16 := ToKey([]byte{0b0001_0000}).Take(4) 335 p = ToKey([]byte{0b0001_0001}) 336 extendedP = key16.Extend(ToToken(0, 4), p) 337 require.Equal([]byte{0b0001_0000, 0b0001_0001}, extendedP.Bytes()) 338 require.Equal(byte(1), extendedP.Token(0, 4)) 339 require.Equal(byte(0), extendedP.Token(4, 4)) 340 require.Equal(byte(1), extendedP.Token(8, 4)) 341 require.Equal(byte(1), extendedP.Token(12, 4)) 342 343 p = ToKey([]byte{0b0001_0001, 0b0001_0001}) 344 extendedP = key16.Extend(ToToken(0, 4), p) 345 require.Equal([]byte{0b0001_0000, 0b0001_0001, 0b0001_0001}, extendedP.Bytes()) 346 require.Equal(byte(1), extendedP.Token(0, 4)) 347 require.Equal(byte(0), extendedP.Token(4, 4)) 348 require.Equal(byte(1), extendedP.Token(8, 4)) 349 require.Equal(byte(1), extendedP.Token(12, 4)) 350 require.Equal(byte(1), extendedP.Token(16, 4)) 351 require.Equal(byte(1), extendedP.Token(20, 4)) 352 353 key256 := ToKey([]byte{0b0000_0001}) 354 p = ToKey([]byte{0b0000_0001}) 355 extendedP = key256.Extend(ToToken(0, 8), p) 356 require.Equal([]byte{0b0000_0001, 0b0000_0000, 0b0000_0001}, extendedP.Bytes()) 357 require.Equal(byte(1), extendedP.Token(0, 8)) 358 require.Equal(byte(0), extendedP.Token(8, 8)) 359 require.Equal(byte(1), extendedP.Token(16, 8)) 360 } 361 362 func TestKeyBytesNeeded(t *testing.T) { 363 type test struct { 364 bitLength int 365 bytesNeeded int 366 } 367 368 tests := []test{ 369 { 370 bitLength: 7, 371 bytesNeeded: 1, 372 }, 373 { 374 bitLength: 8, 375 bytesNeeded: 1, 376 }, 377 { 378 bitLength: 9, 379 bytesNeeded: 2, 380 }, 381 { 382 bitLength: 0, 383 bytesNeeded: 0, 384 }, 385 } 386 387 for _, tt := range tests { 388 t.Run(fmt.Sprintf("bit length %d", tt.bitLength), func(t *testing.T) { 389 require := require.New(t) 390 require.Equal(tt.bytesNeeded, bytesNeeded(tt.bitLength)) 391 }) 392 } 393 } 394 395 func FuzzKeyDoubleExtend_Tokens(f *testing.F) { 396 f.Fuzz(func( 397 t *testing.T, 398 first []byte, 399 second []byte, 400 tokenByte byte, 401 forceFirstOdd bool, 402 forceSecondOdd bool, 403 ) { 404 require := require.New(t) 405 for _, ts := range validTokenSizes { 406 key1 := ToKey(first) 407 if forceFirstOdd && key1.length > ts { 408 key1 = key1.Take(key1.length - ts) 409 } 410 key2 := ToKey(second) 411 if forceSecondOdd && key2.length > ts { 412 key2 = key2.Take(key2.length - ts) 413 } 414 token := byte(int(tokenByte) % int(tokenSizeToBranchFactor[ts])) 415 extendedP := key1.Extend(ToToken(token, ts), key2) 416 require.Equal(key1.length+key2.length+ts, extendedP.length) 417 firstIndex := 0 418 for ; firstIndex < key1.length; firstIndex += ts { 419 require.Equal(key1.Token(firstIndex, ts), extendedP.Token(firstIndex, ts)) 420 } 421 require.Equal(token, extendedP.Token(firstIndex, ts)) 422 firstIndex += ts 423 for secondIndex := 0; secondIndex < key2.length; secondIndex += ts { 424 require.Equal(key2.Token(secondIndex, ts), extendedP.Token(firstIndex+secondIndex, ts)) 425 } 426 } 427 }) 428 } 429 430 func FuzzKeyDoubleExtend_Any(f *testing.F) { 431 f.Fuzz(func( 432 t *testing.T, 433 baseKeyBytes []byte, 434 firstKeyBytes []byte, 435 secondKeyBytes []byte, 436 forceBaseOdd bool, 437 forceFirstOdd bool, 438 forceSecondOdd bool, 439 ) { 440 require := require.New(t) 441 for _, ts := range validTokenSizes { 442 baseKey := ToKey(baseKeyBytes) 443 if forceBaseOdd && baseKey.length > ts { 444 baseKey = baseKey.Take(baseKey.length - ts) 445 } 446 firstKey := ToKey(firstKeyBytes) 447 if forceFirstOdd && firstKey.length > ts { 448 firstKey = firstKey.Take(firstKey.length - ts) 449 } 450 451 secondKey := ToKey(secondKeyBytes) 452 if forceSecondOdd && secondKey.length > ts { 453 secondKey = secondKey.Take(secondKey.length - ts) 454 } 455 456 extendedP := baseKey.Extend(firstKey, secondKey) 457 require.Equal(baseKey.length+firstKey.length+secondKey.length, extendedP.length) 458 totalIndex := 0 459 for baseIndex := 0; baseIndex < baseKey.length; baseIndex += ts { 460 require.Equal(baseKey.Token(baseIndex, ts), extendedP.Token(baseIndex, ts)) 461 } 462 totalIndex += baseKey.length 463 for firstIndex := 0; firstIndex < firstKey.length; firstIndex += ts { 464 require.Equal(firstKey.Token(firstIndex, ts), extendedP.Token(totalIndex+firstIndex, ts)) 465 } 466 totalIndex += firstKey.length 467 for secondIndex := 0; secondIndex < secondKey.length; secondIndex += ts { 468 require.Equal(secondKey.Token(secondIndex, ts), extendedP.Token(totalIndex+secondIndex, ts)) 469 } 470 } 471 }) 472 } 473 474 func FuzzKeySkip(f *testing.F) { 475 f.Fuzz(func( 476 t *testing.T, 477 first []byte, 478 tokensToSkip uint, 479 ) { 480 require := require.New(t) 481 key1 := ToKey(first) 482 for _, ts := range validTokenSizes { 483 // need bits to be a multiple of token size 484 ubitsToSkip := tokensToSkip * uint(ts) 485 if ubitsToSkip >= uint(key1.length) { 486 t.SkipNow() 487 } 488 bitsToSkip := int(ubitsToSkip) 489 key2 := key1.Skip(bitsToSkip) 490 require.Equal(key1.length-bitsToSkip, key2.length) 491 for i := 0; i < key2.length; i += ts { 492 require.Equal(key1.Token(bitsToSkip+i, ts), key2.Token(i, ts)) 493 } 494 } 495 }) 496 } 497 498 func FuzzKeyTake(f *testing.F) { 499 f.Fuzz(func( 500 t *testing.T, 501 first []byte, 502 uTokensToTake uint, 503 ) { 504 require := require.New(t) 505 for _, ts := range validTokenSizes { 506 key1 := ToKey(first) 507 uBitsToTake := uTokensToTake * uint(ts) 508 if uBitsToTake >= uint(key1.length) { 509 t.SkipNow() 510 } 511 bitsToTake := int(uBitsToTake) 512 key2 := key1.Take(bitsToTake) 513 require.Equal(bitsToTake, key2.length) 514 if key2.hasPartialByte() { 515 paddingMask := byte(0xFF >> (key2.length % 8)) 516 require.Zero(key2.value[len(key2.value)-1] & paddingMask) 517 } 518 for i := 0; i < bitsToTake; i += ts { 519 require.Equal(key1.Token(i, ts), key2.Token(i, ts)) 520 } 521 } 522 }) 523 } 524 525 func TestShiftCopy(t *testing.T) { 526 type test struct { 527 dst []byte 528 src []byte 529 expected []byte 530 shift int 531 } 532 533 tests := []test{ 534 { 535 dst: []byte{}, 536 src: []byte{}, 537 expected: []byte{}, 538 shift: 0, 539 }, 540 { 541 dst: []byte{}, 542 src: []byte{}, 543 expected: []byte{}, 544 shift: 1, 545 }, 546 { 547 dst: make([]byte, 1), 548 src: []byte{0b0000_0001}, 549 expected: []byte{0b0000_0010}, 550 shift: 1, 551 }, 552 { 553 dst: make([]byte, 1), 554 src: []byte{0b0000_0001}, 555 expected: []byte{0b0000_0100}, 556 shift: 2, 557 }, 558 { 559 dst: make([]byte, 1), 560 src: []byte{0b0000_0001}, 561 expected: []byte{0b1000_0000}, 562 shift: 7, 563 }, 564 { 565 dst: make([]byte, 2), 566 src: []byte{0b0000_0001, 0b1000_0001}, 567 expected: []byte{0b0000_0011, 0b0000_0010}, 568 shift: 1, 569 }, 570 { 571 dst: make([]byte, 1), 572 src: []byte{0b0000_0001, 0b1000_0001}, 573 expected: []byte{0b0000_0011}, 574 shift: 1, 575 }, 576 { 577 dst: make([]byte, 2), 578 src: []byte{0b0000_0001, 0b1000_0001}, 579 expected: []byte{0b1100_0000, 0b1000_0000}, 580 shift: 7, 581 }, 582 { 583 dst: make([]byte, 1), 584 src: []byte{0b0000_0001, 0b1000_0001}, 585 expected: []byte{0b1100_0000}, 586 shift: 7, 587 }, 588 { 589 dst: make([]byte, 2), 590 src: []byte{0b0000_0001, 0b1000_0001}, 591 expected: []byte{0b1000_0001, 0b0000_0000}, 592 shift: 8, 593 }, 594 { 595 dst: make([]byte, 1), 596 src: []byte{0b0000_0001, 0b1000_0001}, 597 expected: []byte{0b1000_0001}, 598 shift: 8, 599 }, 600 { 601 dst: make([]byte, 2), 602 src: []byte{0b0000_0001, 0b1000_0001, 0b1111_0101}, 603 expected: []byte{0b0000_0110, 0b000_0111}, 604 shift: 2, 605 }, 606 } 607 608 for _, tt := range tests { 609 t.Run(fmt.Sprintf("dst: %v, src: %v", tt.dst, tt.src), func(t *testing.T) { 610 shiftCopy(tt.dst, string(tt.src), tt.shift) 611 require.Equal(t, tt.expected, tt.dst) 612 }) 613 } 614 }