github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"sync"
    10  
    11  	"github.com/ari-anchor/sei-tendermint/crypto/merkle"
    12  	"github.com/ari-anchor/sei-tendermint/libs/bits"
    13  	tmbytes "github.com/ari-anchor/sei-tendermint/libs/bytes"
    14  	tmmath "github.com/ari-anchor/sei-tendermint/libs/math"
    15  	tmproto "github.com/ari-anchor/sei-tendermint/proto/tendermint/types"
    16  )
    17  
    18  var (
    19  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    20  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    21  )
    22  
    23  type Part struct {
    24  	Index uint32           `json:"index"`
    25  	Bytes tmbytes.HexBytes `json:"bytes"`
    26  	Proof merkle.Proof     `json:"proof"`
    27  }
    28  
    29  // ValidateBasic performs basic validation.
    30  func (part *Part) ValidateBasic() error {
    31  	if len(part.Bytes) > int(BlockPartSizeBytes) {
    32  		return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    33  	}
    34  	if err := part.Proof.ValidateBasic(); err != nil {
    35  		return fmt.Errorf("wrong Proof: %w", err)
    36  	}
    37  	return nil
    38  }
    39  
    40  // String returns a string representation of Part.
    41  //
    42  // See StringIndented.
    43  func (part *Part) String() string {
    44  	return part.StringIndented("")
    45  }
    46  
    47  // StringIndented returns an indented Part.
    48  //
    49  // See merkle.Proof#StringIndented
    50  func (part *Part) StringIndented(indent string) string {
    51  	return fmt.Sprintf(`Part{#%v
    52  %s  Bytes: %X...
    53  %s  Proof: %v
    54  %s}`,
    55  		part.Index,
    56  		indent, tmbytes.Fingerprint(part.Bytes),
    57  		indent, part.Proof.StringIndented(indent+"  "),
    58  		indent)
    59  }
    60  
    61  func (part *Part) ToProto() (*tmproto.Part, error) {
    62  	if part == nil {
    63  		return nil, errors.New("nil part")
    64  	}
    65  	pb := new(tmproto.Part)
    66  	proof := part.Proof.ToProto()
    67  
    68  	pb.Index = part.Index
    69  	pb.Bytes = part.Bytes
    70  	pb.Proof = *proof
    71  
    72  	return pb, nil
    73  }
    74  
    75  func PartFromProto(pb *tmproto.Part) (*Part, error) {
    76  	if pb == nil {
    77  		return nil, errors.New("nil part")
    78  	}
    79  
    80  	part := new(Part)
    81  	proof, err := merkle.ProofFromProto(&pb.Proof)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	part.Index = pb.Index
    86  	part.Bytes = pb.Bytes
    87  	part.Proof = *proof
    88  
    89  	return part, part.ValidateBasic()
    90  }
    91  
    92  //-------------------------------------
    93  
    94  type PartSetHeader struct {
    95  	Total uint32           `json:"total"`
    96  	Hash  tmbytes.HexBytes `json:"hash"`
    97  }
    98  
    99  // String returns a string representation of PartSetHeader.
   100  //
   101  // 1. total number of parts
   102  // 2. first 6 bytes of the hash
   103  func (psh PartSetHeader) String() string {
   104  	return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
   105  }
   106  
   107  func (psh PartSetHeader) IsZero() bool {
   108  	return psh.Total == 0 && len(psh.Hash) == 0
   109  }
   110  
   111  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
   112  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
   113  }
   114  
   115  // ValidateBasic performs basic validation.
   116  func (psh PartSetHeader) ValidateBasic() error {
   117  	// Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
   118  	if err := ValidateHash(psh.Hash); err != nil {
   119  		return fmt.Errorf("wrong Hash: %w", err)
   120  	}
   121  	return nil
   122  }
   123  
   124  // ToProto converts PartSetHeader to protobuf
   125  func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
   126  	if psh == nil {
   127  		return tmproto.PartSetHeader{}
   128  	}
   129  
   130  	return tmproto.PartSetHeader{
   131  		Total: psh.Total,
   132  		Hash:  psh.Hash,
   133  	}
   134  }
   135  
   136  // FromProto sets a protobuf PartSetHeader to the given pointer
   137  func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
   138  	if ppsh == nil {
   139  		return nil, errors.New("nil PartSetHeader")
   140  	}
   141  	psh := new(PartSetHeader)
   142  	psh.Total = ppsh.Total
   143  	psh.Hash = ppsh.Hash
   144  
   145  	return psh, psh.ValidateBasic()
   146  }
   147  
   148  // ProtoPartSetHeaderIsZero is similar to the IsZero function for
   149  // PartSetHeader, but for the Protobuf representation.
   150  func ProtoPartSetHeaderIsZero(ppsh *tmproto.PartSetHeader) bool {
   151  	return ppsh.Total == 0 && len(ppsh.Hash) == 0
   152  }
   153  
   154  //-------------------------------------
   155  
   156  type PartSet struct {
   157  	total uint32
   158  	hash  []byte
   159  
   160  	mtx           sync.Mutex
   161  	parts         []*Part
   162  	partsBitArray *bits.BitArray
   163  	count         uint32
   164  	// a count of the total size (in bytes). Used to ensure that the
   165  	// part set doesn't exceed the maximum block bytes
   166  	byteSize int64
   167  }
   168  
   169  // Returns an immutable, full PartSet from the data bytes.
   170  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   171  // CONTRACT: partSize is greater than zero.
   172  func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
   173  	// divide data into 4kb parts.
   174  	total := (uint32(len(data)) + partSize - 1) / partSize
   175  	parts := make([]*Part, total)
   176  	partsBytes := make([][]byte, total)
   177  	partsBitArray := bits.NewBitArray(int(total))
   178  	for i := uint32(0); i < total; i++ {
   179  		part := &Part{
   180  			Index: i,
   181  			Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
   182  		}
   183  		parts[i] = part
   184  		partsBytes[i] = part.Bytes
   185  		partsBitArray.SetIndex(int(i), true)
   186  	}
   187  	// Compute merkle proofs
   188  	root, proofs := merkle.ProofsFromByteSlices(partsBytes)
   189  	for i := uint32(0); i < total; i++ {
   190  		parts[i].Proof = *proofs[i]
   191  	}
   192  	return &PartSet{
   193  		total:         total,
   194  		hash:          root,
   195  		parts:         parts,
   196  		partsBitArray: partsBitArray,
   197  		count:         total,
   198  		byteSize:      int64(len(data)),
   199  	}
   200  }
   201  
   202  // Returns an empty PartSet ready to be populated.
   203  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   204  	return &PartSet{
   205  		total:         header.Total,
   206  		hash:          header.Hash,
   207  		parts:         make([]*Part, header.Total),
   208  		partsBitArray: bits.NewBitArray(int(header.Total)),
   209  		count:         0,
   210  		byteSize:      0,
   211  	}
   212  }
   213  
   214  func (ps *PartSet) Header() PartSetHeader {
   215  	if ps == nil {
   216  		return PartSetHeader{}
   217  	}
   218  	return PartSetHeader{
   219  		Total: ps.total,
   220  		Hash:  ps.hash,
   221  	}
   222  }
   223  
   224  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   225  	if ps == nil {
   226  		return false
   227  	}
   228  	return ps.Header().Equals(header)
   229  }
   230  
   231  func (ps *PartSet) BitArray() *bits.BitArray {
   232  	ps.mtx.Lock()
   233  	defer ps.mtx.Unlock()
   234  	return ps.partsBitArray.Copy()
   235  }
   236  
   237  func (ps *PartSet) Hash() []byte {
   238  	if ps == nil {
   239  		return merkle.HashFromByteSlices(nil)
   240  	}
   241  	return ps.hash
   242  }
   243  
   244  func (ps *PartSet) HashesTo(hash []byte) bool {
   245  	if ps == nil {
   246  		return false
   247  	}
   248  	return bytes.Equal(ps.hash, hash)
   249  }
   250  
   251  func (ps *PartSet) Count() uint32 {
   252  	if ps == nil {
   253  		return 0
   254  	}
   255  	return ps.count
   256  }
   257  
   258  func (ps *PartSet) ByteSize() int64 {
   259  	if ps == nil {
   260  		return 0
   261  	}
   262  	return ps.byteSize
   263  }
   264  
   265  func (ps *PartSet) Total() uint32 {
   266  	if ps == nil {
   267  		return 0
   268  	}
   269  	return ps.total
   270  }
   271  
   272  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   273  	if ps == nil {
   274  		return false, nil
   275  	}
   276  	ps.mtx.Lock()
   277  	defer ps.mtx.Unlock()
   278  
   279  	// Invalid part index
   280  	if part.Index >= ps.total {
   281  		return false, ErrPartSetUnexpectedIndex
   282  	}
   283  
   284  	// If part already exists, return false.
   285  	if ps.parts[part.Index] != nil {
   286  		return false, nil
   287  	}
   288  
   289  	// Check hash proof
   290  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   291  		return false, ErrPartSetInvalidProof
   292  	}
   293  
   294  	// Add part
   295  	ps.parts[part.Index] = part
   296  	ps.partsBitArray.SetIndex(int(part.Index), true)
   297  	ps.count++
   298  	ps.byteSize += int64(len(part.Bytes))
   299  	return true, nil
   300  }
   301  
   302  func (ps *PartSet) GetPart(index int) *Part {
   303  	ps.mtx.Lock()
   304  	defer ps.mtx.Unlock()
   305  	return ps.parts[index]
   306  }
   307  
   308  func (ps *PartSet) IsComplete() bool {
   309  	return ps.count == ps.total
   310  }
   311  
   312  func (ps *PartSet) GetReader() io.Reader {
   313  	if !ps.IsComplete() {
   314  		panic("Cannot GetReader() on incomplete PartSet")
   315  	}
   316  	return NewPartSetReader(ps.parts)
   317  }
   318  
   319  type PartSetReader struct {
   320  	i      int
   321  	parts  []*Part
   322  	reader *bytes.Reader
   323  }
   324  
   325  func NewPartSetReader(parts []*Part) *PartSetReader {
   326  	return &PartSetReader{
   327  		i:      0,
   328  		parts:  parts,
   329  		reader: bytes.NewReader(parts[0].Bytes),
   330  	}
   331  }
   332  
   333  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   334  	readerLen := psr.reader.Len()
   335  	if readerLen >= len(p) {
   336  		return psr.reader.Read(p)
   337  	} else if readerLen > 0 {
   338  		n1, err := psr.Read(p[:readerLen])
   339  		if err != nil {
   340  			return n1, err
   341  		}
   342  		n2, err := psr.Read(p[readerLen:])
   343  		return n1 + n2, err
   344  	}
   345  
   346  	psr.i++
   347  	if psr.i >= len(psr.parts) {
   348  		return 0, io.EOF
   349  	}
   350  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   351  	return psr.Read(p)
   352  }
   353  
   354  // StringShort returns a short version of String.
   355  //
   356  // (Count of Total)
   357  func (ps *PartSet) StringShort() string {
   358  	if ps == nil {
   359  		return "nil-PartSet"
   360  	}
   361  	ps.mtx.Lock()
   362  	defer ps.mtx.Unlock()
   363  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   364  }
   365  
   366  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   367  	if ps == nil {
   368  		return []byte("{}"), nil
   369  	}
   370  
   371  	ps.mtx.Lock()
   372  	defer ps.mtx.Unlock()
   373  
   374  	return json.Marshal(struct {
   375  		CountTotal    string         `json:"count/total"`
   376  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   377  	}{
   378  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   379  		ps.partsBitArray,
   380  	})
   381  }