github.com/okex/exchain@v1.8.0/libs/tendermint/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/tendermint/go-amino"
    10  
    11  	"github.com/pkg/errors"
    12  
    13  	"github.com/okex/exchain/libs/tendermint/crypto/merkle"
    14  	"github.com/okex/exchain/libs/tendermint/libs/bits"
    15  	tmbytes "github.com/okex/exchain/libs/tendermint/libs/bytes"
    16  	tmmath "github.com/okex/exchain/libs/tendermint/libs/math"
    17  	tmproto "github.com/okex/exchain/libs/tendermint/proto/types"
    18  )
    19  
    20  var (
    21  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    22  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    23  )
    24  
    25  type Part struct {
    26  	Index int                `json:"index"`
    27  	Bytes tmbytes.HexBytes   `json:"bytes"`
    28  	Proof merkle.SimpleProof `json:"proof"`
    29  }
    30  
    31  func (part *Part) UnmarshalFromAmino(cdc *amino.Codec, data []byte) error {
    32  	var dataLen uint64 = 0
    33  	var subData []byte
    34  
    35  	for {
    36  		data = data[dataLen:]
    37  
    38  		if len(data) == 0 {
    39  			break
    40  		}
    41  
    42  		pos, aminoType, err := amino.ParseProtoPosAndTypeMustOneByte(data[0])
    43  		if err != nil {
    44  			return err
    45  		}
    46  		data = data[1:]
    47  
    48  		if aminoType == amino.Typ3_ByteLength {
    49  			var n int
    50  			dataLen, n, err = amino.DecodeUvarint(data)
    51  			if err != nil {
    52  				return err
    53  			}
    54  
    55  			data = data[n:]
    56  			if len(data) < int(dataLen) {
    57  				return fmt.Errorf("not enough data for %s, need %d, have %d", aminoType, dataLen, len(data))
    58  			}
    59  			subData = data[:dataLen]
    60  		}
    61  
    62  		switch pos {
    63  		case 1:
    64  			uvint, n, err := amino.DecodeUvarint(data)
    65  			if err != nil {
    66  				return err
    67  			}
    68  			part.Index = int(uvint)
    69  			dataLen = uint64(n)
    70  		case 2:
    71  			part.Bytes = make([]byte, dataLen)
    72  			copy(part.Bytes, subData)
    73  		case 3:
    74  			err = part.Proof.UnmarshalFromAmino(cdc, subData)
    75  			if err != nil {
    76  				return err
    77  			}
    78  		default:
    79  			return fmt.Errorf("unexpect feild num %d", pos)
    80  		}
    81  	}
    82  	return nil
    83  }
    84  
    85  // ValidateBasic performs basic validation.
    86  func (part *Part) ValidateBasic() error {
    87  	if part.Index < 0 {
    88  		return errors.New("negative Index")
    89  	}
    90  	if len(part.Bytes) > BlockPartSizeBytes {
    91  		return errors.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    92  	}
    93  	if err := part.Proof.ValidateBasic(); err != nil {
    94  		return errors.Wrap(err, "wrong Proof")
    95  	}
    96  	return nil
    97  }
    98  
    99  func (part *Part) String() string {
   100  	return part.StringIndented("")
   101  }
   102  
   103  func (part *Part) StringIndented(indent string) string {
   104  	return fmt.Sprintf(`Part{#%v
   105  %s  Bytes: %X...
   106  %s  Proof: %v
   107  %s}`,
   108  		part.Index,
   109  		indent, tmbytes.Fingerprint(part.Bytes),
   110  		indent, part.Proof.StringIndented(indent+"  "),
   111  		indent)
   112  }
   113  
   114  //-------------------------------------
   115  
   116  type PartSetHeader struct {
   117  	Total int              `json:"total"`
   118  	Hash  tmbytes.HexBytes `json:"hash"`
   119  }
   120  
   121  func (psh PartSetHeader) AminoSize() int {
   122  	var size int
   123  	if psh.Total != 0 {
   124  		size += 1 + amino.UvarintSize(uint64(psh.Total))
   125  	}
   126  	if len(psh.Hash) != 0 {
   127  		size += 1 + amino.UvarintSize(uint64(len(psh.Hash))) + len(psh.Hash)
   128  	}
   129  	return size
   130  }
   131  
   132  func (psh *PartSetHeader) UnmarshalFromAmino(_ *amino.Codec, data []byte) error {
   133  	var dataLen uint64 = 0
   134  	var subData []byte
   135  
   136  	for {
   137  		data = data[dataLen:]
   138  
   139  		if len(data) == 0 {
   140  			break
   141  		}
   142  
   143  		pos, aminoType, err := amino.ParseProtoPosAndTypeMustOneByte(data[0])
   144  		if err != nil {
   145  			return err
   146  		}
   147  		data = data[1:]
   148  
   149  		if aminoType == amino.Typ3_ByteLength {
   150  			var n int
   151  			dataLen, n, err = amino.DecodeUvarint(data)
   152  			if err != nil {
   153  				return err
   154  			}
   155  
   156  			data = data[n:]
   157  			if len(data) < int(dataLen) {
   158  				return fmt.Errorf("not enough data for %s, need %d, have %d", aminoType, dataLen, len(data))
   159  			}
   160  			subData = data[:dataLen]
   161  		}
   162  
   163  		switch pos {
   164  		case 1:
   165  			var n int
   166  			var uvint uint64
   167  			uvint, n, err = amino.DecodeUvarint(data)
   168  			if err != nil {
   169  				return err
   170  			}
   171  			psh.Total = int(uvint)
   172  			dataLen = uint64(n)
   173  		case 2:
   174  			psh.Hash = make([]byte, dataLen)
   175  			copy(psh.Hash, subData)
   176  		default:
   177  			return fmt.Errorf("unexpect feild num %d", pos)
   178  		}
   179  	}
   180  	return nil
   181  }
   182  
   183  func (psh PartSetHeader) String() string {
   184  	return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
   185  }
   186  
   187  func (psh PartSetHeader) IsZero() bool {
   188  	return psh.Total == 0 && len(psh.Hash) == 0
   189  }
   190  
   191  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
   192  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
   193  }
   194  
   195  // ValidateBasic performs basic validation.
   196  func (psh PartSetHeader) ValidateBasic() error {
   197  	if psh.Total < 0 {
   198  		return errors.New("negative Total")
   199  	}
   200  	// Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
   201  	if err := ValidateHash(psh.Hash); err != nil {
   202  		return errors.Wrap(err, "Wrong Hash")
   203  	}
   204  	return nil
   205  }
   206  
   207  // ToProto converts BloPartSetHeaderckID to protobuf
   208  func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
   209  	if psh == nil {
   210  		return tmproto.PartSetHeader{}
   211  	}
   212  
   213  	return tmproto.PartSetHeader{
   214  		Total: int64(psh.Total),
   215  		Hash:  psh.Hash,
   216  	}
   217  }
   218  
   219  func (psh *PartSetHeader) ToIBCProto() tmproto.PartSetHeader {
   220  	if psh == nil {
   221  		return tmproto.PartSetHeader{}
   222  	}
   223  	return tmproto.PartSetHeader{
   224  		Total: int64(psh.Total),
   225  		Hash:  psh.Hash,
   226  	}
   227  }
   228  
   229  // FromProto sets a protobuf PartSetHeader to the given pointer
   230  func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
   231  	if ppsh == nil {
   232  		return nil, errors.New("nil PartSetHeader")
   233  	}
   234  	psh := new(PartSetHeader)
   235  	psh.Total = int(ppsh.Total)
   236  	psh.Hash = ppsh.Hash
   237  
   238  	return psh, psh.ValidateBasic()
   239  }
   240  
   241  //-------------------------------------
   242  
   243  type PartSet struct {
   244  	total int
   245  	hash  []byte
   246  
   247  	mtx           sync.Mutex
   248  	parts         []*Part
   249  	partsBitArray *bits.BitArray
   250  	count         int
   251  }
   252  
   253  // Returns an immutable, full PartSet from the data bytes.
   254  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   255  func NewPartSetFromData(data []byte, partSize int) *PartSet {
   256  	// divide data into 4kb parts.
   257  	total := (len(data) + partSize - 1) / partSize
   258  	parts := make([]*Part, total)
   259  	partsBytes := make([][]byte, total)
   260  	partsBitArray := bits.NewBitArray(total)
   261  	for i := 0; i < total; i++ {
   262  		part := &Part{
   263  			Index: i,
   264  			Bytes: data[i*partSize : tmmath.MinInt(len(data), (i+1)*partSize)],
   265  		}
   266  		parts[i] = part
   267  		partsBytes[i] = part.Bytes
   268  		partsBitArray.SetIndex(i, true)
   269  	}
   270  	// Compute merkle proofs
   271  	root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
   272  	for i := 0; i < total; i++ {
   273  		parts[i].Proof = *proofs[i]
   274  	}
   275  	return &PartSet{
   276  		total:         total,
   277  		hash:          root,
   278  		parts:         parts,
   279  		partsBitArray: partsBitArray,
   280  		count:         total,
   281  	}
   282  }
   283  
   284  // Returns an empty PartSet ready to be populated.
   285  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   286  	return &PartSet{
   287  		total:         header.Total,
   288  		hash:          header.Hash,
   289  		parts:         make([]*Part, header.Total),
   290  		partsBitArray: bits.NewBitArray(header.Total),
   291  		count:         0,
   292  	}
   293  }
   294  
   295  func (ps *PartSet) Header() PartSetHeader {
   296  	if ps == nil {
   297  		return PartSetHeader{}
   298  	}
   299  	return PartSetHeader{
   300  		Total: ps.total,
   301  		Hash:  ps.hash,
   302  	}
   303  }
   304  
   305  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   306  	if ps == nil {
   307  		return false
   308  	}
   309  	return ps.Header().Equals(header)
   310  }
   311  
   312  func (ps *PartSet) BitArray() *bits.BitArray {
   313  	ps.mtx.Lock()
   314  	defer ps.mtx.Unlock()
   315  	return ps.partsBitArray.Copy()
   316  }
   317  
   318  func (ps *PartSet) Hash() []byte {
   319  	if ps == nil {
   320  		return nil
   321  	}
   322  	return ps.hash
   323  }
   324  
   325  func (ps *PartSet) HashesTo(hash []byte) bool {
   326  	if ps == nil {
   327  		return false
   328  	}
   329  	return bytes.Equal(ps.hash, hash)
   330  }
   331  
   332  func (ps *PartSet) Count() int {
   333  	if ps == nil {
   334  		return 0
   335  	}
   336  	return ps.count
   337  }
   338  
   339  func (ps *PartSet) Total() int {
   340  	if ps == nil {
   341  		return 0
   342  	}
   343  	return ps.total
   344  }
   345  
   346  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   347  	if ps == nil {
   348  		return false, nil
   349  	}
   350  	ps.mtx.Lock()
   351  	defer ps.mtx.Unlock()
   352  
   353  	// Invalid part index
   354  	if part.Index >= ps.total {
   355  		return false, ErrPartSetUnexpectedIndex
   356  	}
   357  
   358  	// If part already exists, return false.
   359  	if ps.parts[part.Index] != nil {
   360  		return false, nil
   361  	}
   362  
   363  	// Check hash proof
   364  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   365  		return false, ErrPartSetInvalidProof
   366  	}
   367  
   368  	// Add part
   369  	ps.parts[part.Index] = part
   370  	ps.partsBitArray.SetIndex(part.Index, true)
   371  	ps.count++
   372  	return true, nil
   373  }
   374  
   375  func (ps *PartSet) GetPart(index int) *Part {
   376  	ps.mtx.Lock()
   377  	defer ps.mtx.Unlock()
   378  	return ps.parts[index]
   379  }
   380  
   381  func (ps *PartSet) IsComplete() bool {
   382  	return ps.count == ps.total
   383  }
   384  
   385  func (ps *PartSet) GetReader() io.Reader {
   386  	if !ps.IsComplete() {
   387  		panic("Cannot GetReader() on incomplete PartSet")
   388  	}
   389  	return NewPartSetReader(ps.parts)
   390  }
   391  
   392  type PartSetReader struct {
   393  	i      int
   394  	parts  []*Part
   395  	reader *bytes.Reader
   396  }
   397  
   398  func NewPartSetReader(parts []*Part) *PartSetReader {
   399  	return &PartSetReader{
   400  		i:      0,
   401  		parts:  parts,
   402  		reader: bytes.NewReader(parts[0].Bytes),
   403  	}
   404  }
   405  
   406  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   407  	readerLen := psr.reader.Len()
   408  	if readerLen >= len(p) {
   409  		return psr.reader.Read(p)
   410  	} else if readerLen > 0 {
   411  		n1, err := psr.Read(p[:readerLen])
   412  		if err != nil {
   413  			return n1, err
   414  		}
   415  		n2, err := psr.Read(p[readerLen:])
   416  		return n1 + n2, err
   417  	}
   418  
   419  	psr.i++
   420  	if psr.i >= len(psr.parts) {
   421  		return 0, io.EOF
   422  	}
   423  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   424  	return psr.Read(p)
   425  }
   426  
   427  func (ps *PartSet) StringShort() string {
   428  	if ps == nil {
   429  		return "nil-PartSet"
   430  	}
   431  	ps.mtx.Lock()
   432  	defer ps.mtx.Unlock()
   433  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   434  }
   435  
   436  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   437  	if ps == nil {
   438  		return []byte("{}"), nil
   439  	}
   440  
   441  	ps.mtx.Lock()
   442  	defer ps.mtx.Unlock()
   443  
   444  	return cdc.MarshalJSON(struct {
   445  		CountTotal    string         `json:"count/total"`
   446  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   447  	}{
   448  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   449  		ps.partsBitArray,
   450  	})
   451  }