
     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     4  package merkledb
     6  import (
     7  	"cmp"
     8  	"errors"
     9  	"fmt"
    10  	"slices"
    11  	"strings"
    12  	"unsafe"
    14  	""
    15  )
    17  var (
    18  	ErrInvalidBranchFactor = errors.New("branch factor must match one of the predefined branch factors")
    20  	BranchFactorToTokenSize = map[BranchFactor]int{
    21  		BranchFactor2:   1,
    22  		BranchFactor4:   2,
    23  		BranchFactor16:  4,
    24  		BranchFactor256: 8,
    25  	}
    27  	tokenSizeToBranchFactor = map[int]BranchFactor{
    28  		1: BranchFactor2,
    29  		2: BranchFactor4,
    30  		4: BranchFactor16,
    31  		8: BranchFactor256,
    32  	}
    34  	validTokenSizes = maps.Keys(tokenSizeToBranchFactor)
    36  	validBranchFactors = []BranchFactor{
    37  		BranchFactor2,
    38  		BranchFactor4,
    39  		BranchFactor16,
    40  		BranchFactor256,
    41  	}
    42  )
    44  type BranchFactor int
    46  const (
    47  	BranchFactor2   = BranchFactor(2)
    48  	BranchFactor4   = BranchFactor(4)
    49  	BranchFactor16  = BranchFactor(16)
    50  	BranchFactor256 = BranchFactor(256)
    52  	BranchFactorLargest = BranchFactor256
    53  )
    55  // Valid checks if BranchFactor [b] is one of the predefined valid options for BranchFactor
    56  func (b BranchFactor) Valid() error {
    57  	for _, validBF := range validBranchFactors {
    58  		if validBF == b {
    59  			return nil
    60  		}
    61  	}
    62  	return fmt.Errorf("%w: %d", ErrInvalidBranchFactor, b)
    63  }
    65  // ToToken creates a key version of the passed byte with bit length equal to tokenSize
    66  func ToToken(val byte, tokenSize int) Key {
    67  	return Key{
    68  		value:  string([]byte{val << dualBitIndex(tokenSize)}),
    69  		length: tokenSize,
    70  	}
    71  }
    73  // Token returns the token at the specified index,
    74  // Assumes that bitIndex + tokenSize doesn't cross a byte boundary
    75  func (k Key) Token(bitIndex int, tokenSize int) byte {
    76  	storageByte := k.value[bitIndex/8]
    77  	// Shift the byte right to get the last bit to the rightmost position.
    78  	storageByte >>= dualBitIndex((bitIndex + tokenSize) % 8)
    79  	// Apply a mask to remove any other bits in the byte.
    80  	return storageByte & (0xFF >> dualBitIndex(tokenSize))
    81  }
    83  // iteratedHasPrefix checks if the provided prefix key is a prefix of the current key starting after the [bitsOffset]th bit
    84  // this has better performance than constructing the actual key via Skip() then calling HasPrefix because it avoids an allocation
    85  func (k Key) iteratedHasPrefix(prefix Key, bitsOffset int, tokenSize int) bool {
    86  	if k.length-bitsOffset < prefix.length {
    87  		return false
    88  	}
    89  	for i := 0; i < prefix.length; i += tokenSize {
    90  		if k.Token(bitsOffset+i, tokenSize) != prefix.Token(i, tokenSize) {
    91  			return false
    92  		}
    93  	}
    94  	return true
    95  }
    97  type Key struct {
    98  	// The number of bits in the key.
    99  	length int
   100  	// The string representation of the key
   101  	value string
   102  }
   104  // ToKey returns [keyBytes] as a new key
   105  // Assumes all bits of the keyBytes are part of the Key, call Key.Take if that is not the case
   106  // Creates a copy of [keyBytes], so keyBytes are safe to edit after the call
   107  func ToKey(keyBytes []byte) Key {
   108  	return toKey(slices.Clone(keyBytes))
   109  }
   111  // toKey returns [keyBytes] as a new key
   112  // Assumes all bits of the keyBytes are part of the Key, call Key.Take if that is not the case
   113  // Caller must not modify [keyBytes] after this call.
   114  func toKey(keyBytes []byte) Key {
   115  	return Key{
   116  		value:  byteSliceToString(keyBytes),
   117  		length: len(keyBytes) * 8,
   118  	}
   119  }
   121  // hasPartialByte returns true iff the key fits into a non-whole number of bytes
   122  func (k Key) hasPartialByte() bool {
   123  	return k.length%8 > 0
   124  }
   126  // HasPrefix returns true iff [prefix] is a prefix of [k] or equal to it.
   127  func (k Key) HasPrefix(prefix Key) bool {
   128  	// [prefix] must be shorter than [k] to be a prefix.
   129  	if k.length < prefix.length {
   130  		return false
   131  	}
   133  	// The number of tokens in the last byte of [prefix], or zero
   134  	// if [prefix] fits into a whole number of bytes.
   135  	remainderBitCount := prefix.length % 8
   136  	if remainderBitCount == 0 {
   137  		return strings.HasPrefix(k.value, prefix.value)
   138  	}
   140  	// check that the tokens in the partially filled final byte of [prefix] are
   141  	// equal to the tokens in the final byte of [k].
   142  	remainderBitsMask := byte(0xFF >> remainderBitCount)
   143  	prefixRemainderTokens := prefix.value[len(prefix.value)-1] | remainderBitsMask
   144  	remainderTokens := k.value[len(prefix.value)-1] | remainderBitsMask
   146  	if prefixRemainderTokens != remainderTokens {
   147  		return false
   148  	}
   150  	// Note that this will never be an index OOB because len(prefix.value) > 0.
   151  	// If len(prefix.value) == 0 were true, [remainderTokens] would be 0, so we
   152  	// would have returned above.
   153  	prefixWithoutPartialByte := prefix.value[:len(prefix.value)-1]
   154  	return strings.HasPrefix(k.value, prefixWithoutPartialByte)
   155  }
   157  // HasStrictPrefix returns true iff [prefix] is a prefix of [k]
   158  // but is not equal to it.
   159  func (k Key) HasStrictPrefix(prefix Key) bool {
   160  	return k != prefix && k.HasPrefix(prefix)
   161  }
   163  // Length returns the number of bits in the Key
   164  func (k Key) Length() int {
   165  	return k.length
   166  }
   168  // Greater returns true if current Key is greater than other Key
   169  func (k Key) Greater(other Key) bool {
   170  	return k.Compare(other) == 1
   171  }
   173  // Less will return true if current Key is less than other Key
   174  func (k Key) Less(other Key) bool {
   175  	return k.Compare(other) == -1
   176  }
   178  func (k Key) Compare(other Key) int {
   179  	if valueCmp := cmp.Compare(k.value, other.value); valueCmp != 0 {
   180  		return valueCmp
   181  	}
   182  	return cmp.Compare(k.length, other.length)
   183  }
   185  // Extend returns a new Key that is the in-order aggregation of Key [k] with [keys]
   186  func (k Key) Extend(keys ...Key) Key {
   187  	totalBitLength := k.length
   188  	for _, key := range keys {
   189  		totalBitLength += key.length
   190  	}
   191  	buffer := make([]byte, bytesNeeded(totalBitLength))
   192  	copy(buffer, k.value)
   193  	currentTotal := k.length
   194  	for _, key := range keys {
   195  		extendIntoBuffer(buffer, key, currentTotal)
   196  		currentTotal += key.length
   197  	}
   199  	return Key{
   200  		value:  byteSliceToString(buffer),
   201  		length: totalBitLength,
   202  	}
   203  }
   205  func extendIntoBuffer(buffer []byte, val Key, bitsOffset int) {
   206  	if val.length == 0 {
   207  		return
   208  	}
   209  	bytesOffset := bytesNeeded(bitsOffset)
   210  	bitsRemainder := bitsOffset % 8
   211  	if bitsRemainder == 0 {
   212  		copy(buffer[bytesOffset:], val.value)
   213  		return
   214  	}
   216  	// Fill the partial byte with the first [shift] bits of the extension path
   217  	buffer[bytesOffset-1] |= val.value[0] >> bitsRemainder
   219  	// copy the rest of the extension path bytes into the buffer,
   220  	// shifted byte shift bits
   221  	shiftCopy(buffer[bytesOffset:], val.value, dualBitIndex(bitsRemainder))
   222  }
   224  // dualBitIndex gets the dual of the bit index
   225  // ex: in a byte, the bit 5 from the right is the same as the bit 3 from the left
   226  func dualBitIndex(shift int) int {
   227  	return (8 - shift) % 8
   228  }
   230  // Treats [src] as a bit array and copies it into [dst] shifted by [shift] bits.
   231  // For example, if [src] is [0b0000_0001, 0b0000_0010] and [shift] is 4,
   232  // we copy [0b0001_0000, 0b0010_0000] into [dst].
   233  // Assumes len(dst) >= len(src)-1.
   234  // If len(dst) == len(src)-1 the last byte of [src] is only partially copied
   235  // (i.e. the rightmost bits are not copied).
   236  func shiftCopy(dst []byte, src string, shift int) {
   237  	i := 0
   238  	dualShift := dualBitIndex(shift)
   239  	for ; i < len(src)-1; i++ {
   240  		dst[i] = src[i]<<shift | src[i+1]>>dualShift
   241  	}
   243  	if i < len(dst) {
   244  		// the last byte only has values from byte i, as there is no byte i+1
   245  		dst[i] = src[i] << shift
   246  	}
   247  }
   249  // Skip returns a new Key that contains the last
   250  // k.length-bitsToSkip bits of [k].
   251  func (k Key) Skip(bitsToSkip int) Key {
   252  	if k.length <= bitsToSkip {
   253  		return Key{}
   254  	}
   255  	result := Key{
   256  		value:  k.value[bitsToSkip/8:],
   257  		length: k.length - bitsToSkip,
   258  	}
   260  	// if the tokens to skip is a whole number of bytes,
   261  	// the remaining bytes exactly equals the new key.
   262  	if bitsToSkip%8 == 0 {
   263  		return result
   264  	}
   266  	// bitsToSkip does not remove a whole number of bytes.
   267  	// copy the remaining shifted bytes into a new buffer.
   268  	buffer := make([]byte, bytesNeeded(result.length))
   269  	bitsRemovedFromFirstRemainingByte := bitsToSkip % 8
   270  	shiftCopy(buffer, result.value, bitsRemovedFromFirstRemainingByte)
   272  	result.value = byteSliceToString(buffer)
   273  	return result
   274  }
   276  // Take returns a new Key that contains the first bitsToTake bits of the current Key
   277  func (k Key) Take(bitsToTake int) Key {
   278  	if k.length <= bitsToTake {
   279  		return k
   280  	}
   282  	result := Key{
   283  		length: bitsToTake,
   284  	}
   286  	remainderBits := result.length % 8
   287  	if remainderBits == 0 {
   288  		result.value = k.value[:bitsToTake/8]
   289  		return result
   290  	}
   292  	// We need to zero out some bits of the last byte so a simple slice will not work
   293  	// Create a new []byte to store the altered value
   294  	buffer := make([]byte, bytesNeeded(bitsToTake))
   295  	copy(buffer, k.value)
   297  	// We want to zero out everything to the right of the last token, which is at index bitsToTake-1
   298  	// Mask will be (8-remainderBits) number of 1's followed by (remainderBits) number of 0's
   299  	buffer[len(buffer)-1] &= byte(0xFF << dualBitIndex(remainderBits))
   301  	result.value = byteSliceToString(buffer)
   302  	return result
   303  }
   305  // Bytes returns the raw bytes of the Key
   306  // Invariant: The returned value must not be modified.
   307  func (k Key) Bytes() []byte {
   308  	// avoid copying during the conversion
   309  	// "safe" because we never edit the value, only used as DB key
   310  	return stringToByteSlice(k.value)
   311  }
   313  // byteSliceToString converts the []byte to a string
   314  // Invariant: The input []byte must not be modified.
   315  func byteSliceToString(bs []byte) string {
   316  	// avoid copying during the conversion
   317  	// "safe" because we never edit the []byte, and it is never returned by any functions except Bytes()
   318  	return unsafe.String(unsafe.SliceData(bs), len(bs))
   319  }
   321  // stringToByteSlice converts the string to a []byte
   322  // Invariant: The output []byte must not be modified.
   323  func stringToByteSlice(value string) []byte {
   324  	// avoid copying during the conversion
   325  	// "safe" because we never edit the []byte
   326  	return unsafe.Slice(unsafe.StringData(value), len(value))
   327  }
   329  // Returns the number of bytes needed to store [bits] bits.
   330  func bytesNeeded(bits int) int {
   331  	size := bits / 8
   332  	if bits%8 != 0 {
   333  		size++
   334  	}
   335  	return size
   336  }