github.com/DFWallet/tendermint-cosmos@v0.0.2/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/DFWallet/tendermint-cosmos/crypto/merkle"
    10  	"github.com/DFWallet/tendermint-cosmos/libs/bits"
    11  	tmbytes "github.com/DFWallet/tendermint-cosmos/libs/bytes"
    12  	tmjson "github.com/DFWallet/tendermint-cosmos/libs/json"
    13  	tmmath "github.com/DFWallet/tendermint-cosmos/libs/math"
    14  	tmsync "github.com/DFWallet/tendermint-cosmos/libs/sync"
    15  	tmproto "github.com/DFWallet/tendermint-cosmos/proto/tendermint/types"
    16  )
    17  
    18  var (
    19  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    20  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    21  )
    22  
    23  type Part struct {
    24  	Index uint32           `json:"index"`
    25  	Bytes tmbytes.HexBytes `json:"bytes"`
    26  	Proof merkle.Proof     `json:"proof"`
    27  }
    28  
    29  // ValidateBasic performs basic validation.
    30  func (part *Part) ValidateBasic() error {
    31  	if len(part.Bytes) > int(BlockPartSizeBytes) {
    32  		return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    33  	}
    34  	if err := part.Proof.ValidateBasic(); err != nil {
    35  		return fmt.Errorf("wrong Proof: %w", err)
    36  	}
    37  	return nil
    38  }
    39  
    40  // String returns a string representation of Part.
    41  //
    42  // See StringIndented.
    43  func (part *Part) String() string {
    44  	return part.StringIndented("")
    45  }
    46  
    47  // StringIndented returns an indented Part.
    48  //
    49  // See merkle.Proof#StringIndented
    50  func (part *Part) StringIndented(indent string) string {
    51  	return fmt.Sprintf(`Part{#%v
    52  %s  Bytes: %X...
    53  %s  Proof: %v
    54  %s}`,
    55  		part.Index,
    56  		indent, tmbytes.Fingerprint(part.Bytes),
    57  		indent, part.Proof.StringIndented(indent+"  "),
    58  		indent)
    59  }
    60  
    61  func (part *Part) ToProto() (*tmproto.Part, error) {
    62  	if part == nil {
    63  		return nil, errors.New("nil part")
    64  	}
    65  	pb := new(tmproto.Part)
    66  	proof := part.Proof.ToProto()
    67  
    68  	pb.Index = part.Index
    69  	pb.Bytes = part.Bytes
    70  	pb.Proof = *proof
    71  
    72  	return pb, nil
    73  }
    74  
    75  func PartFromProto(pb *tmproto.Part) (*Part, error) {
    76  	if pb == nil {
    77  		return nil, errors.New("nil part")
    78  	}
    79  
    80  	part := new(Part)
    81  	proof, err := merkle.ProofFromProto(&pb.Proof)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	part.Index = pb.Index
    86  	part.Bytes = pb.Bytes
    87  	part.Proof = *proof
    88  
    89  	return part, part.ValidateBasic()
    90  }
    91  
    92  //-------------------------------------
    93  
    94  type PartSetHeader struct {
    95  	Total uint32           `json:"total"`
    96  	Hash  tmbytes.HexBytes `json:"hash"`
    97  }
    98  
    99  // String returns a string representation of PartSetHeader.
   100  //
   101  // 1. total number of parts
   102  // 2. first 6 bytes of the hash
   103  func (psh PartSetHeader) String() string {
   104  	return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
   105  }
   106  
   107  func (psh PartSetHeader) IsZero() bool {
   108  	return psh.Total == 0 && len(psh.Hash) == 0
   109  }
   110  
   111  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
   112  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
   113  }
   114  
   115  // ValidateBasic performs basic validation.
   116  func (psh PartSetHeader) ValidateBasic() error {
   117  	// Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
   118  	if err := ValidateHash(psh.Hash); err != nil {
   119  		return fmt.Errorf("wrong Hash: %w", err)
   120  	}
   121  	return nil
   122  }
   123  
   124  // ToProto converts PartSetHeader to protobuf
   125  func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
   126  	if psh == nil {
   127  		return tmproto.PartSetHeader{}
   128  	}
   129  
   130  	return tmproto.PartSetHeader{
   131  		Total: psh.Total,
   132  		Hash:  psh.Hash,
   133  	}
   134  }
   135  
   136  // FromProto sets a protobuf PartSetHeader to the given pointer
   137  func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
   138  	if ppsh == nil {
   139  		return nil, errors.New("nil PartSetHeader")
   140  	}
   141  	psh := new(PartSetHeader)
   142  	psh.Total = ppsh.Total
   143  	psh.Hash = ppsh.Hash
   144  
   145  	return psh, psh.ValidateBasic()
   146  }
   147  
   148  //-------------------------------------
   149  
   150  type PartSet struct {
   151  	total uint32
   152  	hash  []byte
   153  
   154  	mtx           tmsync.Mutex
   155  	parts         []*Part
   156  	partsBitArray *bits.BitArray
   157  	count         uint32
   158  	// a count of the total size (in bytes). Used to ensure that the
   159  	// part set doesn't exceed the maximum block bytes
   160  	byteSize int64
   161  }
   162  
   163  // Returns an immutable, full PartSet from the data bytes.
   164  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   165  // CONTRACT: partSize is greater than zero.
   166  func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
   167  	// divide data into 4kb parts.
   168  	total := (uint32(len(data)) + partSize - 1) / partSize
   169  	parts := make([]*Part, total)
   170  	partsBytes := make([][]byte, total)
   171  	partsBitArray := bits.NewBitArray(int(total))
   172  	for i := uint32(0); i < total; i++ {
   173  		part := &Part{
   174  			Index: i,
   175  			Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
   176  		}
   177  		parts[i] = part
   178  		partsBytes[i] = part.Bytes
   179  		partsBitArray.SetIndex(int(i), true)
   180  	}
   181  	// Compute merkle proofs
   182  	root, proofs := merkle.ProofsFromByteSlices(partsBytes)
   183  	for i := uint32(0); i < total; i++ {
   184  		parts[i].Proof = *proofs[i]
   185  	}
   186  	return &PartSet{
   187  		total:         total,
   188  		hash:          root,
   189  		parts:         parts,
   190  		partsBitArray: partsBitArray,
   191  		count:         total,
   192  		byteSize:      int64(len(data)),
   193  	}
   194  }
   195  
   196  // Returns an empty PartSet ready to be populated.
   197  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   198  	return &PartSet{
   199  		total:         header.Total,
   200  		hash:          header.Hash,
   201  		parts:         make([]*Part, header.Total),
   202  		partsBitArray: bits.NewBitArray(int(header.Total)),
   203  		count:         0,
   204  		byteSize:      0,
   205  	}
   206  }
   207  
   208  func (ps *PartSet) Header() PartSetHeader {
   209  	if ps == nil {
   210  		return PartSetHeader{}
   211  	}
   212  	return PartSetHeader{
   213  		Total: ps.total,
   214  		Hash:  ps.hash,
   215  	}
   216  }
   217  
   218  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   219  	if ps == nil {
   220  		return false
   221  	}
   222  	return ps.Header().Equals(header)
   223  }
   224  
   225  func (ps *PartSet) BitArray() *bits.BitArray {
   226  	ps.mtx.Lock()
   227  	defer ps.mtx.Unlock()
   228  	return ps.partsBitArray.Copy()
   229  }
   230  
   231  func (ps *PartSet) Hash() []byte {
   232  	if ps == nil {
   233  		return merkle.HashFromByteSlices(nil)
   234  	}
   235  	return ps.hash
   236  }
   237  
   238  func (ps *PartSet) HashesTo(hash []byte) bool {
   239  	if ps == nil {
   240  		return false
   241  	}
   242  	return bytes.Equal(ps.hash, hash)
   243  }
   244  
   245  func (ps *PartSet) Count() uint32 {
   246  	if ps == nil {
   247  		return 0
   248  	}
   249  	return ps.count
   250  }
   251  
   252  func (ps *PartSet) ByteSize() int64 {
   253  	if ps == nil {
   254  		return 0
   255  	}
   256  	return ps.byteSize
   257  }
   258  
   259  func (ps *PartSet) Total() uint32 {
   260  	if ps == nil {
   261  		return 0
   262  	}
   263  	return ps.total
   264  }
   265  
   266  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   267  	if ps == nil {
   268  		return false, nil
   269  	}
   270  	ps.mtx.Lock()
   271  	defer ps.mtx.Unlock()
   272  
   273  	// Invalid part index
   274  	if part.Index >= ps.total {
   275  		return false, ErrPartSetUnexpectedIndex
   276  	}
   277  
   278  	// If part already exists, return false.
   279  	if ps.parts[part.Index] != nil {
   280  		return false, nil
   281  	}
   282  
   283  	// Check hash proof
   284  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   285  		return false, ErrPartSetInvalidProof
   286  	}
   287  
   288  	// Add part
   289  	ps.parts[part.Index] = part
   290  	ps.partsBitArray.SetIndex(int(part.Index), true)
   291  	ps.count++
   292  	ps.byteSize += int64(len(part.Bytes))
   293  	return true, nil
   294  }
   295  
   296  func (ps *PartSet) GetPart(index int) *Part {
   297  	ps.mtx.Lock()
   298  	defer ps.mtx.Unlock()
   299  	return ps.parts[index]
   300  }
   301  
   302  func (ps *PartSet) IsComplete() bool {
   303  	return ps.count == ps.total
   304  }
   305  
   306  func (ps *PartSet) GetReader() io.Reader {
   307  	if !ps.IsComplete() {
   308  		panic("Cannot GetReader() on incomplete PartSet")
   309  	}
   310  	return NewPartSetReader(ps.parts)
   311  }
   312  
   313  type PartSetReader struct {
   314  	i      int
   315  	parts  []*Part
   316  	reader *bytes.Reader
   317  }
   318  
   319  func NewPartSetReader(parts []*Part) *PartSetReader {
   320  	return &PartSetReader{
   321  		i:      0,
   322  		parts:  parts,
   323  		reader: bytes.NewReader(parts[0].Bytes),
   324  	}
   325  }
   326  
   327  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   328  	readerLen := psr.reader.Len()
   329  	if readerLen >= len(p) {
   330  		return psr.reader.Read(p)
   331  	} else if readerLen > 0 {
   332  		n1, err := psr.Read(p[:readerLen])
   333  		if err != nil {
   334  			return n1, err
   335  		}
   336  		n2, err := psr.Read(p[readerLen:])
   337  		return n1 + n2, err
   338  	}
   339  
   340  	psr.i++
   341  	if psr.i >= len(psr.parts) {
   342  		return 0, io.EOF
   343  	}
   344  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   345  	return psr.Read(p)
   346  }
   347  
   348  // StringShort returns a short version of String.
   349  //
   350  // (Count of Total)
   351  func (ps *PartSet) StringShort() string {
   352  	if ps == nil {
   353  		return "nil-PartSet"
   354  	}
   355  	ps.mtx.Lock()
   356  	defer ps.mtx.Unlock()
   357  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   358  }
   359  
   360  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   361  	if ps == nil {
   362  		return []byte("{}"), nil
   363  	}
   364  
   365  	ps.mtx.Lock()
   366  	defer ps.mtx.Unlock()
   367  
   368  	return tmjson.Marshal(struct {
   369  		CountTotal    string         `json:"count/total"`
   370  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   371  	}{
   372  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   373  		ps.partsBitArray,
   374  	})
   375  }