github.com/Finschia/ostracon@v1.1.5/types/part_set.go (about) 1 package types 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 9 tmproto "github.com/tendermint/tendermint/proto/tendermint/types" 10 11 "github.com/Finschia/ostracon/crypto/merkle" 12 "github.com/Finschia/ostracon/libs/bits" 13 tmbytes "github.com/Finschia/ostracon/libs/bytes" 14 tmjson "github.com/Finschia/ostracon/libs/json" 15 tmmath "github.com/Finschia/ostracon/libs/math" 16 tmsync "github.com/Finschia/ostracon/libs/sync" 17 ) 18 19 var ( 20 ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index") 21 ErrPartSetInvalidProof = errors.New("error part set invalid proof") 22 ) 23 24 type Part struct { 25 Index uint32 `json:"index"` 26 Bytes tmbytes.HexBytes `json:"bytes"` 27 Proof merkle.Proof `json:"proof"` 28 } 29 30 // ValidateBasic performs basic validation. 31 func (part *Part) ValidateBasic() error { 32 if len(part.Bytes) > int(BlockPartSizeBytes) { 33 return fmt.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) 34 } 35 if err := part.Proof.ValidateBasic(); err != nil { 36 return fmt.Errorf("wrong Proof: %w", err) 37 } 38 return nil 39 } 40 41 // String returns a string representation of Part. 42 // 43 // See StringIndented. 44 func (part *Part) String() string { 45 return part.StringIndented("") 46 } 47 48 // StringIndented returns an indented Part. 49 // 50 // See merkle.Proof#StringIndented 51 func (part *Part) StringIndented(indent string) string { 52 return fmt.Sprintf(`Part{#%v 53 %s Bytes: %X... 54 %s Proof: %v 55 %s}`, 56 part.Index, 57 indent, tmbytes.Fingerprint(part.Bytes), 58 indent, part.Proof.StringIndented(indent+" "), 59 indent) 60 } 61 62 func (part *Part) ToProto() (*tmproto.Part, error) { 63 if part == nil { 64 return nil, errors.New("nil part") 65 } 66 pb := new(tmproto.Part) 67 proof := part.Proof.ToProto() 68 69 pb.Index = part.Index 70 pb.Bytes = part.Bytes 71 pb.Proof = *proof 72 73 return pb, nil 74 } 75 76 func PartFromProto(pb *tmproto.Part) (*Part, error) { 77 if pb == nil { 78 return nil, errors.New("nil part") 79 } 80 81 part := new(Part) 82 proof, err := merkle.ProofFromProto(&pb.Proof) 83 if err != nil { 84 return nil, err 85 } 86 part.Index = pb.Index 87 part.Bytes = pb.Bytes 88 part.Proof = *proof 89 90 return part, part.ValidateBasic() 91 } 92 93 //------------------------------------- 94 95 type PartSetHeader struct { 96 Total uint32 `json:"total"` 97 Hash tmbytes.HexBytes `json:"hash"` 98 } 99 100 // String returns a string representation of PartSetHeader. 101 // 102 // 1. total number of parts 103 // 2. first 6 bytes of the hash 104 func (psh PartSetHeader) String() string { 105 return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash)) 106 } 107 108 func (psh PartSetHeader) IsZero() bool { 109 return psh.Total == 0 && len(psh.Hash) == 0 110 } 111 112 func (psh PartSetHeader) Equals(other PartSetHeader) bool { 113 return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) 114 } 115 116 // ValidateBasic performs basic validation. 117 func (psh PartSetHeader) ValidateBasic() error { 118 // Hash can be empty in case of POLBlockID.PartSetHeader in Proposal. 119 if err := ValidateHash(psh.Hash); err != nil { 120 return fmt.Errorf("wrong Hash: %w", err) 121 } 122 return nil 123 } 124 125 // ToProto converts PartSetHeader to protobuf 126 func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader { 127 if psh == nil { 128 return tmproto.PartSetHeader{} 129 } 130 131 return tmproto.PartSetHeader{ 132 Total: psh.Total, 133 Hash: psh.Hash, 134 } 135 } 136 137 // FromProto sets a protobuf PartSetHeader to the given pointer 138 func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) { 139 if ppsh == nil { 140 return nil, errors.New("nil PartSetHeader") 141 } 142 psh := new(PartSetHeader) 143 psh.Total = ppsh.Total 144 psh.Hash = ppsh.Hash 145 146 return psh, psh.ValidateBasic() 147 } 148 149 //------------------------------------- 150 151 type PartSet struct { 152 total uint32 153 hash []byte 154 155 mtx tmsync.Mutex 156 parts []*Part 157 partsBitArray *bits.BitArray 158 count uint32 159 // a count of the total size (in bytes). Used to ensure that the 160 // part set doesn't exceed the maximum block bytes 161 byteSize int64 162 } 163 164 // Returns an immutable, full PartSet from the data bytes. 165 // The data bytes are split into "partSize" chunks, and merkle tree computed. 166 // CONTRACT: partSize is greater than zero. 167 func NewPartSetFromData(data []byte, partSize uint32) *PartSet { 168 // divide data into parts of size `partSize` 169 total := (uint32(len(data)) + partSize - 1) / partSize 170 parts := make([]*Part, total) 171 partsBytes := make([][]byte, total) 172 partsBitArray := bits.NewBitArray(int(total)) 173 for i := uint32(0); i < total; i++ { 174 part := &Part{ 175 Index: i, 176 Bytes: data[i*partSize : tmmath.MinInt(len(data), int((i+1)*partSize))], 177 } 178 parts[i] = part 179 partsBytes[i] = part.Bytes 180 partsBitArray.SetIndex(int(i), true) 181 } 182 // Compute merkle proofs 183 root, proofs := merkle.ProofsFromByteSlices(partsBytes) 184 for i := uint32(0); i < total; i++ { 185 parts[i].Proof = *proofs[i] 186 } 187 return &PartSet{ 188 total: total, 189 hash: root, 190 parts: parts, 191 partsBitArray: partsBitArray, 192 count: total, 193 byteSize: int64(len(data)), 194 } 195 } 196 197 // Returns an empty PartSet ready to be populated. 198 func NewPartSetFromHeader(header PartSetHeader) *PartSet { 199 return &PartSet{ 200 total: header.Total, 201 hash: header.Hash, 202 parts: make([]*Part, header.Total), 203 partsBitArray: bits.NewBitArray(int(header.Total)), 204 count: 0, 205 byteSize: 0, 206 } 207 } 208 209 func (ps *PartSet) Header() PartSetHeader { 210 if ps == nil { 211 return PartSetHeader{} 212 } 213 return PartSetHeader{ 214 Total: ps.total, 215 Hash: ps.hash, 216 } 217 } 218 219 func (ps *PartSet) HasHeader(header PartSetHeader) bool { 220 if ps == nil { 221 return false 222 } 223 return ps.Header().Equals(header) 224 } 225 226 func (ps *PartSet) BitArray() *bits.BitArray { 227 ps.mtx.Lock() 228 defer ps.mtx.Unlock() 229 return ps.partsBitArray.Copy() 230 } 231 232 func (ps *PartSet) Hash() []byte { 233 if ps == nil { 234 return merkle.HashFromByteSlices(nil) 235 } 236 return ps.hash 237 } 238 239 func (ps *PartSet) HashesTo(hash []byte) bool { 240 if ps == nil { 241 return false 242 } 243 return bytes.Equal(ps.hash, hash) 244 } 245 246 func (ps *PartSet) Count() uint32 { 247 if ps == nil { 248 return 0 249 } 250 return ps.count 251 } 252 253 func (ps *PartSet) ByteSize() int64 { 254 if ps == nil { 255 return 0 256 } 257 return ps.byteSize 258 } 259 260 func (ps *PartSet) Total() uint32 { 261 if ps == nil { 262 return 0 263 } 264 return ps.total 265 } 266 267 func (ps *PartSet) AddPart(part *Part) (bool, error) { 268 if ps == nil { 269 return false, nil 270 } 271 ps.mtx.Lock() 272 defer ps.mtx.Unlock() 273 274 // Invalid part index 275 if part.Index >= ps.total { 276 return false, ErrPartSetUnexpectedIndex 277 } 278 279 // If part already exists, return false. 280 if ps.parts[part.Index] != nil { 281 return false, nil 282 } 283 284 // Check hash proof 285 if part.Proof.Verify(ps.Hash(), part.Bytes) != nil { 286 return false, ErrPartSetInvalidProof 287 } 288 289 // Add part 290 ps.parts[part.Index] = part 291 ps.partsBitArray.SetIndex(int(part.Index), true) 292 ps.count++ 293 ps.byteSize += int64(len(part.Bytes)) 294 return true, nil 295 } 296 297 func (ps *PartSet) GetPart(index int) *Part { 298 ps.mtx.Lock() 299 defer ps.mtx.Unlock() 300 return ps.parts[index] 301 } 302 303 func (ps *PartSet) IsComplete() bool { 304 return ps.count == ps.total 305 } 306 307 func (ps *PartSet) GetReader() io.Reader { 308 if !ps.IsComplete() { 309 panic("Cannot GetReader() on incomplete PartSet") 310 } 311 return NewPartSetReader(ps.parts) 312 } 313 314 type PartSetReader struct { 315 i int 316 parts []*Part 317 reader *bytes.Reader 318 } 319 320 func NewPartSetReader(parts []*Part) *PartSetReader { 321 return &PartSetReader{ 322 i: 0, 323 parts: parts, 324 reader: bytes.NewReader(parts[0].Bytes), 325 } 326 } 327 328 func (psr *PartSetReader) Read(p []byte) (n int, err error) { 329 readerLen := psr.reader.Len() 330 if readerLen >= len(p) { 331 return psr.reader.Read(p) 332 } else if readerLen > 0 { 333 n1, err := psr.Read(p[:readerLen]) 334 if err != nil { 335 return n1, err 336 } 337 n2, err := psr.Read(p[readerLen:]) 338 return n1 + n2, err 339 } 340 341 psr.i++ 342 if psr.i >= len(psr.parts) { 343 return 0, io.EOF 344 } 345 psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) 346 return psr.Read(p) 347 } 348 349 // StringShort returns a short version of String. 350 // 351 // (Count of Total) 352 func (ps *PartSet) StringShort() string { 353 if ps == nil { 354 return "nil-PartSet" 355 } 356 ps.mtx.Lock() 357 defer ps.mtx.Unlock() 358 return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) 359 } 360 361 func (ps *PartSet) MarshalJSON() ([]byte, error) { 362 if ps == nil { 363 return []byte("{}"), nil 364 } 365 366 ps.mtx.Lock() 367 defer ps.mtx.Unlock() 368 369 return tmjson.Marshal(struct { 370 CountTotal string `json:"count/total"` 371 PartsBitArray *bits.BitArray `json:"parts_bit_array"` 372 }{ 373 fmt.Sprintf("%d/%d", ps.Count(), ps.Total()), 374 ps.partsBitArray, 375 }) 376 }