github.com/aakash4dev/cometbft@v0.38.2/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/aakash4dev/cometbft/crypto/merkle"
    10  	"github.com/aakash4dev/cometbft/libs/bits"
    11  	cmtbytes "github.com/aakash4dev/cometbft/libs/bytes"
    12  	cmtjson "github.com/aakash4dev/cometbft/libs/json"
    13  	cmtmath "github.com/aakash4dev/cometbft/libs/math"
    14  	cmtsync "github.com/aakash4dev/cometbft/libs/sync"
    15  	cmtproto "github.com/aakash4dev/cometbft/proto/tendermint/types"
    16  )
    17  
    18  var (
    19  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    20  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    21  )
    22  
    23  type Part struct {
    24  	Index uint32            `json:"index"`
    25  	Bytes cmtbytes.HexBytes `json:"bytes"`
    26  	Proof merkle.Proof      `json:"proof"`
    27  }
    28  
    29  // ValidateBasic performs basic validation.
    30  func (part *Part) ValidateBasic() error {
    31  	if len(part.Bytes) > int(BlockPartSizeBytes) {
    32  		return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    33  	}
    34  	if err := part.Proof.ValidateBasic(); err != nil {
    35  		return fmt.Errorf("wrong Proof: %w", err)
    36  	}
    37  	return nil
    38  }
    39  
    40  // String returns a string representation of Part.
    41  //
    42  // See StringIndented.
    43  func (part *Part) String() string {
    44  	return part.StringIndented("")
    45  }
    46  
    47  // StringIndented returns an indented Part.
    48  //
    49  // See merkle.Proof#StringIndented
    50  func (part *Part) StringIndented(indent string) string {
    51  	return fmt.Sprintf(`Part{#%v
    52  %s  Bytes: %X...
    53  %s  Proof: %v
    54  %s}`,
    55  		part.Index,
    56  		indent, cmtbytes.Fingerprint(part.Bytes),
    57  		indent, part.Proof.StringIndented(indent+"  "),
    58  		indent)
    59  }
    60  
    61  func (part *Part) ToProto() (*cmtproto.Part, error) {
    62  	if part == nil {
    63  		return nil, errors.New("nil part")
    64  	}
    65  	pb := new(cmtproto.Part)
    66  	proof := part.Proof.ToProto()
    67  
    68  	pb.Index = part.Index
    69  	pb.Bytes = part.Bytes
    70  	pb.Proof = *proof
    71  
    72  	return pb, nil
    73  }
    74  
    75  func PartFromProto(pb *cmtproto.Part) (*Part, error) {
    76  	if pb == nil {
    77  		return nil, errors.New("nil part")
    78  	}
    79  
    80  	part := new(Part)
    81  	proof, err := merkle.ProofFromProto(&pb.Proof)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	part.Index = pb.Index
    86  	part.Bytes = pb.Bytes
    87  	part.Proof = *proof
    88  
    89  	return part, part.ValidateBasic()
    90  }
    91  
    92  //-------------------------------------
    93  
    94  type PartSetHeader struct {
    95  	Total uint32            `json:"total"`
    96  	Hash  cmtbytes.HexBytes `json:"hash"`
    97  }
    98  
    99  // String returns a string representation of PartSetHeader.
   100  //
   101  // 1. total number of parts
   102  // 2. first 6 bytes of the hash
   103  func (psh PartSetHeader) String() string {
   104  	return fmt.Sprintf("%v:%X", psh.Total, cmtbytes.Fingerprint(psh.Hash))
   105  }
   106  
   107  func (psh PartSetHeader) IsZero() bool {
   108  	return psh.Total == 0 && len(psh.Hash) == 0
   109  }
   110  
   111  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
   112  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
   113  }
   114  
   115  // ValidateBasic performs basic validation.
   116  func (psh PartSetHeader) ValidateBasic() error {
   117  	// Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
   118  	if err := ValidateHash(psh.Hash); err != nil {
   119  		return fmt.Errorf("wrong Hash: %w", err)
   120  	}
   121  	return nil
   122  }
   123  
   124  // ToProto converts PartSetHeader to protobuf
   125  func (psh *PartSetHeader) ToProto() cmtproto.PartSetHeader {
   126  	if psh == nil {
   127  		return cmtproto.PartSetHeader{}
   128  	}
   129  
   130  	return cmtproto.PartSetHeader{
   131  		Total: psh.Total,
   132  		Hash:  psh.Hash,
   133  	}
   134  }
   135  
   136  // FromProto sets a protobuf PartSetHeader to the given pointer
   137  func PartSetHeaderFromProto(ppsh *cmtproto.PartSetHeader) (*PartSetHeader, error) {
   138  	if ppsh == nil {
   139  		return nil, errors.New("nil PartSetHeader")
   140  	}
   141  	psh := new(PartSetHeader)
   142  	psh.Total = ppsh.Total
   143  	psh.Hash = ppsh.Hash
   144  
   145  	return psh, psh.ValidateBasic()
   146  }
   147  
   148  // ProtoPartSetHeaderIsZero is similar to the IsZero function for
   149  // PartSetHeader, but for the Protobuf representation.
   150  func ProtoPartSetHeaderIsZero(ppsh *cmtproto.PartSetHeader) bool {
   151  	return ppsh.Total == 0 && len(ppsh.Hash) == 0
   152  }
   153  
   154  //-------------------------------------
   155  
   156  type PartSet struct {
   157  	total uint32
   158  	hash  []byte
   159  
   160  	mtx           cmtsync.Mutex
   161  	parts         []*Part
   162  	partsBitArray *bits.BitArray
   163  	count         uint32
   164  	// a count of the total size (in bytes). Used to ensure that the
   165  	// part set doesn't exceed the maximum block bytes
   166  	byteSize int64
   167  }
   168  
   169  // Returns an immutable, full PartSet from the data bytes.
   170  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   171  // CONTRACT: partSize is greater than zero.
   172  func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
   173  	// divide data into parts of size `partSize`
   174  	total := (uint32(len(data)) + partSize - 1) / partSize
   175  	parts := make([]*Part, total)
   176  	partsBytes := make([][]byte, total)
   177  	partsBitArray := bits.NewBitArray(int(total))
   178  	for i := uint32(0); i < total; i++ {
   179  		part := &Part{
   180  			Index: i,
   181  			Bytes: data[i*partSize : cmtmath.MinInt(len(data), int((i+1)*partSize))],
   182  		}
   183  		parts[i] = part
   184  		partsBytes[i] = part.Bytes
   185  		partsBitArray.SetIndex(int(i), true)
   186  	}
   187  	// Compute merkle proofs
   188  	root, proofs := merkle.ProofsFromByteSlices(partsBytes)
   189  	for i := uint32(0); i < total; i++ {
   190  		parts[i].Proof = *proofs[i]
   191  	}
   192  	return &PartSet{
   193  		total:         total,
   194  		hash:          root,
   195  		parts:         parts,
   196  		partsBitArray: partsBitArray,
   197  		count:         total,
   198  		byteSize:      int64(len(data)),
   199  	}
   200  }
   201  
   202  // Returns an empty PartSet ready to be populated.
   203  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   204  	return &PartSet{
   205  		total:         header.Total,
   206  		hash:          header.Hash,
   207  		parts:         make([]*Part, header.Total),
   208  		partsBitArray: bits.NewBitArray(int(header.Total)),
   209  		count:         0,
   210  		byteSize:      0,
   211  	}
   212  }
   213  
   214  func (ps *PartSet) Header() PartSetHeader {
   215  	if ps == nil {
   216  		return PartSetHeader{}
   217  	}
   218  	return PartSetHeader{
   219  		Total: ps.total,
   220  		Hash:  ps.hash,
   221  	}
   222  }
   223  
   224  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   225  	if ps == nil {
   226  		return false
   227  	}
   228  	return ps.Header().Equals(header)
   229  }
   230  
   231  func (ps *PartSet) BitArray() *bits.BitArray {
   232  	ps.mtx.Lock()
   233  	defer ps.mtx.Unlock()
   234  	return ps.partsBitArray.Copy()
   235  }
   236  
   237  func (ps *PartSet) Hash() []byte {
   238  	if ps == nil {
   239  		return merkle.HashFromByteSlices(nil)
   240  	}
   241  	return ps.hash
   242  }
   243  
   244  func (ps *PartSet) HashesTo(hash []byte) bool {
   245  	if ps == nil {
   246  		return false
   247  	}
   248  	return bytes.Equal(ps.hash, hash)
   249  }
   250  
   251  func (ps *PartSet) Count() uint32 {
   252  	if ps == nil {
   253  		return 0
   254  	}
   255  	return ps.count
   256  }
   257  
   258  func (ps *PartSet) ByteSize() int64 {
   259  	if ps == nil {
   260  		return 0
   261  	}
   262  	return ps.byteSize
   263  }
   264  
   265  func (ps *PartSet) Total() uint32 {
   266  	if ps == nil {
   267  		return 0
   268  	}
   269  	return ps.total
   270  }
   271  
   272  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   273  	// TODO: remove this? would be preferable if this only returned (false, nil)
   274  	// when its a duplicate block part
   275  	if ps == nil {
   276  		return false, nil
   277  	}
   278  
   279  	ps.mtx.Lock()
   280  	defer ps.mtx.Unlock()
   281  
   282  	// Invalid part index
   283  	if part.Index >= ps.total {
   284  		return false, ErrPartSetUnexpectedIndex
   285  	}
   286  
   287  	// If part already exists, return false.
   288  	if ps.parts[part.Index] != nil {
   289  		return false, nil
   290  	}
   291  
   292  	// Check hash proof
   293  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   294  		return false, ErrPartSetInvalidProof
   295  	}
   296  
   297  	// Add part
   298  	ps.parts[part.Index] = part
   299  	ps.partsBitArray.SetIndex(int(part.Index), true)
   300  	ps.count++
   301  	ps.byteSize += int64(len(part.Bytes))
   302  	return true, nil
   303  }
   304  
   305  func (ps *PartSet) GetPart(index int) *Part {
   306  	ps.mtx.Lock()
   307  	defer ps.mtx.Unlock()
   308  	return ps.parts[index]
   309  }
   310  
   311  func (ps *PartSet) IsComplete() bool {
   312  	return ps.count == ps.total
   313  }
   314  
   315  func (ps *PartSet) GetReader() io.Reader {
   316  	if !ps.IsComplete() {
   317  		panic("Cannot GetReader() on incomplete PartSet")
   318  	}
   319  	return NewPartSetReader(ps.parts)
   320  }
   321  
   322  type PartSetReader struct {
   323  	i      int
   324  	parts  []*Part
   325  	reader *bytes.Reader
   326  }
   327  
   328  func NewPartSetReader(parts []*Part) *PartSetReader {
   329  	return &PartSetReader{
   330  		i:      0,
   331  		parts:  parts,
   332  		reader: bytes.NewReader(parts[0].Bytes),
   333  	}
   334  }
   335  
   336  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   337  	readerLen := psr.reader.Len()
   338  	if readerLen >= len(p) {
   339  		return psr.reader.Read(p)
   340  	} else if readerLen > 0 {
   341  		n1, err := psr.Read(p[:readerLen])
   342  		if err != nil {
   343  			return n1, err
   344  		}
   345  		n2, err := psr.Read(p[readerLen:])
   346  		return n1 + n2, err
   347  	}
   348  
   349  	psr.i++
   350  	if psr.i >= len(psr.parts) {
   351  		return 0, io.EOF
   352  	}
   353  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   354  	return psr.Read(p)
   355  }
   356  
   357  // StringShort returns a short version of String.
   358  //
   359  // (Count of Total)
   360  func (ps *PartSet) StringShort() string {
   361  	if ps == nil {
   362  		return "nil-PartSet"
   363  	}
   364  	ps.mtx.Lock()
   365  	defer ps.mtx.Unlock()
   366  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   367  }
   368  
   369  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   370  	if ps == nil {
   371  		return []byte("{}"), nil
   372  	}
   373  
   374  	ps.mtx.Lock()
   375  	defer ps.mtx.Unlock()
   376  
   377  	return cmtjson.Marshal(struct {
   378  		CountTotal    string         `json:"count/total"`
   379  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   380  	}{
   381  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   382  		ps.partsBitArray,
   383  	})
   384  }