github.com/MetalBlockchain/metalgo@v1.11.9/x/merkledb/codec.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package merkledb
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"errors"
    10  	"io"
    11  	"math"
    12  	"math/bits"
    13  	"slices"
    14  
    15  	"github.com/MetalBlockchain/metalgo/ids"
    16  	"github.com/MetalBlockchain/metalgo/utils/maybe"
    17  )
    18  
    19  const (
    20  	boolLen   = 1
    21  	trueByte  = 1
    22  	falseByte = 0
    23  )
    24  
    25  var (
    26  	trueBytes  = []byte{trueByte}
    27  	falseBytes = []byte{falseByte}
    28  
    29  	errChildIndexTooLarge = errors.New("invalid child index. Must be less than branching factor")
    30  	errLeadingZeroes      = errors.New("varint has leading zeroes")
    31  	errInvalidBool        = errors.New("decoded bool is neither true nor false")
    32  	errNonZeroKeyPadding  = errors.New("key partial byte should be padded with 0s")
    33  	errExtraSpace         = errors.New("trailing buffer space")
    34  	errIntOverflow        = errors.New("value overflows int")
    35  	errTooManyChildren    = errors.New("too many children")
    36  )
    37  
    38  func childSize(index byte, childEntry *child) int {
    39  	// * index
    40  	// * child ID
    41  	// * child key
    42  	// * bool indicating whether the child has a value
    43  	return uintSize(uint64(index)) + ids.IDLen + keySize(childEntry.compressedKey) + boolLen
    44  }
    45  
    46  // based on the implementation of encodeUint which uses binary.PutUvarint
    47  func uintSize(value uint64) int {
    48  	if value == 0 {
    49  		return 1
    50  	}
    51  	return (bits.Len64(value) + 6) / 7
    52  }
    53  
    54  func keySize(p Key) int {
    55  	return uintSize(uint64(p.length)) + bytesNeeded(p.length)
    56  }
    57  
    58  // Assumes [n] is non-nil.
    59  func encodedDBNodeSize(n *dbNode) int {
    60  	// * number of children
    61  	// * bool indicating whether [n] has a value
    62  	// * the value (optional)
    63  	// * children
    64  	size := uintSize(uint64(len(n.children))) + boolLen
    65  	if n.value.HasValue() {
    66  		valueLen := len(n.value.Value())
    67  		size += uintSize(uint64(valueLen)) + valueLen
    68  	}
    69  	// for each non-nil entry, we add the additional size of the child entry
    70  	for index, entry := range n.children {
    71  		size += childSize(index, entry)
    72  	}
    73  	return size
    74  }
    75  
    76  // Assumes [n] is non-nil.
    77  func encodeDBNode(n *dbNode) []byte {
    78  	length := encodedDBNodeSize(n)
    79  	w := codecWriter{
    80  		b: make([]byte, 0, length),
    81  	}
    82  
    83  	w.MaybeBytes(n.value)
    84  
    85  	numChildren := len(n.children)
    86  	w.Uvarint(uint64(numChildren))
    87  
    88  	// Avoid allocating keys entirely if the node doesn't have any children.
    89  	if numChildren == 0 {
    90  		return w.b
    91  	}
    92  
    93  	// By allocating BranchFactorLargest rather than [numChildren], this slice
    94  	// is allocated on the stack rather than the heap. BranchFactorLargest is
    95  	// at least [numChildren] which avoids memory allocations.
    96  	keys := make([]byte, numChildren, BranchFactorLargest)
    97  	i := 0
    98  	for k := range n.children {
    99  		keys[i] = k
   100  		i++
   101  	}
   102  
   103  	// Ensure that the order of entries is correct.
   104  	slices.Sort(keys)
   105  	for _, index := range keys {
   106  		entry := n.children[index]
   107  		w.Uvarint(uint64(index))
   108  		w.Key(entry.compressedKey)
   109  		w.ID(entry.id)
   110  		w.Bool(entry.hasValue)
   111  	}
   112  
   113  	return w.b
   114  }
   115  
   116  func encodeKey(key Key) []byte {
   117  	length := uintSize(uint64(key.length)) + len(key.Bytes())
   118  	w := codecWriter{
   119  		b: make([]byte, 0, length),
   120  	}
   121  	w.Key(key)
   122  	return w.b
   123  }
   124  
   125  type codecWriter struct {
   126  	b []byte
   127  }
   128  
   129  func (w *codecWriter) Bool(v bool) {
   130  	if v {
   131  		w.b = append(w.b, trueByte)
   132  	} else {
   133  		w.b = append(w.b, falseByte)
   134  	}
   135  }
   136  
   137  func (w *codecWriter) Uvarint(v uint64) {
   138  	w.b = binary.AppendUvarint(w.b, v)
   139  }
   140  
   141  func (w *codecWriter) ID(v ids.ID) {
   142  	w.b = append(w.b, v[:]...)
   143  }
   144  
   145  func (w *codecWriter) Bytes(v []byte) {
   146  	w.Uvarint(uint64(len(v)))
   147  	w.b = append(w.b, v...)
   148  }
   149  
   150  func (w *codecWriter) MaybeBytes(v maybe.Maybe[[]byte]) {
   151  	hasValue := v.HasValue()
   152  	w.Bool(hasValue)
   153  	if hasValue {
   154  		w.Bytes(v.Value())
   155  	}
   156  }
   157  
   158  func (w *codecWriter) Key(v Key) {
   159  	w.Uvarint(uint64(v.length))
   160  	w.b = append(w.b, v.Bytes()...)
   161  }
   162  
   163  // Assumes [n] is non-nil.
   164  func decodeDBNode(b []byte, n *dbNode) error {
   165  	r := codecReader{
   166  		b:    b,
   167  		copy: true,
   168  	}
   169  
   170  	var err error
   171  	n.value, err = r.MaybeBytes()
   172  	if err != nil {
   173  		return err
   174  	}
   175  
   176  	numChildren, err := r.Uvarint()
   177  	if err != nil {
   178  		return err
   179  	}
   180  	if numChildren > uint64(BranchFactorLargest) {
   181  		return errTooManyChildren
   182  	}
   183  
   184  	n.children = make(map[byte]*child, numChildren)
   185  	var previousChild uint64
   186  	for i := uint64(0); i < numChildren; i++ {
   187  		index, err := r.Uvarint()
   188  		if err != nil {
   189  			return err
   190  		}
   191  		if (i != 0 && index <= previousChild) || index > math.MaxUint8 {
   192  			return errChildIndexTooLarge
   193  		}
   194  		previousChild = index
   195  
   196  		compressedKey, err := r.Key()
   197  		if err != nil {
   198  			return err
   199  		}
   200  		childID, err := r.ID()
   201  		if err != nil {
   202  			return err
   203  		}
   204  		hasValue, err := r.Bool()
   205  		if err != nil {
   206  			return err
   207  		}
   208  		n.children[byte(index)] = &child{
   209  			compressedKey: compressedKey,
   210  			id:            childID,
   211  			hasValue:      hasValue,
   212  		}
   213  	}
   214  	if len(r.b) != 0 {
   215  		return errExtraSpace
   216  	}
   217  	return nil
   218  }
   219  
   220  func decodeKey(b []byte) (Key, error) {
   221  	r := codecReader{
   222  		b:    b,
   223  		copy: true,
   224  	}
   225  	key, err := r.Key()
   226  	if err != nil {
   227  		return Key{}, err
   228  	}
   229  	if len(r.b) != 0 {
   230  		return Key{}, errExtraSpace
   231  	}
   232  	return key, nil
   233  }
   234  
   235  type codecReader struct {
   236  	b []byte
   237  	// copy is used to flag to the reader if it is required to copy references
   238  	// to [b].
   239  	copy bool
   240  }
   241  
   242  func (r *codecReader) Bool() (bool, error) {
   243  	if len(r.b) < boolLen {
   244  		return false, io.ErrUnexpectedEOF
   245  	}
   246  	boolByte := r.b[0]
   247  	if boolByte > trueByte {
   248  		return false, errInvalidBool
   249  	}
   250  
   251  	r.b = r.b[boolLen:]
   252  	return boolByte == trueByte, nil
   253  }
   254  
   255  func (r *codecReader) Uvarint() (uint64, error) {
   256  	length, bytesRead := binary.Uvarint(r.b)
   257  	if bytesRead <= 0 {
   258  		return 0, io.ErrUnexpectedEOF
   259  	}
   260  
   261  	// To ensure decoding is canonical, we check for leading zeroes in the
   262  	// varint.
   263  	// The last byte of the varint includes the most significant bits.
   264  	// If the last byte is 0, then the number should have been encoded more
   265  	// efficiently by removing this leading zero.
   266  	if bytesRead > 1 && r.b[bytesRead-1] == 0x00 {
   267  		return 0, errLeadingZeroes
   268  	}
   269  
   270  	r.b = r.b[bytesRead:]
   271  	return length, nil
   272  }
   273  
   274  func (r *codecReader) ID() (ids.ID, error) {
   275  	if len(r.b) < ids.IDLen {
   276  		return ids.Empty, io.ErrUnexpectedEOF
   277  	}
   278  	id := ids.ID(r.b[:ids.IDLen])
   279  
   280  	r.b = r.b[ids.IDLen:]
   281  	return id, nil
   282  }
   283  
   284  func (r *codecReader) Bytes() ([]byte, error) {
   285  	length, err := r.Uvarint()
   286  	if err != nil {
   287  		return nil, err
   288  	}
   289  
   290  	if length > uint64(len(r.b)) {
   291  		return nil, io.ErrUnexpectedEOF
   292  	}
   293  	result := r.b[:length]
   294  	if r.copy {
   295  		result = bytes.Clone(result)
   296  	}
   297  
   298  	r.b = r.b[length:]
   299  	return result, nil
   300  }
   301  
   302  func (r *codecReader) MaybeBytes() (maybe.Maybe[[]byte], error) {
   303  	if hasValue, err := r.Bool(); err != nil || !hasValue {
   304  		return maybe.Nothing[[]byte](), err
   305  	}
   306  
   307  	bytes, err := r.Bytes()
   308  	return maybe.Some(bytes), err
   309  }
   310  
   311  func (r *codecReader) Key() (Key, error) {
   312  	bitLen, err := r.Uvarint()
   313  	if err != nil {
   314  		return Key{}, err
   315  	}
   316  	if bitLen > math.MaxInt {
   317  		return Key{}, errIntOverflow
   318  	}
   319  
   320  	result := Key{
   321  		length: int(bitLen),
   322  	}
   323  	byteLen := bytesNeeded(result.length)
   324  	if byteLen > len(r.b) {
   325  		return Key{}, io.ErrUnexpectedEOF
   326  	}
   327  	if result.hasPartialByte() {
   328  		// Confirm that the padding bits in the partial byte are 0.
   329  		// We want to only look at the bits to the right of the last token,
   330  		// which is at index length-1.
   331  		// Generate a mask where the (result.length % 8) left bits are 0.
   332  		paddingMask := byte(0xFF >> (result.length % 8))
   333  		if r.b[byteLen-1]&paddingMask != 0 {
   334  			return Key{}, errNonZeroKeyPadding
   335  		}
   336  	}
   337  	result.value = string(r.b[:byteLen])
   338  
   339  	r.b = r.b[byteLen:]
   340  	return result, nil
   341  }