github.com/adoriasoft/tendermint@v0.34.0-dev1.0.20200722151356-96d84601a75a/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/tendermint/tendermint/crypto/merkle"
    10  	"github.com/tendermint/tendermint/libs/bits"
    11  	tmbytes "github.com/tendermint/tendermint/libs/bytes"
    12  	tmjson "github.com/tendermint/tendermint/libs/json"
    13  	tmmath "github.com/tendermint/tendermint/libs/math"
    14  	tmsync "github.com/tendermint/tendermint/libs/sync"
    15  	tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
    16  )
    17  
    18  var (
    19  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    20  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    21  )
    22  
    23  type Part struct {
    24  	Index uint32           `json:"index"`
    25  	Bytes tmbytes.HexBytes `json:"bytes"`
    26  	Proof merkle.Proof     `json:"proof"`
    27  }
    28  
    29  // ValidateBasic performs basic validation.
    30  func (part *Part) ValidateBasic() error {
    31  	if len(part.Bytes) > int(BlockPartSizeBytes) {
    32  		return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    33  	}
    34  	if err := part.Proof.ValidateBasic(); err != nil {
    35  		return fmt.Errorf("wrong Proof: %w", err)
    36  	}
    37  	return nil
    38  }
    39  
    40  func (part *Part) String() string {
    41  	return part.StringIndented("")
    42  }
    43  
    44  func (part *Part) StringIndented(indent string) string {
    45  	return fmt.Sprintf(`Part{#%v
    46  %s  Bytes: %X...
    47  %s  Proof: %v
    48  %s}`,
    49  		part.Index,
    50  		indent, tmbytes.Fingerprint(part.Bytes),
    51  		indent, part.Proof.StringIndented(indent+"  "),
    52  		indent)
    53  }
    54  
    55  func (part *Part) ToProto() (*tmproto.Part, error) {
    56  	if part == nil {
    57  		return nil, errors.New("nil part")
    58  	}
    59  	pb := new(tmproto.Part)
    60  	proof := part.Proof.ToProto()
    61  
    62  	pb.Index = part.Index
    63  	pb.Bytes = part.Bytes
    64  	pb.Proof = *proof
    65  
    66  	return pb, nil
    67  }
    68  
    69  func PartFromProto(pb *tmproto.Part) (*Part, error) {
    70  	if pb == nil {
    71  		return nil, errors.New("nil part")
    72  	}
    73  
    74  	part := new(Part)
    75  	proof, err := merkle.ProofFromProto(&pb.Proof)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	part.Index = pb.Index
    80  	part.Bytes = pb.Bytes
    81  	part.Proof = *proof
    82  
    83  	return part, part.ValidateBasic()
    84  }
    85  
    86  //-------------------------------------
    87  
    88  type PartSetHeader struct {
    89  	Total uint32           `json:"total"`
    90  	Hash  tmbytes.HexBytes `json:"hash"`
    91  }
    92  
    93  func (psh PartSetHeader) String() string {
    94  	return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
    95  }
    96  
    97  func (psh PartSetHeader) IsZero() bool {
    98  	return psh.Total == 0 && len(psh.Hash) == 0
    99  }
   100  
   101  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
   102  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
   103  }
   104  
   105  // ValidateBasic performs basic validation.
   106  func (psh PartSetHeader) ValidateBasic() error {
   107  	// Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
   108  	if err := ValidateHash(psh.Hash); err != nil {
   109  		return fmt.Errorf("wrong Hash: %w", err)
   110  	}
   111  	return nil
   112  }
   113  
   114  // ToProto converts BloPartSetHeaderckID to protobuf
   115  func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
   116  	if psh == nil {
   117  		return tmproto.PartSetHeader{}
   118  	}
   119  
   120  	return tmproto.PartSetHeader{
   121  		Total: psh.Total,
   122  		Hash:  psh.Hash,
   123  	}
   124  }
   125  
   126  // FromProto sets a protobuf PartSetHeader to the given pointer
   127  func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
   128  	if ppsh == nil {
   129  		return nil, errors.New("nil PartSetHeader")
   130  	}
   131  	psh := new(PartSetHeader)
   132  	psh.Total = ppsh.Total
   133  	psh.Hash = ppsh.Hash
   134  
   135  	return psh, psh.ValidateBasic()
   136  }
   137  
   138  //-------------------------------------
   139  
   140  type PartSet struct {
   141  	total uint32
   142  	hash  []byte
   143  
   144  	mtx           tmsync.Mutex
   145  	parts         []*Part
   146  	partsBitArray *bits.BitArray
   147  	count         uint32
   148  }
   149  
   150  // Returns an immutable, full PartSet from the data bytes.
   151  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   152  // CONTRACT: partSize is greater than zero.
   153  func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
   154  	// divide data into 4kb parts.
   155  	total := (uint32(len(data)) + partSize - 1) / partSize
   156  	parts := make([]*Part, total)
   157  	partsBytes := make([][]byte, total)
   158  	partsBitArray := bits.NewBitArray(int(total))
   159  	for i := uint32(0); i < total; i++ {
   160  		part := &Part{
   161  			Index: i,
   162  			Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
   163  		}
   164  		parts[i] = part
   165  		partsBytes[i] = part.Bytes
   166  		partsBitArray.SetIndex(int(i), true)
   167  	}
   168  	// Compute merkle proofs
   169  	root, proofs := merkle.ProofsFromByteSlices(partsBytes)
   170  	for i := uint32(0); i < total; i++ {
   171  		parts[i].Proof = *proofs[i]
   172  	}
   173  	return &PartSet{
   174  		total:         total,
   175  		hash:          root,
   176  		parts:         parts,
   177  		partsBitArray: partsBitArray,
   178  		count:         total,
   179  	}
   180  }
   181  
   182  // Returns an empty PartSet ready to be populated.
   183  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   184  	return &PartSet{
   185  		total:         header.Total,
   186  		hash:          header.Hash,
   187  		parts:         make([]*Part, header.Total),
   188  		partsBitArray: bits.NewBitArray(int(header.Total)),
   189  		count:         0,
   190  	}
   191  }
   192  
   193  func (ps *PartSet) Header() PartSetHeader {
   194  	if ps == nil {
   195  		return PartSetHeader{}
   196  	}
   197  	return PartSetHeader{
   198  		Total: ps.total,
   199  		Hash:  ps.hash,
   200  	}
   201  }
   202  
   203  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   204  	if ps == nil {
   205  		return false
   206  	}
   207  	return ps.Header().Equals(header)
   208  }
   209  
   210  func (ps *PartSet) BitArray() *bits.BitArray {
   211  	ps.mtx.Lock()
   212  	defer ps.mtx.Unlock()
   213  	return ps.partsBitArray.Copy()
   214  }
   215  
   216  func (ps *PartSet) Hash() []byte {
   217  	if ps == nil {
   218  		return nil
   219  	}
   220  	return ps.hash
   221  }
   222  
   223  func (ps *PartSet) HashesTo(hash []byte) bool {
   224  	if ps == nil {
   225  		return false
   226  	}
   227  	return bytes.Equal(ps.hash, hash)
   228  }
   229  
   230  func (ps *PartSet) Count() uint32 {
   231  	if ps == nil {
   232  		return 0
   233  	}
   234  	return ps.count
   235  }
   236  
   237  func (ps *PartSet) Total() uint32 {
   238  	if ps == nil {
   239  		return 0
   240  	}
   241  	return ps.total
   242  }
   243  
   244  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   245  	if ps == nil {
   246  		return false, nil
   247  	}
   248  	ps.mtx.Lock()
   249  	defer ps.mtx.Unlock()
   250  
   251  	// Invalid part index
   252  	if part.Index >= ps.total {
   253  		return false, ErrPartSetUnexpectedIndex
   254  	}
   255  
   256  	// If part already exists, return false.
   257  	if ps.parts[part.Index] != nil {
   258  		return false, nil
   259  	}
   260  
   261  	// Check hash proof
   262  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   263  		return false, ErrPartSetInvalidProof
   264  	}
   265  
   266  	// Add part
   267  	ps.parts[part.Index] = part
   268  	ps.partsBitArray.SetIndex(int(part.Index), true)
   269  	ps.count++
   270  	return true, nil
   271  }
   272  
   273  func (ps *PartSet) GetPart(index int) *Part {
   274  	ps.mtx.Lock()
   275  	defer ps.mtx.Unlock()
   276  	return ps.parts[index]
   277  }
   278  
   279  func (ps *PartSet) IsComplete() bool {
   280  	return ps.count == ps.total
   281  }
   282  
   283  func (ps *PartSet) GetReader() io.Reader {
   284  	if !ps.IsComplete() {
   285  		panic("Cannot GetReader() on incomplete PartSet")
   286  	}
   287  	return NewPartSetReader(ps.parts)
   288  }
   289  
   290  type PartSetReader struct {
   291  	i      int
   292  	parts  []*Part
   293  	reader *bytes.Reader
   294  }
   295  
   296  func NewPartSetReader(parts []*Part) *PartSetReader {
   297  	return &PartSetReader{
   298  		i:      0,
   299  		parts:  parts,
   300  		reader: bytes.NewReader(parts[0].Bytes),
   301  	}
   302  }
   303  
   304  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   305  	readerLen := psr.reader.Len()
   306  	if readerLen >= len(p) {
   307  		return psr.reader.Read(p)
   308  	} else if readerLen > 0 {
   309  		n1, err := psr.Read(p[:readerLen])
   310  		if err != nil {
   311  			return n1, err
   312  		}
   313  		n2, err := psr.Read(p[readerLen:])
   314  		return n1 + n2, err
   315  	}
   316  
   317  	psr.i++
   318  	if psr.i >= len(psr.parts) {
   319  		return 0, io.EOF
   320  	}
   321  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   322  	return psr.Read(p)
   323  }
   324  
   325  func (ps *PartSet) StringShort() string {
   326  	if ps == nil {
   327  		return "nil-PartSet"
   328  	}
   329  	ps.mtx.Lock()
   330  	defer ps.mtx.Unlock()
   331  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   332  }
   333  
   334  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   335  	if ps == nil {
   336  		return []byte("{}"), nil
   337  	}
   338  
   339  	ps.mtx.Lock()
   340  	defer ps.mtx.Unlock()
   341  
   342  	return tmjson.Marshal(struct {
   343  		CountTotal    string         `json:"count/total"`
   344  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   345  	}{
   346  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   347  		ps.partsBitArray,
   348  	})
   349  }