github.com/aakash4dev/cometbft@v0.38.2/types/part_set.go (about) 1 package types 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 9 "github.com/aakash4dev/cometbft/crypto/merkle" 10 "github.com/aakash4dev/cometbft/libs/bits" 11 cmtbytes "github.com/aakash4dev/cometbft/libs/bytes" 12 cmtjson "github.com/aakash4dev/cometbft/libs/json" 13 cmtmath "github.com/aakash4dev/cometbft/libs/math" 14 cmtsync "github.com/aakash4dev/cometbft/libs/sync" 15 cmtproto "github.com/aakash4dev/cometbft/proto/tendermint/types" 16 ) 17 18 var ( 19 ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index") 20 ErrPartSetInvalidProof = errors.New("error part set invalid proof") 21 ) 22 23 type Part struct { 24 Index uint32 `json:"index"` 25 Bytes cmtbytes.HexBytes `json:"bytes"` 26 Proof merkle.Proof `json:"proof"` 27 } 28 29 // ValidateBasic performs basic validation. 30 func (part *Part) ValidateBasic() error { 31 if len(part.Bytes) > int(BlockPartSizeBytes) { 32 return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) 33 } 34 if err := part.Proof.ValidateBasic(); err != nil { 35 return fmt.Errorf("wrong Proof: %w", err) 36 } 37 return nil 38 } 39 40 // String returns a string representation of Part. 41 // 42 // See StringIndented. 43 func (part *Part) String() string { 44 return part.StringIndented("") 45 } 46 47 // StringIndented returns an indented Part. 48 // 49 // See merkle.Proof#StringIndented 50 func (part *Part) StringIndented(indent string) string { 51 return fmt.Sprintf(`Part{#%v 52 %s Bytes: %X... 53 %s Proof: %v 54 %s}`, 55 part.Index, 56 indent, cmtbytes.Fingerprint(part.Bytes), 57 indent, part.Proof.StringIndented(indent+" "), 58 indent) 59 } 60 61 func (part *Part) ToProto() (*cmtproto.Part, error) { 62 if part == nil { 63 return nil, errors.New("nil part") 64 } 65 pb := new(cmtproto.Part) 66 proof := part.Proof.ToProto() 67 68 pb.Index = part.Index 69 pb.Bytes = part.Bytes 70 pb.Proof = *proof 71 72 return pb, nil 73 } 74 75 func PartFromProto(pb *cmtproto.Part) (*Part, error) { 76 if pb == nil { 77 return nil, errors.New("nil part") 78 } 79 80 part := new(Part) 81 proof, err := merkle.ProofFromProto(&pb.Proof) 82 if err != nil { 83 return nil, err 84 } 85 part.Index = pb.Index 86 part.Bytes = pb.Bytes 87 part.Proof = *proof 88 89 return part, part.ValidateBasic() 90 } 91 92 //------------------------------------- 93 94 type PartSetHeader struct { 95 Total uint32 `json:"total"` 96 Hash cmtbytes.HexBytes `json:"hash"` 97 } 98 99 // String returns a string representation of PartSetHeader. 100 // 101 // 1. total number of parts 102 // 2. first 6 bytes of the hash 103 func (psh PartSetHeader) String() string { 104 return fmt.Sprintf("%v:%X", psh.Total, cmtbytes.Fingerprint(psh.Hash)) 105 } 106 107 func (psh PartSetHeader) IsZero() bool { 108 return psh.Total == 0 && len(psh.Hash) == 0 109 } 110 111 func (psh PartSetHeader) Equals(other PartSetHeader) bool { 112 return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) 113 } 114 115 // ValidateBasic performs basic validation. 116 func (psh PartSetHeader) ValidateBasic() error { 117 // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal. 118 if err := ValidateHash(psh.Hash); err != nil { 119 return fmt.Errorf("wrong Hash: %w", err) 120 } 121 return nil 122 } 123 124 // ToProto converts PartSetHeader to protobuf 125 func (psh *PartSetHeader) ToProto() cmtproto.PartSetHeader { 126 if psh == nil { 127 return cmtproto.PartSetHeader{} 128 } 129 130 return cmtproto.PartSetHeader{ 131 Total: psh.Total, 132 Hash: psh.Hash, 133 } 134 } 135 136 // FromProto sets a protobuf PartSetHeader to the given pointer 137 func PartSetHeaderFromProto(ppsh *cmtproto.PartSetHeader) (*PartSetHeader, error) { 138 if ppsh == nil { 139 return nil, errors.New("nil PartSetHeader") 140 } 141 psh := new(PartSetHeader) 142 psh.Total = ppsh.Total 143 psh.Hash = ppsh.Hash 144 145 return psh, psh.ValidateBasic() 146 } 147 148 // ProtoPartSetHeaderIsZero is similar to the IsZero function for 149 // PartSetHeader, but for the Protobuf representation. 150 func ProtoPartSetHeaderIsZero(ppsh *cmtproto.PartSetHeader) bool { 151 return ppsh.Total == 0 && len(ppsh.Hash) == 0 152 } 153 154 //------------------------------------- 155 156 type PartSet struct { 157 total uint32 158 hash []byte 159 160 mtx cmtsync.Mutex 161 parts []*Part 162 partsBitArray *bits.BitArray 163 count uint32 164 // a count of the total size (in bytes). Used to ensure that the 165 // part set doesn't exceed the maximum block bytes 166 byteSize int64 167 } 168 169 // Returns an immutable, full PartSet from the data bytes. 170 // The data bytes are split into "partSize" chunks, and merkle tree computed. 171 // CONTRACT: partSize is greater than zero. 172 func NewPartSetFromData(data []byte, partSize uint32) *PartSet { 173 // divide data into parts of size `partSize` 174 total := (uint32(len(data)) + partSize - 1) / partSize 175 parts := make([]*Part, total) 176 partsBytes := make([][]byte, total) 177 partsBitArray := bits.NewBitArray(int(total)) 178 for i := uint32(0); i < total; i++ { 179 part := &Part{ 180 Index: i, 181 Bytes: data[i*partSize : cmtmath.MinInt(len(data), int((i+1)*partSize))], 182 } 183 parts[i] = part 184 partsBytes[i] = part.Bytes 185 partsBitArray.SetIndex(int(i), true) 186 } 187 // Compute merkle proofs 188 root, proofs := merkle.ProofsFromByteSlices(partsBytes) 189 for i := uint32(0); i < total; i++ { 190 parts[i].Proof = *proofs[i] 191 } 192 return &PartSet{ 193 total: total, 194 hash: root, 195 parts: parts, 196 partsBitArray: partsBitArray, 197 count: total, 198 byteSize: int64(len(data)), 199 } 200 } 201 202 // Returns an empty PartSet ready to be populated. 203 func NewPartSetFromHeader(header PartSetHeader) *PartSet { 204 return &PartSet{ 205 total: header.Total, 206 hash: header.Hash, 207 parts: make([]*Part, header.Total), 208 partsBitArray: bits.NewBitArray(int(header.Total)), 209 count: 0, 210 byteSize: 0, 211 } 212 } 213 214 func (ps *PartSet) Header() PartSetHeader { 215 if ps == nil { 216 return PartSetHeader{} 217 } 218 return PartSetHeader{ 219 Total: ps.total, 220 Hash: ps.hash, 221 } 222 } 223 224 func (ps *PartSet) HasHeader(header PartSetHeader) bool { 225 if ps == nil { 226 return false 227 } 228 return ps.Header().Equals(header) 229 } 230 231 func (ps *PartSet) BitArray() *bits.BitArray { 232 ps.mtx.Lock() 233 defer ps.mtx.Unlock() 234 return ps.partsBitArray.Copy() 235 } 236 237 func (ps *PartSet) Hash() []byte { 238 if ps == nil { 239 return merkle.HashFromByteSlices(nil) 240 } 241 return ps.hash 242 } 243 244 func (ps *PartSet) HashesTo(hash []byte) bool { 245 if ps == nil { 246 return false 247 } 248 return bytes.Equal(ps.hash, hash) 249 } 250 251 func (ps *PartSet) Count() uint32 { 252 if ps == nil { 253 return 0 254 } 255 return ps.count 256 } 257 258 func (ps *PartSet) ByteSize() int64 { 259 if ps == nil { 260 return 0 261 } 262 return ps.byteSize 263 } 264 265 func (ps *PartSet) Total() uint32 { 266 if ps == nil { 267 return 0 268 } 269 return ps.total 270 } 271 272 func (ps *PartSet) AddPart(part *Part) (bool, error) { 273 // TODO: remove this? would be preferable if this only returned (false, nil) 274 // when its a duplicate block part 275 if ps == nil { 276 return false, nil 277 } 278 279 ps.mtx.Lock() 280 defer ps.mtx.Unlock() 281 282 // Invalid part index 283 if part.Index >= ps.total { 284 return false, ErrPartSetUnexpectedIndex 285 } 286 287 // If part already exists, return false. 288 if ps.parts[part.Index] != nil { 289 return false, nil 290 } 291 292 // Check hash proof 293 if part.Proof.Verify(ps.Hash(), part.Bytes) != nil { 294 return false, ErrPartSetInvalidProof 295 } 296 297 // Add part 298 ps.parts[part.Index] = part 299 ps.partsBitArray.SetIndex(int(part.Index), true) 300 ps.count++ 301 ps.byteSize += int64(len(part.Bytes)) 302 return true, nil 303 } 304 305 func (ps *PartSet) GetPart(index int) *Part { 306 ps.mtx.Lock() 307 defer ps.mtx.Unlock() 308 return ps.parts[index] 309 } 310 311 func (ps *PartSet) IsComplete() bool { 312 return ps.count == ps.total 313 } 314 315 func (ps *PartSet) GetReader() io.Reader { 316 if !ps.IsComplete() { 317 panic("Cannot GetReader() on incomplete PartSet") 318 } 319 return NewPartSetReader(ps.parts) 320 } 321 322 type PartSetReader struct { 323 i int 324 parts []*Part 325 reader *bytes.Reader 326 } 327 328 func NewPartSetReader(parts []*Part) *PartSetReader { 329 return &PartSetReader{ 330 i: 0, 331 parts: parts, 332 reader: bytes.NewReader(parts[0].Bytes), 333 } 334 } 335 336 func (psr *PartSetReader) Read(p []byte) (n int, err error) { 337 readerLen := psr.reader.Len() 338 if readerLen >= len(p) { 339 return psr.reader.Read(p) 340 } else if readerLen > 0 { 341 n1, err := psr.Read(p[:readerLen]) 342 if err != nil { 343 return n1, err 344 } 345 n2, err := psr.Read(p[readerLen:]) 346 return n1 + n2, err 347 } 348 349 psr.i++ 350 if psr.i >= len(psr.parts) { 351 return 0, io.EOF 352 } 353 psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) 354 return psr.Read(p) 355 } 356 357 // StringShort returns a short version of String. 358 // 359 // (Count of Total) 360 func (ps *PartSet) StringShort() string { 361 if ps == nil { 362 return "nil-PartSet" 363 } 364 ps.mtx.Lock() 365 defer ps.mtx.Unlock() 366 return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) 367 } 368 369 func (ps *PartSet) MarshalJSON() ([]byte, error) { 370 if ps == nil { 371 return []byte("{}"), nil 372 } 373 374 ps.mtx.Lock() 375 defer ps.mtx.Unlock() 376 377 return cmtjson.Marshal(struct { 378 CountTotal string `json:"count/total"` 379 PartsBitArray *bits.BitArray `json:"parts_bit_array"` 380 }{ 381 fmt.Sprintf("%d/%d", ps.Count(), ps.Total()), 382 ps.partsBitArray, 383 }) 384 }