github.com/intfoundation/intchain@v0.0.0-20220727031208-4316ad31ca73/consensus/ipbft/types/part_set.go (about) 1 package types 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "sync" 9 10 "golang.org/x/crypto/ripemd160" 11 12 . "github.com/intfoundation/go-common" 13 "github.com/intfoundation/go-merkle" 14 "github.com/intfoundation/go-wire" 15 ) 16 17 var ( 18 ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index") 19 ErrPartSetInvalidProof = errors.New("Error part set invalid proof") 20 ) 21 22 type Part struct { 23 Index int `json:"index"` 24 Bytes []byte `json:"bytes"` 25 Proof merkle.SimpleProof `json:"proof"` 26 27 // Cache 28 hash []byte 29 } 30 31 func (part *Part) Hash() []byte { 32 if part.hash != nil { 33 return part.hash 34 } else { 35 hasher := ripemd160.New() 36 hasher.Write(part.Bytes) // doesn't err 37 part.hash = hasher.Sum(nil) 38 return part.hash 39 } 40 } 41 42 func (part *Part) String() string { 43 return part.StringIndented("") 44 } 45 46 func (part *Part) StringIndented(indent string) string { 47 return fmt.Sprintf(`Part{#%v 48 %s Bytes: %X... 49 %s Proof: %v 50 %s}`, 51 part.Index, 52 indent, Fingerprint(part.Bytes), 53 indent, part.Proof.StringIndented(indent+" "), 54 indent) 55 } 56 57 //------------------------------------- 58 59 type PartSetHeader struct { 60 Total uint64 `json:"total"` 61 Hash []byte `json:"hash"` 62 } 63 64 func (psh PartSetHeader) String() string { 65 return fmt.Sprintf("%v:%X", psh.Total, Fingerprint(psh.Hash)) 66 } 67 68 func (psh PartSetHeader) IsZero() bool { 69 return psh.Total == 0 70 } 71 72 func (psh PartSetHeader) Equals(other PartSetHeader) bool { 73 return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) 74 } 75 76 func (psh PartSetHeader) WriteSignBytes(w io.Writer, n *int, err *error) { 77 wire.WriteJSON(CanonicalPartSetHeader(psh), w, n, err) 78 } 79 80 //------------------------------------- 81 82 type PartSet struct { 83 total int 84 hash []byte 85 86 mtx sync.Mutex 87 parts []*Part 88 partsBitArray *BitArray 89 count int 90 } 91 92 // Returns an immutable, full PartSet from the data bytes. 93 // The data bytes are split into "partSize" chunks, and merkle tree computed. 94 func NewPartSetFromData(data []byte, partSize int) *PartSet { 95 // divide data into 4kb parts. 96 total := (len(data) + partSize - 1) / partSize 97 parts := make([]*Part, total) 98 parts_ := make([]merkle.Hashable, total) 99 partsBitArray := NewBitArray(uint64(total)) 100 for i := 0; i < total; i++ { 101 part := &Part{ 102 Index: i, 103 Bytes: data[i*partSize : MinInt(len(data), (i+1)*partSize)], 104 } 105 parts[i] = part 106 parts_[i] = part 107 partsBitArray.SetIndex(uint64(i), true) 108 } 109 // Compute merkle proofs 110 root, proofs := merkle.SimpleProofsFromHashables(parts_) 111 for i := 0; i < total; i++ { 112 parts[i].Proof = *proofs[i] 113 } 114 return &PartSet{ 115 total: total, 116 hash: root, 117 parts: parts, 118 partsBitArray: partsBitArray, 119 count: total, 120 } 121 } 122 123 // Returns an empty PartSet ready to be populated. 124 func NewPartSetFromHeader(header PartSetHeader) *PartSet { 125 return &PartSet{ 126 total: int(header.Total), 127 hash: header.Hash, 128 parts: make([]*Part, header.Total), 129 partsBitArray: NewBitArray(header.Total), 130 count: 0, 131 } 132 } 133 134 func (ps *PartSet) Header() PartSetHeader { 135 if ps == nil { 136 return PartSetHeader{} 137 } else { 138 return PartSetHeader{ 139 Total: uint64(ps.total), 140 Hash: ps.hash, 141 } 142 } 143 } 144 145 func (ps *PartSet) HasHeader(header PartSetHeader) bool { 146 if ps == nil { 147 return false 148 } else { 149 return ps.Header().Equals(header) 150 } 151 } 152 153 func (ps *PartSet) BitArray() *BitArray { 154 ps.mtx.Lock() 155 defer ps.mtx.Unlock() 156 return ps.partsBitArray.Copy() 157 } 158 159 func (ps *PartSet) Hash() []byte { 160 if ps == nil { 161 return nil 162 } 163 return ps.hash 164 } 165 166 func (ps *PartSet) HashesTo(hash []byte) bool { 167 if ps == nil { 168 return false 169 } 170 return bytes.Equal(ps.hash, hash) 171 } 172 173 func (ps *PartSet) Count() int { 174 if ps == nil { 175 return 0 176 } 177 return ps.count 178 } 179 180 func (ps *PartSet) Total() int { 181 if ps == nil { 182 return 0 183 } 184 return ps.total 185 } 186 187 func (ps *PartSet) AddPart(part *Part, verify bool) (bool, error) { 188 ps.mtx.Lock() 189 defer ps.mtx.Unlock() 190 191 // Invalid part index 192 if part.Index >= ps.total { 193 return false, ErrPartSetUnexpectedIndex 194 } 195 196 // If part already exists, return false. 197 if ps.parts[part.Index] != nil { 198 return false, nil 199 } 200 201 // Check hash proof 202 if verify { 203 if !part.Proof.Verify(part.Index, ps.total, part.Hash(), ps.Hash()) { 204 return false, ErrPartSetInvalidProof 205 } 206 } 207 208 // Add part 209 ps.parts[part.Index] = part 210 ps.partsBitArray.SetIndex(uint64(part.Index), true) 211 ps.count++ 212 return true, nil 213 } 214 215 func (ps *PartSet) GetPart(index int) *Part { 216 ps.mtx.Lock() 217 defer ps.mtx.Unlock() 218 return ps.parts[index] 219 } 220 221 func (ps *PartSet) IsComplete() bool { 222 return ps.count == ps.total 223 } 224 225 func (ps *PartSet) GetReader() io.Reader { 226 if !ps.IsComplete() { 227 PanicSanity("Cannot GetReader() on incomplete PartSet") 228 } 229 return NewPartSetReader(ps.parts) 230 } 231 232 type PartSetReader struct { 233 i int 234 parts []*Part 235 reader *bytes.Reader 236 } 237 238 func NewPartSetReader(parts []*Part) *PartSetReader { 239 return &PartSetReader{ 240 i: 0, 241 parts: parts, 242 reader: bytes.NewReader(parts[0].Bytes), 243 } 244 } 245 246 func (psr *PartSetReader) Read(p []byte) (n int, err error) { 247 readerLen := psr.reader.Len() 248 if readerLen >= len(p) { 249 return psr.reader.Read(p) 250 } else if readerLen > 0 { 251 n1, err := psr.Read(p[:readerLen]) 252 if err != nil { 253 return n1, err 254 } 255 n2, err := psr.Read(p[readerLen:]) 256 return n1 + n2, err 257 } 258 259 psr.i += 1 260 if psr.i >= len(psr.parts) { 261 return 0, io.EOF 262 } 263 psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) 264 return psr.Read(p) 265 } 266 267 func (ps *PartSet) StringShort() string { 268 if ps == nil { 269 return "nil-PartSet" 270 } else { 271 ps.mtx.Lock() 272 defer ps.mtx.Unlock() 273 return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) 274 } 275 }