github.com/intfoundation/intchain@v0.0.0-20220727031208-4316ad31ca73/consensus/ipbft/types/part_set.go (about)

     1  package types
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"sync"
     9  
    10  	"golang.org/x/crypto/ripemd160"
    11  
    12  	. "github.com/intfoundation/go-common"
    13  	"github.com/intfoundation/go-merkle"
    14  	"github.com/intfoundation/go-wire"
    15  )
    16  
    17  var (
    18  	ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index")
    19  	ErrPartSetInvalidProof    = errors.New("Error part set invalid proof")
    20  )
    21  
    22  type Part struct {
    23  	Index int                `json:"index"`
    24  	Bytes []byte             `json:"bytes"`
    25  	Proof merkle.SimpleProof `json:"proof"`
    26  
    27  	// Cache
    28  	hash []byte
    29  }
    30  
    31  func (part *Part) Hash() []byte {
    32  	if part.hash != nil {
    33  		return part.hash
    34  	} else {
    35  		hasher := ripemd160.New()
    36  		hasher.Write(part.Bytes) // doesn't err
    37  		part.hash = hasher.Sum(nil)
    38  		return part.hash
    39  	}
    40  }
    41  
    42  func (part *Part) String() string {
    43  	return part.StringIndented("")
    44  }
    45  
    46  func (part *Part) StringIndented(indent string) string {
    47  	return fmt.Sprintf(`Part{#%v
    48  %s  Bytes: %X...
    49  %s  Proof: %v
    50  %s}`,
    51  		part.Index,
    52  		indent, Fingerprint(part.Bytes),
    53  		indent, part.Proof.StringIndented(indent+"  "),
    54  		indent)
    55  }
    56  
    57  //-------------------------------------
    58  
    59  type PartSetHeader struct {
    60  	Total uint64 `json:"total"`
    61  	Hash  []byte `json:"hash"`
    62  }
    63  
    64  func (psh PartSetHeader) String() string {
    65  	return fmt.Sprintf("%v:%X", psh.Total, Fingerprint(psh.Hash))
    66  }
    67  
    68  func (psh PartSetHeader) IsZero() bool {
    69  	return psh.Total == 0
    70  }
    71  
    72  func (psh PartSetHeader) Equals(other PartSetHeader) bool {
    73  	return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash)
    74  }
    75  
    76  func (psh PartSetHeader) WriteSignBytes(w io.Writer, n *int, err *error) {
    77  	wire.WriteJSON(CanonicalPartSetHeader(psh), w, n, err)
    78  }
    79  
    80  //-------------------------------------
    81  
    82  type PartSet struct {
    83  	total int
    84  	hash  []byte
    85  
    86  	mtx           sync.Mutex
    87  	parts         []*Part
    88  	partsBitArray *BitArray
    89  	count         int
    90  }
    91  
    92  // Returns an immutable, full PartSet from the data bytes.
    93  // The data bytes are split into "partSize" chunks, and merkle tree computed.
    94  func NewPartSetFromData(data []byte, partSize int) *PartSet {
    95  	// divide data into 4kb parts.
    96  	total := (len(data) + partSize - 1) / partSize
    97  	parts := make([]*Part, total)
    98  	parts_ := make([]merkle.Hashable, total)
    99  	partsBitArray := NewBitArray(uint64(total))
   100  	for i := 0; i < total; i++ {
   101  		part := &Part{
   102  			Index: i,
   103  			Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)],
   104  		}
   105  		parts[i] = part
   106  		parts_[i] = part
   107  		partsBitArray.SetIndex(uint64(i), true)
   108  	}
   109  	// Compute merkle proofs
   110  	root, proofs := merkle.SimpleProofsFromHashables(parts_)
   111  	for i := 0; i < total; i++ {
   112  		parts[i].Proof = *proofs[i]
   113  	}
   114  	return &PartSet{
   115  		total:         total,
   116  		hash:          root,
   117  		parts:         parts,
   118  		partsBitArray: partsBitArray,
   119  		count:         total,
   120  	}
   121  }
   122  
   123  // Returns an empty PartSet ready to be populated.
   124  func NewPartSetFromHeader(header PartSetHeader) *PartSet {
   125  	return &PartSet{
   126  		total:         int(header.Total),
   127  		hash:          header.Hash,
   128  		parts:         make([]*Part, header.Total),
   129  		partsBitArray: NewBitArray(header.Total),
   130  		count:         0,
   131  	}
   132  }
   133  
   134  func (ps *PartSet) Header() PartSetHeader {
   135  	if ps == nil {
   136  		return PartSetHeader{}
   137  	} else {
   138  		return PartSetHeader{
   139  			Total: uint64(ps.total),
   140  			Hash:  ps.hash,
   141  		}
   142  	}
   143  }
   144  
   145  func (ps *PartSet) HasHeader(header PartSetHeader) bool {
   146  	if ps == nil {
   147  		return false
   148  	} else {
   149  		return ps.Header().Equals(header)
   150  	}
   151  }
   152  
   153  func (ps *PartSet) BitArray() *BitArray {
   154  	ps.mtx.Lock()
   155  	defer ps.mtx.Unlock()
   156  	return ps.partsBitArray.Copy()
   157  }
   158  
   159  func (ps *PartSet) Hash() []byte {
   160  	if ps == nil {
   161  		return nil
   162  	}
   163  	return ps.hash
   164  }
   165  
   166  func (ps *PartSet) HashesTo(hash []byte) bool {
   167  	if ps == nil {
   168  		return false
   169  	}
   170  	return bytes.Equal(ps.hash, hash)
   171  }
   172  
   173  func (ps *PartSet) Count() int {
   174  	if ps == nil {
   175  		return 0
   176  	}
   177  	return ps.count
   178  }
   179  
   180  func (ps *PartSet) Total() int {
   181  	if ps == nil {
   182  		return 0
   183  	}
   184  	return ps.total
   185  }
   186  
   187  func (ps *PartSet) AddPart(part *Part, verify bool) (bool, error) {
   188  	ps.mtx.Lock()
   189  	defer ps.mtx.Unlock()
   190  
   191  	// Invalid part index
   192  	if part.Index >= ps.total {
   193  		return false, ErrPartSetUnexpectedIndex
   194  	}
   195  
   196  	// If part already exists, return false.
   197  	if ps.parts[part.Index] != nil {
   198  		return false, nil
   199  	}
   200  
   201  	// Check hash proof
   202  	if verify {
   203  		if !part.Proof.Verify(part.Index, ps.total, part.Hash(), ps.Hash()) {
   204  			return false, ErrPartSetInvalidProof
   205  		}
   206  	}
   207  
   208  	// Add part
   209  	ps.parts[part.Index] = part
   210  	ps.partsBitArray.SetIndex(uint64(part.Index), true)
   211  	ps.count++
   212  	return true, nil
   213  }
   214  
   215  func (ps *PartSet) GetPart(index int) *Part {
   216  	ps.mtx.Lock()
   217  	defer ps.mtx.Unlock()
   218  	return ps.parts[index]
   219  }
   220  
   221  func (ps *PartSet) IsComplete() bool {
   222  	return ps.count == ps.total
   223  }
   224  
   225  func (ps *PartSet) GetReader() io.Reader {
   226  	if !ps.IsComplete() {
   227  		PanicSanity("Cannot GetReader() on incomplete PartSet")
   228  	}
   229  	return NewPartSetReader(ps.parts)
   230  }
   231  
   232  type PartSetReader struct {
   233  	i      int
   234  	parts  []*Part
   235  	reader *bytes.Reader
   236  }
   237  
   238  func NewPartSetReader(parts []*Part) *PartSetReader {
   239  	return &PartSetReader{
   240  		i:      0,
   241  		parts:  parts,
   242  		reader: bytes.NewReader(parts[0].Bytes),
   243  	}
   244  }
   245  
   246  func (psr *PartSetReader) Read(p []byte) (n int, err error) {
   247  	readerLen := psr.reader.Len()
   248  	if readerLen >= len(p) {
   249  		return psr.reader.Read(p)
   250  	} else if readerLen > 0 {
   251  		n1, err := psr.Read(p[:readerLen])
   252  		if err != nil {
   253  			return n1, err
   254  		}
   255  		n2, err := psr.Read(p[readerLen:])
   256  		return n1 + n2, err
   257  	}
   258  
   259  	psr.i += 1
   260  	if psr.i >= len(psr.parts) {
   261  		return 0, io.EOF
   262  	}
   263  	psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes)
   264  	return psr.Read(p)
   265  }
   266  
   267  func (ps *PartSet) StringShort() string {
   268  	if ps == nil {
   269  		return "nil-PartSet"
   270  	} else {
   271  		ps.mtx.Lock()
   272  		defer ps.mtx.Unlock()
   273  		return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total())
   274  	}
   275  }