github.com/evdatsion/aphelion-dpos-bft@v0.32.1/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/pkg/errors"
    10  
    11  	"github.com/evdatsion/aphelion-dpos-bft/crypto/merkle"
    12  	cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common"
    13  )
    14  
    15  var (
    16  	ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
    17  	ErrPartSetInvalidProof    = errors.New("Error part set invalid proof")
    18  )
    19  
    20  type Part struct {
    21  	Index int                `json:"index"`
    22  	Bytes cmn.HexBytes       `json:"bytes"`
    23  	Proof merkle.SimpleProof `json:"proof"`
    24  }
    25  
    26  // ValidateBasic performs basic validation.
    27  func (part *Part) ValidateBasic() error {
    28  	if part.Index < 0 {
    29  		return errors.New("Negative Index")
    30  	}
    31  	if len(part.Bytes) > BlockPartSizeBytes {
    32  		return fmt.Errorf("Too big (max: %d)", BlockPartSizeBytes)
    33  	}
    34  	return nil
    35  }
    36  
    37  func (part *Part) String() string {
    38  	return part.StringIndented("")
    39  }
    40  
    41  func (part *Part) StringIndented(indent string) string {
    42  	return fmt.Sprintf(`Part{#%v
    43  %s  Bytes: %X...
    44  %s  Proof: %v
    45  %s}`,
    46  		part.Index,
    47  		indent, cmn.Fingerprint(part.Bytes),
    48  		indent, part.Proof.StringIndented(indent+"  "),
    49  		indent)
    50  }
    51  
    52  //-------------------------------------
    53  
    54  type PartSetHeader struct {
    55  	Total int          `json:"total"`
    56  	Hash  cmn.HexBytes `json:"hash"`
    57  }
    58  
    59  func (psh PartSetHeader) String() string {
    60  	return fmt.Sprintf("%v:%X", psh.Total, cmn.Fingerprint(psh.Hash))
    61  }
    62  
    63  func (psh PartSetHeader) IsZero() bool {
    64  	return psh.Total == 0 && len(psh.Hash) == 0
    65  }
    66  
    67  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
    68  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
    69  }
    70  
    71  // ValidateBasic performs basic validation.
    72  func (psh PartSetHeader) ValidateBasic() error {
    73  	if psh.Total < 0 {
    74  		return errors.New("Negative Total")
    75  	}
    76  	// Hash can be empty in case of POLBlockID.PartsHeader in Proposal.
    77  	if err := ValidateHash(psh.Hash); err != nil {
    78  		return errors.Wrap(err, "Wrong Hash")
    79  	}
    80  	return nil
    81  }
    82  
    83  //-------------------------------------
    84  
    85  type PartSet struct {
    86  	total int
    87  	hash  []byte
    88  
    89  	mtx           sync.Mutex
    90  	parts         []*Part
    91  	partsBitArray *cmn.BitArray
    92  	count         int
    93  }
    94  
    95  // Returns an immutable, full PartSet from the data bytes.
    96  // The data bytes are split into "partSize" chunks, and merkle tree computed.
    97  func NewPartSetFromData(data []byte, partSize int) *PartSet {
    98  	// divide data into 4kb parts.
    99  	total := (len(data) + partSize - 1) / partSize
   100  	parts := make([]*Part, total)
   101  	partsBytes := make([][]byte, total)
   102  	partsBitArray := cmn.NewBitArray(total)
   103  	for i := 0; i < total; i++ {
   104  		part := &Part{
   105  			Index: i,
   106  			Bytes: data[i*partSize : cmn.MinInt(len(data), (i+1)*partSize)],
   107  		}
   108  		parts[i] = part
   109  		partsBytes[i] = part.Bytes
   110  		partsBitArray.SetIndex(i, true)
   111  	}
   112  	// Compute merkle proofs
   113  	root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes)
   114  	for i := 0; i < total; i++ {
   115  		parts[i].Proof = *proofs[i]
   116  	}
   117  	return &PartSet{
   118  		total:         total,
   119  		hash:          root,
   120  		parts:         parts,
   121  		partsBitArray: partsBitArray,
   122  		count:         total,
   123  	}
   124  }
   125  
   126  // Returns an empty PartSet ready to be populated.
   127  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   128  	return &PartSet{
   129  		total:         header.Total,
   130  		hash:          header.Hash,
   131  		parts:         make([]*Part, header.Total),
   132  		partsBitArray: cmn.NewBitArray(header.Total),
   133  		count:         0,
   134  	}
   135  }
   136  
   137  func (ps *PartSet) Header() PartSetHeader {
   138  	if ps == nil {
   139  		return PartSetHeader{}
   140  	}
   141  	return PartSetHeader{
   142  		Total: ps.total,
   143  		Hash:  ps.hash,
   144  	}
   145  }
   146  
   147  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   148  	if ps == nil {
   149  		return false
   150  	}
   151  	return ps.Header().Equals(header)
   152  }
   153  
   154  func (ps *PartSet) BitArray() *cmn.BitArray {
   155  	ps.mtx.Lock()
   156  	defer ps.mtx.Unlock()
   157  	return ps.partsBitArray.Copy()
   158  }
   159  
   160  func (ps *PartSet) Hash() []byte {
   161  	if ps == nil {
   162  		return nil
   163  	}
   164  	return ps.hash
   165  }
   166  
   167  func (ps *PartSet) HashesTo(hash []byte) bool {
   168  	if ps == nil {
   169  		return false
   170  	}
   171  	return bytes.Equal(ps.hash, hash)
   172  }
   173  
   174  func (ps *PartSet) Count() int {
   175  	if ps == nil {
   176  		return 0
   177  	}
   178  	return ps.count
   179  }
   180  
   181  func (ps *PartSet) Total() int {
   182  	if ps == nil {
   183  		return 0
   184  	}
   185  	return ps.total
   186  }
   187  
   188  func (ps *PartSet) AddPart(part *Part) (bool, error) {
   189  	if ps == nil {
   190  		return false, nil
   191  	}
   192  	ps.mtx.Lock()
   193  	defer ps.mtx.Unlock()
   194  
   195  	// Invalid part index
   196  	if part.Index >= ps.total {
   197  		return false, ErrPartSetUnexpectedIndex
   198  	}
   199  
   200  	// If part already exists, return false.
   201  	if ps.parts[part.Index] != nil {
   202  		return false, nil
   203  	}
   204  
   205  	// Check hash proof
   206  	if part.Proof.Verify(ps.Hash(), part.Bytes) != nil {
   207  		return false, ErrPartSetInvalidProof
   208  	}
   209  
   210  	// Add part
   211  	ps.parts[part.Index] = part
   212  	ps.partsBitArray.SetIndex(part.Index, true)
   213  	ps.count++
   214  	return true, nil
   215  }
   216  
   217  func (ps *PartSet) GetPart(index int) *Part {
   218  	ps.mtx.Lock()
   219  	defer ps.mtx.Unlock()
   220  	return ps.parts[index]
   221  }
   222  
   223  func (ps *PartSet) IsComplete() bool {
   224  	return ps.count == ps.total
   225  }
   226  
   227  func (ps *PartSet) GetReader() io.Reader {
   228  	if !ps.IsComplete() {
   229  		panic("Cannot GetReader() on incomplete PartSet")
   230  	}
   231  	return NewPartSetReader(ps.parts)
   232  }
   233  
   234  type PartSetReader struct {
   235  	i      int
   236  	parts  []*Part
   237  	reader *bytes.Reader
   238  }
   239  
   240  func NewPartSetReader(parts []*Part) *PartSetReader {
   241  	return &PartSetReader{
   242  		i:      0,
   243  		parts:  parts,
   244  		reader: bytes.NewReader(parts[0].Bytes),
   245  	}
   246  }
   247  
   248  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   249  	readerLen := psr.reader.Len()
   250  	if readerLen >= len(p) {
   251  		return psr.reader.Read(p)
   252  	} else if readerLen > 0 {
   253  		n1, err := psr.Read(p[:readerLen])
   254  		if err != nil {
   255  			return n1, err
   256  		}
   257  		n2, err := psr.Read(p[readerLen:])
   258  		return n1 + n2, err
   259  	}
   260  
   261  	psr.i++
   262  	if psr.i >= len(psr.parts) {
   263  		return 0, io.EOF
   264  	}
   265  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   266  	return psr.Read(p)
   267  }
   268  
   269  func (ps *PartSet) StringShort() string {
   270  	if ps == nil {
   271  		return "nil-PartSet"
   272  	}
   273  	ps.mtx.Lock()
   274  	defer ps.mtx.Unlock()
   275  	return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   276  }
   277  
   278  func (ps *PartSet) MarshalJSON() ([]byte, error) {
   279  	if ps == nil {
   280  		return []byte("{}"), nil
   281  	}
   282  
   283  	ps.mtx.Lock()
   284  	defer ps.mtx.Unlock()
   285  
   286  	return cdc.MarshalJSON(struct {
   287  		CountTotal    string        `json:"count/total"`
   288  		PartsBitArray *cmn.BitArray `json:"parts_bit_array"`
   289  	}{
   290  		fmt.Sprintf("%d/%d", ps.Count(), ps.Total()),
   291  		ps.partsBitArray,
   292  	})
   293  }