github.com/Finschia/ostracon@v1.1.5/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  
     9  	tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
    10  
    11  	"github.com/Finschia/ostracon/crypto/merkle"
    12  	"github.com/Finschia/ostracon/libs/bits"
    13  	tmbytes "github.com/Finschia/ostracon/libs/bytes"
    14  	tmjson "github.com/Finschia/ostracon/libs/json"
    15  	tmmath "github.com/Finschia/ostracon/libs/math"
    16  	tmsync "github.com/Finschia/ostracon/libs/sync"
    17  )
    18  
    19  var (
    20  	ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index")
    21  	ErrPartSetInvalidProof    = errors.New("error part set invalid proof")
    22  )
    23  
    24  type Part struct {
    25  	Index uint32           `json:"index"`
    26  	Bytes tmbytes.HexBytes `json:"bytes"`
    27  	Proof merkle.Proof     `json:"proof"`
    28  }
    29  
    30  // ValidateBasic performs basic validation.
    31  func (part *Part) ValidateBasic() error {
    32  	if len(part.Bytes) > int(BlockPartSizeBytes) {
    33  		return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes)
    34  	}
    35  	if err := part.Proof.ValidateBasic(); err != nil {
    36  		return fmt.Errorf("wrong Proof: %w", err)
    37  	}
    38  	return nil
    39  }
    40  
    41  // String returns a string representation of Part.
    42  //
    43  // See StringIndented.
    44  func (part *Part) String() string {
    45  	return part.StringIndented("")
    46  }
    47  
    48  // StringIndented returns an indented Part.
    49  //
    50  // See merkle.Proof#StringIndented
    51  func (part *Part) StringIndented(indent string) string {
    52  	return fmt.Sprintf(`Part{#%v
    53  %s  Bytes: %X...
    54  %s  Proof: %v
    55  %s}`,
    56  		part.Index,
    57  		indent, tmbytes.Fingerprint(part.Bytes),
    58  		indent, part.Proof.StringIndented(indent+"  "),
    59  		indent)
    60  }
    61  
    62  func (part *Part) ToProto() (*tmproto.Part, error) {
    63  	if part == nil {
    64  		return nil, errors.New("nil part")
    65  	}
    66  	pb := new(tmproto.Part)
    67  	proof := part.Proof.ToProto()
    68  
    69  	pb.Index = part.Index
    70  	pb.Bytes = part.Bytes
    71  	pb.Proof = *proof
    72  
    73  	return pb, nil
    74  }
    75  
    76  func PartFromProto(pb *tmproto.Part) (*Part, error) {
    77  	if pb == nil {
    78  		return nil, errors.New("nil part")
    79  	}
    80  
    81  	part := new(Part)
    82  	proof, err := merkle.ProofFromProto(&pb.Proof)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	part.Index = pb.Index
    87  	part.Bytes = pb.Bytes
    88  	part.Proof = *proof
    89  
    90  	return part, part.ValidateBasic()
    91  }
    92  
    93  //-------------------------------------
    94  
    95  type PartSetHeader struct {
    96  	Total uint32           `json:"total"`
    97  	Hash  tmbytes.HexBytes `json:"hash"`
    98  }
    99  
   100  // String returns a string representation of PartSetHeader.
   101  //
   102  // 1. total number of parts
   103  // 2. first 6 bytes of the hash
   104  func (psh PartSetHeader) String() string {
   105  	return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash))
   106  }
   107  
   108  func (psh PartSetHeader) IsZero() bool {
   109  	return psh.Total == 0 && len(psh.Hash) == 0
   110  }
   111  
   112  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
   113  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
   114  }
   115  
   116  // ValidateBasic performs basic validation.
   117  func (psh PartSetHeader) ValidateBasic() error {
   118  	// Hash can be empty in case of POLBlockID.PartSetHeader in Proposal.
   119  	if err := ValidateHash(psh.Hash); err != nil {
   120  		return fmt.Errorf("wrong Hash: %w", err)
   121  	}
   122  	return nil
   123  }
   124  
   125  // ToProto converts PartSetHeader to protobuf
   126  func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader {
   127  	if psh == nil {
   128  		return tmproto.PartSetHeader{}
   129  	}
   130  
   131  	return tmproto.PartSetHeader{
   132  		Total: psh.Total,
   133  		Hash:  psh.Hash,
   134  	}
   135  }
   136  
   137  // FromProto sets a protobuf PartSetHeader to the given pointer
   138  func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) {
   139  	if ppsh == nil {
   140  		return nil, errors.New("nil PartSetHeader")
   141  	}
   142  	psh := new(PartSetHeader)
   143  	psh.Total = ppsh.Total
   144  	psh.Hash = ppsh.Hash
   145  
   146  	return psh, psh.ValidateBasic()
   147  }
   148  
   149  //-------------------------------------
   150  
   151  type PartSet struct {
   152  	total uint32
   153  	hash  []byte
   154  
   155  	mtx           tmsync.Mutex
   156  	parts         []*Part
   157  	partsBitArray *bits.BitArray
   158  	count         uint32
   159  	// a count of the total size (in bytes). Used to ensure that the
   160  	// part set doesn't exceed the maximum block bytes
   161  	byteSize int64
   162  }
   163  
   164  // Returns an immutable, full PartSet from the data bytes.
   165  // The data bytes are split into "partSize" chunks, and merkle tree computed.
   166  // CONTRACT: partSize is greater than zero.
   167  func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
   168  	// divide data into parts of size `partSize`
   169  	total := (uint32(len(data)) + partSize - 1) / partSize
   170  	parts := make([]*Part, total)
   171  	partsBytes := make([][]byte, total)
   172  	partsBitArray := bits.NewBitArray(int(total))
   173  	for i := uint32(0); i < total; i++ {
   174  		part := &Part{
   175  			Index: i,
   176  			Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))],
   177  		}
   178  		parts[i] = part
   179  		partsBytes[i] = part.Bytes
   180  		partsBitArray.SetIndex(int(i), true)
   181  	}
   182  	// Compute merkle proofs
   183  	root, proofs := merkle.ProofsFromByteSlices(partsBytes)
   184  	for i := uint32(0); i < total; i++ {
   185  		parts[i].Proof = *proofs[i]
   186  	}
   187  	return &PartSet{
   188  		total:         total,
   189  		hash:          root,
   190  		parts:         parts,
   191  		partsBitArray: partsBitArray,
   192  		count:         total,
   193  		byteSize:      int64(len(data)),
   194  	}
   195  }
   196  
   197  // Returns an empty PartSet ready to be populated.
   198  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   199  	return &PartSet{
   200  		total:         header.Total,
   201  		hash:          header.Hash,
   202  		parts:         make([]*Part, header.Total),
   203  		partsBitArray: bits.NewBitArray(int(header.Total)),
   204  		count:         0,
   205  		byteSize:      0,
   206  	}
   207  }
   208  
   209  func (ps *PartSet) Header() PartSetHeader {
   210  	if ps == nil {
   211  		return PartSetHeader{}
   212  	}
   213  	return PartSetHeader{
   214  		Total: ps.total,
   215  		Hash:  ps.hash,
   216  	}
   217  }
   218  
   219  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   220  	if ps == nil {
   221  		return false
   222  	}
   223  	return ps.Header().Equals(header)
   224  }
   225  
   226  func (ps *PartSet) BitArray() *bits.BitArray {
   227  	ps.mtx.Lock()
   228  	defer ps.mtx.Unlock()
   229  	return ps.partsBitArray.Copy()
   230  }
   231  
   232  func (ps *PartSet) Hash() []byte {
   233  	if ps == nil {
   234  		return merkle.HashFromByteSlices(nil)
   235  	}
   236  	return ps.hash
   237  }
   238  
   239  func (ps *PartSet) HashesTo(hash []byte) bool {
   240  	if ps == nil {
   241  		return false
   242  	}
   243  	return bytes.Equal(ps.hash, hash)
   244  }
   245  
   246  func (ps *PartSet) Count() uint32 {
   247  	if ps == nil {
   248  		return 0
   249  	}
   250  	return ps.count
   251  }
   252  
   253  func (ps *PartSet) ByteSize() int64 {
   254  	if ps == nil {
   255  		return 0
   256  	}
   257  	return ps.byteSize
   258  }
   259  
   260  func (ps *PartSet) Total() uint32 {
   261  	if ps == nil {
   262  		return 0
   263  	}
   264  	return ps.total
   265  }
   266  
   267  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   268  	if ps == nil {
   269  		return false, nil
   270  	}
   271  	ps.mtx.Lock()
   272  	defer ps.mtx.Unlock()
   273  
   274  	// Invalid part index
   275  	if part.Index >= ps.total {
   276  		return false, ErrPartSetUnexpectedIndex
   277  	}
   278  
   279  	// If part already exists, return false.
   280  	if ps.parts[part.Index] != nil {
   281  		return false, nil
   282  	}
   283  
   284  	// Check hash proof
   285  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   286  		return false, ErrPartSetInvalidProof
   287  	}
   288  
   289  	// Add part
   290  	ps.parts[part.Index] = part
   291  	ps.partsBitArray.SetIndex(int(part.Index), true)
   292  	ps.count++
   293  	ps.byteSize += int64(len(part.Bytes))
   294  	return true, nil
   295  }
   296  
   297  func (ps *PartSet) GetPart(index int) *Part {
   298  	ps.mtx.Lock()
   299  	defer ps.mtx.Unlock()
   300  	return ps.parts[index]
   301  }
   302  
   303  func (ps *PartSet) IsComplete() bool {
   304  	return ps.count == ps.total
   305  }
   306  
   307  func (ps *PartSet) GetReader() io.Reader {
   308  	if !ps.IsComplete() {
   309  		panic("Cannot GetReader() on incomplete PartSet")
   310  	}
   311  	return NewPartSetReader(ps.parts)
   312  }
   313  
   314  type PartSetReader struct {
   315  	i      int
   316  	parts  []*Part
   317  	reader *bytes.Reader
   318  }
   319  
   320  func NewPartSetReader(parts []*Part) *PartSetReader {
   321  	return &PartSetReader{
   322  		i:      0,
   323  		parts:  parts,
   324  		reader: bytes.NewReader(parts[0].Bytes),
   325  	}
   326  }
   327  
   328  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   329  	readerLen := psr.reader.Len()
   330  	if readerLen >= len(p) {
   331  		return psr.reader.Read(p)
   332  	} else if readerLen > 0 {
   333  		n1, err := psr.Read(p[:readerLen])
   334  		if err != nil {
   335  			return n1, err
   336  		}
   337  		n2, err := psr.Read(p[readerLen:])
   338  		return n1 + n2, err
   339  	}
   340  
   341  	psr.i++
   342  	if psr.i >= len(psr.parts) {
   343  		return 0, io.EOF
   344  	}
   345  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   346  	return psr.Read(p)
   347  }
   348  
   349  // StringShort returns a short version of String.
   350  //
   351  // (Count of Total)
   352  func (ps *PartSet) StringShort() string {
   353  	if ps == nil {
   354  		return "nil-PartSet"
   355  	}
   356  	ps.mtx.Lock()
   357  	defer ps.mtx.Unlock()
   358  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   359  }
   360  
   361  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   362  	if ps == nil {
   363  		return []byte("{}"), nil
   364  	}
   365  
   366  	ps.mtx.Lock()
   367  	defer ps.mtx.Unlock()
   368  
   369  	return tmjson.Marshal(struct {
   370  		CountTotal    string         `json:"count/total"`
   371  		PartsBitArray *bits.BitArray `json:"parts_bit_array"`
   372  	}{
   373  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   374  		ps.partsBitArray,
   375  	})
   376  }