github.com/franono/tendermint@v0.32.2-0.20200527150959-749313264ce9/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"sync"
     9  
    10  	"github.com/franono/tendermint/crypto/merkle"
    11  	"github.com/franono/tendermint/libs/bits"
    12  	tmbytes "github.com/franono/tendermint/libs/bytes"
    13  	tmmath "github.com/franono/tendermint/libs/math"
    14  	tmproto "github.com/franono/tendermint/proto/types"
    15  )
    16  
    17  var (
    18  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    19  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    20  )
    21  
    22  type Part struct {
    23  	Index int                `json:"index"`
    24  	Bytes tmbytes.HexBytes   `json:"bytes"`
    25  	Proof merkle.SimpleProof `json:"proof"`
    26  }
    27  
    28  // ValidateBasic performs basic validation.
    29  func (part *Part) ValidateBasic() error {
    30  	if part.Index < 0 {
    31  		return errors.New("negative Index")
    32  	}
    33  	if len(part.Bytes) > BlockPartSizeBytes {
    34  		return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    35  	}
    36  	if err := part.Proof.ValidateBasic(); err != nil {
    37  		return fmt.Errorf("wrong Proof: %w", err)
    38  	}
    39  	return nil
    40  }
    41  
    42  func (part *Part) String() string {
    43  	return part.StringIndented("")
    44  }
    45  
    46  func (part *Part) StringIndented(indent string) string {
    47  	return fmt.Sprintf(`Part{#%v
    48  %s  Bytes: %X...
    49  %s  Proof: %v
    50  %s}`,
    51  		part.Index,
    52  		indent, tmbytes.Fingerprint(part.Bytes),
    53  		indent, part.Proof.StringIndented(indent+"  "),
    54  		indent)
    55  }
    56  
    57  //-------------------------------------
    58  
    59  type PartSetHeader struct {
    60  	Total int              `json:"total"`
    61  	Hash  tmbytes.HexBytes `json:"hash"`
    62  }
    63  
    64  func (psh PartSetHeader) String() string {
    65  	return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
    66  }
    67  
    68  func (psh PartSetHeader) IsZero() bool {
    69  	return psh.Total == 0 && len(psh.Hash) == 0
    70  }
    71  
    72  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
    73  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
    74  }
    75  
    76  // ValidateBasic performs basic validation.
    77  func (psh PartSetHeader) ValidateBasic() error {
    78  	if psh.Total < 0 {
    79  		return errors.New("negative Total")
    80  	}
    81  	// Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
    82  	if err := ValidateHash(psh.Hash); err != nil {
    83  		return fmt.Errorf("wrong Hash: %w", err)
    84  	}
    85  	return nil
    86  }
    87  
    88  // ToProto converts BloPartSetHeaderckID to protobuf
    89  func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
    90  	if psh == nil {
    91  		return tmproto.PartSetHeader{}
    92  	}
    93  
    94  	return tmproto.PartSetHeader{
    95  		Total: int64(psh.Total),
    96  		Hash:  psh.Hash,
    97  	}
    98  }
    99  
   100  // FromProto sets a protobuf PartSetHeader to the given pointer
   101  func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
   102  	if ppsh == nil {
   103  		return nil, errors.New("nil PartSetHeader")
   104  	}
   105  	psh := new(PartSetHeader)
   106  	psh.Total = int(ppsh.Total)
   107  	psh.Hash = ppsh.Hash
   108  
   109  	return psh, psh.ValidateBasic()
   110  }
   111  
   112  //-------------------------------------
   113  
   114  type PartSet struct {
   115  	total int
   116  	hash  []byte
   117  
   118  	mtx           sync.Mutex
   119  	parts         []*Part
   120  	partsBitArray *bits.BitArray
   121  	count         int
   122  }
   123  
   124  // Returns an immutable, full PartSet from the data bytes.
   125  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   126  func NewPartSetFromData(data []byte, partSize int) *PartSet {
   127  	// divide data into 4kb parts.
   128  	total := (len(data) + partSize - 1) / partSize
   129  	parts := make([]*Part, total)
   130  	partsBytes := make([][]byte, total)
   131  	partsBitArray := bits.NewBitArray(total)
   132  	for i := 0; i < total; i++ {
   133  		part := &Part{
   134  			Index: i,
   135  			Bytes: data[i*partSize : tmmath.MinInt(len(data), (i+1)*partSize)],
   136  		}
   137  		parts[i] = part
   138  		partsBytes[i] = part.Bytes
   139  		partsBitArray.SetIndex(i, true)
   140  	}
   141  	// Compute merkle proofs
   142  	root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
   143  	for i := 0; i < total; i++ {
   144  		parts[i].Proof = *proofs[i]
   145  	}
   146  	return &PartSet{
   147  		total:         total,
   148  		hash:          root,
   149  		parts:         parts,
   150  		partsBitArray: partsBitArray,
   151  		count:         total,
   152  	}
   153  }
   154  
   155  // Returns an empty PartSet ready to be populated.
   156  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   157  	return &PartSet{
   158  		total:         header.Total,
   159  		hash:          header.Hash,
   160  		parts:         make([]*Part, header.Total),
   161  		partsBitArray: bits.NewBitArray(header.Total),
   162  		count:         0,
   163  	}
   164  }
   165  
   166  func (ps *PartSet) Header() PartSetHeader {
   167  	if ps == nil {
   168  		return PartSetHeader{}
   169  	}
   170  	return PartSetHeader{
   171  		Total: ps.total,
   172  		Hash:  ps.hash,
   173  	}
   174  }
   175  
   176  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   177  	if ps == nil {
   178  		return false
   179  	}
   180  	return ps.Header().Equals(header)
   181  }
   182  
   183  func (ps *PartSet) BitArray() *bits.BitArray {
   184  	ps.mtx.Lock()
   185  	defer ps.mtx.Unlock()
   186  	return ps.partsBitArray.Copy()
   187  }
   188  
   189  func (ps *PartSet) Hash() []byte {
   190  	if ps == nil {
   191  		return nil
   192  	}
   193  	return ps.hash
   194  }
   195  
   196  func (ps *PartSet) HashesTo(hash []byte) bool {
   197  	if ps == nil {
   198  		return false
   199  	}
   200  	return bytes.Equal(ps.hash, hash)
   201  }
   202  
   203  func (ps *PartSet) Count() int {
   204  	if ps == nil {
   205  		return 0
   206  	}
   207  	return ps.count
   208  }
   209  
   210  func (ps *PartSet) Total() int {
   211  	if ps == nil {
   212  		return 0
   213  	}
   214  	return ps.total
   215  }
   216  
   217  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   218  	if ps == nil {
   219  		return false, nil
   220  	}
   221  	ps.mtx.Lock()
   222  	defer ps.mtx.Unlock()
   223  
   224  	// Invalid part index
   225  	if part.Index >= ps.total {
   226  		return false, ErrPartSetUnexpectedIndex
   227  	}
   228  
   229  	// If part already exists, return false.
   230  	if ps.parts[part.Index] != nil {
   231  		return false, nil
   232  	}
   233  
   234  	// Check hash proof
   235  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   236  		return false, ErrPartSetInvalidProof
   237  	}
   238  
   239  	// Add part
   240  	ps.parts[part.Index] = part
   241  	ps.partsBitArray.SetIndex(part.Index, true)
   242  	ps.count++
   243  	return true, nil
   244  }
   245  
   246  func (ps *PartSet) GetPart(index int) *Part {
   247  	ps.mtx.Lock()
   248  	defer ps.mtx.Unlock()
   249  	return ps.parts[index]
   250  }
   251  
   252  func (ps *PartSet) IsComplete() bool {
   253  	return ps.count == ps.total
   254  }
   255  
   256  func (ps *PartSet) GetReader() io.Reader {
   257  	if !ps.IsComplete() {
   258  		panic("Cannot GetReader() on incomplete PartSet")
   259  	}
   260  	return NewPartSetReader(ps.parts)
   261  }
   262  
   263  type PartSetReader struct {
   264  	i      int
   265  	parts  []*Part
   266  	reader *bytes.Reader
   267  }
   268  
   269  func NewPartSetReader(parts []*Part) *PartSetReader {
   270  	return &PartSetReader{
   271  		i:      0,
   272  		parts:  parts,
   273  		reader: bytes.NewReader(parts[0].Bytes),
   274  	}
   275  }
   276  
   277  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   278  	readerLen := psr.reader.Len()
   279  	if readerLen >= len(p) {
   280  		return psr.reader.Read(p)
   281  	} else if readerLen > 0 {
   282  		n1, err := psr.Read(p[:readerLen])
   283  		if err != nil {
   284  			return n1, err
   285  		}
   286  		n2, err := psr.Read(p[readerLen:])
   287  		return n1 + n2, err
   288  	}
   289  
   290  	psr.i++
   291  	if psr.i >= len(psr.parts) {
   292  		return 0, io.EOF
   293  	}
   294  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   295  	return psr.Read(p)
   296  }
   297  
   298  func (ps *PartSet) StringShort() string {
   299  	if ps == nil {
   300  		return "nil-PartSet"
   301  	}
   302  	ps.mtx.Lock()
   303  	defer ps.mtx.Unlock()
   304  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   305  }
   306  
   307  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   308  	if ps == nil {
   309  		return []byte("{}"), nil
   310  	}
   311  
   312  	ps.mtx.Lock()
   313  	defer ps.mtx.Unlock()
   314  
   315  	return cdc.MarshalJSON(struct {
   316  		CountTotal    string         `json:"count/total"`
   317  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   318  	}{
   319  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   320  		ps.partsBitArray,
   321  	})
   322  }