github.com/adoriasoft/tendermint@v0.34.0-dev1.0.20200722151356-96d84601a75a/types/part_set.go (about) 1 package types 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 9 "github.com/tendermint/tendermint/crypto/merkle" 10 "github.com/tendermint/tendermint/libs/bits" 11 tmbytes "github.com/tendermint/tendermint/libs/bytes" 12 tmjson "github.com/tendermint/tendermint/libs/json" 13 tmmath "github.com/tendermint/tendermint/libs/math" 14 tmsync "github.com/tendermint/tendermint/libs/sync" 15 tmproto "github.com/tendermint/tendermint/proto/tendermint/types" 16 ) 17 18 var ( 19 ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index") 20 ErrPartSetInvalidProof = errors.New("error part set invalid proof") 21 ) 22 23 type Part struct { 24 Index uint32 `json:"index"` 25 Bytes tmbytes.HexBytes `json:"bytes"` 26 Proof merkle.Proof `json:"proof"` 27 } 28 29 // ValidateBasic performs basic validation. 30 func (part *Part) ValidateBasic() error { 31 if len(part.Bytes) > int(BlockPartSizeBytes) { 32 return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) 33 } 34 if err := part.Proof.ValidateBasic(); err != nil { 35 return fmt.Errorf("wrong Proof: %w", err) 36 } 37 return nil 38 } 39 40 func (part *Part) String() string { 41 return part.StringIndented("") 42 } 43 44 func (part *Part) StringIndented(indent string) string { 45 return fmt.Sprintf(`Part{#%v 46 %s Bytes: %X... 47 %s Proof: %v 48 %s}`, 49 part.Index, 50 indent, tmbytes.Fingerprint(part.Bytes), 51 indent, part.Proof.StringIndented(indent+" "), 52 indent) 53 } 54 55 func (part *Part) ToProto() (*tmproto.Part, error) { 56 if part == nil { 57 return nil, errors.New("nil part") 58 } 59 pb := new(tmproto.Part) 60 proof := part.Proof.ToProto() 61 62 pb.Index = part.Index 63 pb.Bytes = part.Bytes 64 pb.Proof = *proof 65 66 return pb, nil 67 } 68 69 func PartFromProto(pb *tmproto.Part) (*Part, error) { 70 if pb == nil { 71 return nil, errors.New("nil part") 72 } 73 74 part := new(Part) 75 proof, err := merkle.ProofFromProto(&pb.Proof) 76 if err != nil { 77 return nil, err 78 } 79 part.Index = pb.Index 80 part.Bytes = pb.Bytes 81 part.Proof = *proof 82 83 return part, part.ValidateBasic() 84 } 85 86 //------------------------------------- 87 88 type PartSetHeader struct { 89 Total uint32 `json:"total"` 90 Hash tmbytes.HexBytes `json:"hash"` 91 } 92 93 func (psh PartSetHeader) String() string { 94 return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash)) 95 } 96 97 func (psh PartSetHeader) IsZero() bool { 98 return psh.Total == 0 && len(psh.Hash) == 0 99 } 100 101 func (psh PartSetHeader) Equals(other PartSetHeader) bool { 102 return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) 103 } 104 105 // ValidateBasic performs basic validation. 106 func (psh PartSetHeader) ValidateBasic() error { 107 // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal. 108 if err := ValidateHash(psh.Hash); err != nil { 109 return fmt.Errorf("wrong Hash: %w", err) 110 } 111 return nil 112 } 113 114 // ToProto converts BloPartSetHeaderckID to protobuf 115 func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader { 116 if psh == nil { 117 return tmproto.PartSetHeader{} 118 } 119 120 return tmproto.PartSetHeader{ 121 Total: psh.Total, 122 Hash: psh.Hash, 123 } 124 } 125 126 // FromProto sets a protobuf PartSetHeader to the given pointer 127 func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) { 128 if ppsh == nil { 129 return nil, errors.New("nil PartSetHeader") 130 } 131 psh := new(PartSetHeader) 132 psh.Total = ppsh.Total 133 psh.Hash = ppsh.Hash 134 135 return psh, psh.ValidateBasic() 136 } 137 138 //------------------------------------- 139 140 type PartSet struct { 141 total uint32 142 hash []byte 143 144 mtx tmsync.Mutex 145 parts []*Part 146 partsBitArray *bits.BitArray 147 count uint32 148 } 149 150 // Returns an immutable, full PartSet from the data bytes. 151 // The data bytes are split into "partSize" chunks, and merkle tree computed. 152 // CONTRACT: partSize is greater than zero. 153 func NewPartSetFromData(data []byte, partSize uint32) *PartSet { 154 // divide data into 4kb parts. 155 total := (uint32(len(data)) + partSize - 1) / partSize 156 parts := make([]*Part, total) 157 partsBytes := make([][]byte, total) 158 partsBitArray := bits.NewBitArray(int(total)) 159 for i := uint32(0); i < total; i++ { 160 part := &Part{ 161 Index: i, 162 Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))], 163 } 164 parts[i] = part 165 partsBytes[i] = part.Bytes 166 partsBitArray.SetIndex(int(i), true) 167 } 168 // Compute merkle proofs 169 root, proofs := merkle.ProofsFromByteSlices(partsBytes) 170 for i := uint32(0); i < total; i++ { 171 parts[i].Proof = *proofs[i] 172 } 173 return &PartSet{ 174 total: total, 175 hash: root, 176 parts: parts, 177 partsBitArray: partsBitArray, 178 count: total, 179 } 180 } 181 182 // Returns an empty PartSet ready to be populated. 183 func NewPartSetFromHeader(header PartSetHeader) *PartSet { 184 return &PartSet{ 185 total: header.Total, 186 hash: header.Hash, 187 parts: make([]*Part, header.Total), 188 partsBitArray: bits.NewBitArray(int(header.Total)), 189 count: 0, 190 } 191 } 192 193 func (ps *PartSet) Header() PartSetHeader { 194 if ps == nil { 195 return PartSetHeader{} 196 } 197 return PartSetHeader{ 198 Total: ps.total, 199 Hash: ps.hash, 200 } 201 } 202 203 func (ps *PartSet) HasHeader(header PartSetHeader) bool { 204 if ps == nil { 205 return false 206 } 207 return ps.Header().Equals(header) 208 } 209 210 func (ps *PartSet) BitArray() *bits.BitArray { 211 ps.mtx.Lock() 212 defer ps.mtx.Unlock() 213 return ps.partsBitArray.Copy() 214 } 215 216 func (ps *PartSet) Hash() []byte { 217 if ps == nil { 218 return nil 219 } 220 return ps.hash 221 } 222 223 func (ps *PartSet) HashesTo(hash []byte) bool { 224 if ps == nil { 225 return false 226 } 227 return bytes.Equal(ps.hash, hash) 228 } 229 230 func (ps *PartSet) Count() uint32 { 231 if ps == nil { 232 return 0 233 } 234 return ps.count 235 } 236 237 func (ps *PartSet) Total() uint32 { 238 if ps == nil { 239 return 0 240 } 241 return ps.total 242 } 243 244 func (ps *PartSet) AddPart(part *Part) (bool, error) { 245 if ps == nil { 246 return false, nil 247 } 248 ps.mtx.Lock() 249 defer ps.mtx.Unlock() 250 251 // Invalid part index 252 if part.Index >= ps.total { 253 return false, ErrPartSetUnexpectedIndex 254 } 255 256 // If part already exists, return false. 257 if ps.parts[part.Index] != nil { 258 return false, nil 259 } 260 261 // Check hash proof 262 if part.Proof.Verify(ps.Hash(), part.Bytes) != nil { 263 return false, ErrPartSetInvalidProof 264 } 265 266 // Add part 267 ps.parts[part.Index] = part 268 ps.partsBitArray.SetIndex(int(part.Index), true) 269 ps.count++ 270 return true, nil 271 } 272 273 func (ps *PartSet) GetPart(index int) *Part { 274 ps.mtx.Lock() 275 defer ps.mtx.Unlock() 276 return ps.parts[index] 277 } 278 279 func (ps *PartSet) IsComplete() bool { 280 return ps.count == ps.total 281 } 282 283 func (ps *PartSet) GetReader() io.Reader { 284 if !ps.IsComplete() { 285 panic("Cannot GetReader() on incomplete PartSet") 286 } 287 return NewPartSetReader(ps.parts) 288 } 289 290 type PartSetReader struct { 291 i int 292 parts []*Part 293 reader *bytes.Reader 294 } 295 296 func NewPartSetReader(parts []*Part) *PartSetReader { 297 return &PartSetReader{ 298 i: 0, 299 parts: parts, 300 reader: bytes.NewReader(parts[0].Bytes), 301 } 302 } 303 304 func (psr *PartSetReader) Read(p []byte) (n int, err error) { 305 readerLen := psr.reader.Len() 306 if readerLen >= len(p) { 307 return psr.reader.Read(p) 308 } else if readerLen > 0 { 309 n1, err := psr.Read(p[:readerLen]) 310 if err != nil { 311 return n1, err 312 } 313 n2, err := psr.Read(p[readerLen:]) 314 return n1 + n2, err 315 } 316 317 psr.i++ 318 if psr.i >= len(psr.parts) { 319 return 0, io.EOF 320 } 321 psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) 322 return psr.Read(p) 323 } 324 325 func (ps *PartSet) StringShort() string { 326 if ps == nil { 327 return "nil-PartSet" 328 } 329 ps.mtx.Lock() 330 defer ps.mtx.Unlock() 331 return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) 332 } 333 334 func (ps *PartSet) MarshalJSON() ([]byte, error) { 335 if ps == nil { 336 return []byte("{}"), nil 337 } 338 339 ps.mtx.Lock() 340 defer ps.mtx.Unlock() 341 342 return tmjson.Marshal(struct { 343 CountTotal string `json:"count/total"` 344 PartsBitArray *bits.BitArray `json:"parts_bit_array"` 345 }{ 346 fmt.Sprintf("%d/%d", ps.Count(), ps.Total()), 347 ps.partsBitArray, 348 }) 349 }