github.com/okex/exchain@v1.8.0/libs/tendermint/types/part_set.go (about) 1 package types 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "sync" 8 9 "github.com/tendermint/go-amino" 10 11 "github.com/pkg/errors" 12 13 "github.com/okex/exchain/libs/tendermint/crypto/merkle" 14 "github.com/okex/exchain/libs/tendermint/libs/bits" 15 tmbytes "github.com/okex/exchain/libs/tendermint/libs/bytes" 16 tmmath "github.com/okex/exchain/libs/tendermint/libs/math" 17 tmproto "github.com/okex/exchain/libs/tendermint/proto/types" 18 ) 19 20 var ( 21 ErrPartSetUnexpectedIndex = errors.New("error part set unexpected index") 22 ErrPartSetInvalidProof = errors.New("error part set invalid proof") 23 ) 24 25 type Part struct { 26 Index int `json:"index"` 27 Bytes tmbytes.HexBytes `json:"bytes"` 28 Proof merkle.SimpleProof `json:"proof"` 29 } 30 31 func (part *Part) UnmarshalFromAmino(cdc *amino.Codec, data []byte) error { 32 var dataLen uint64 = 0 33 var subData []byte 34 35 for { 36 data = data[dataLen:] 37 38 if len(data) == 0 { 39 break 40 } 41 42 pos, aminoType, err := amino.ParseProtoPosAndTypeMustOneByte(data[0]) 43 if err != nil { 44 return err 45 } 46 data = data[1:] 47 48 if aminoType == amino.Typ3_ByteLength { 49 var n int 50 dataLen, n, err = amino.DecodeUvarint(data) 51 if err != nil { 52 return err 53 } 54 55 data = data[n:] 56 if len(data) < int(dataLen) { 57 return fmt.Errorf("not enough data for %s, need %d, have %d", aminoType, dataLen, len(data)) 58 } 59 subData = data[:dataLen] 60 } 61 62 switch pos { 63 case 1: 64 uvint, n, err := amino.DecodeUvarint(data) 65 if err != nil { 66 return err 67 } 68 part.Index = int(uvint) 69 dataLen = uint64(n) 70 case 2: 71 part.Bytes = make([]byte, dataLen) 72 copy(part.Bytes, subData) 73 case 3: 74 err = part.Proof.UnmarshalFromAmino(cdc, subData) 75 if err != nil { 76 return err 77 } 78 default: 79 return fmt.Errorf("unexpect feild num %d", pos) 80 } 81 } 82 return nil 83 } 84 85 // ValidateBasic performs basic validation. 86 func (part *Part) ValidateBasic() error { 87 if part.Index < 0 { 88 return errors.New("negative Index") 89 } 90 if len(part.Bytes) > BlockPartSizeBytes { 91 return errors.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) 92 } 93 if err := part.Proof.ValidateBasic(); err != nil { 94 return errors.Wrap(err, "wrong Proof") 95 } 96 return nil 97 } 98 99 func (part *Part) String() string { 100 return part.StringIndented("") 101 } 102 103 func (part *Part) StringIndented(indent string) string { 104 return fmt.Sprintf(`Part{#%v 105 %s Bytes: %X... 106 %s Proof: %v 107 %s}`, 108 part.Index, 109 indent, tmbytes.Fingerprint(part.Bytes), 110 indent, part.Proof.StringIndented(indent+" "), 111 indent) 112 } 113 114 //------------------------------------- 115 116 type PartSetHeader struct { 117 Total int `json:"total"` 118 Hash tmbytes.HexBytes `json:"hash"` 119 } 120 121 func (psh PartSetHeader) AminoSize() int { 122 var size int 123 if psh.Total != 0 { 124 size += 1 + amino.UvarintSize(uint64(psh.Total)) 125 } 126 if len(psh.Hash) != 0 { 127 size += 1 + amino.UvarintSize(uint64(len(psh.Hash))) + len(psh.Hash) 128 } 129 return size 130 } 131 132 func (psh *PartSetHeader) UnmarshalFromAmino(_ *amino.Codec, data []byte) error { 133 var dataLen uint64 = 0 134 var subData []byte 135 136 for { 137 data = data[dataLen:] 138 139 if len(data) == 0 { 140 break 141 } 142 143 pos, aminoType, err := amino.ParseProtoPosAndTypeMustOneByte(data[0]) 144 if err != nil { 145 return err 146 } 147 data = data[1:] 148 149 if aminoType == amino.Typ3_ByteLength { 150 var n int 151 dataLen, n, err = amino.DecodeUvarint(data) 152 if err != nil { 153 return err 154 } 155 156 data = data[n:] 157 if len(data) < int(dataLen) { 158 return fmt.Errorf("not enough data for %s, need %d, have %d", aminoType, dataLen, len(data)) 159 } 160 subData = data[:dataLen] 161 } 162 163 switch pos { 164 case 1: 165 var n int 166 var uvint uint64 167 uvint, n, err = amino.DecodeUvarint(data) 168 if err != nil { 169 return err 170 } 171 psh.Total = int(uvint) 172 dataLen = uint64(n) 173 case 2: 174 psh.Hash = make([]byte, dataLen) 175 copy(psh.Hash, subData) 176 default: 177 return fmt.Errorf("unexpect feild num %d", pos) 178 } 179 } 180 return nil 181 } 182 183 func (psh PartSetHeader) String() string { 184 return fmt.Sprintf("%v:%X", psh.Total, tmbytes.Fingerprint(psh.Hash)) 185 } 186 187 func (psh PartSetHeader) IsZero() bool { 188 return psh.Total == 0 && len(psh.Hash) == 0 189 } 190 191 func (psh PartSetHeader) Equals(other PartSetHeader) bool { 192 return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) 193 } 194 195 // ValidateBasic performs basic validation. 196 func (psh PartSetHeader) ValidateBasic() error { 197 if psh.Total < 0 { 198 return errors.New("negative Total") 199 } 200 // Hash can be empty in case of POLBlockID.PartsHeader in Proposal. 201 if err := ValidateHash(psh.Hash); err != nil { 202 return errors.Wrap(err, "Wrong Hash") 203 } 204 return nil 205 } 206 207 // ToProto converts BloPartSetHeaderckID to protobuf 208 func (psh *PartSetHeader) ToProto() tmproto.PartSetHeader { 209 if psh == nil { 210 return tmproto.PartSetHeader{} 211 } 212 213 return tmproto.PartSetHeader{ 214 Total: int64(psh.Total), 215 Hash: psh.Hash, 216 } 217 } 218 219 func (psh *PartSetHeader) ToIBCProto() tmproto.PartSetHeader { 220 if psh == nil { 221 return tmproto.PartSetHeader{} 222 } 223 return tmproto.PartSetHeader{ 224 Total: int64(psh.Total), 225 Hash: psh.Hash, 226 } 227 } 228 229 // FromProto sets a protobuf PartSetHeader to the given pointer 230 func PartSetHeaderFromProto(ppsh *tmproto.PartSetHeader) (*PartSetHeader, error) { 231 if ppsh == nil { 232 return nil, errors.New("nil PartSetHeader") 233 } 234 psh := new(PartSetHeader) 235 psh.Total = int(ppsh.Total) 236 psh.Hash = ppsh.Hash 237 238 return psh, psh.ValidateBasic() 239 } 240 241 //------------------------------------- 242 243 type PartSet struct { 244 total int 245 hash []byte 246 247 mtx sync.Mutex 248 parts []*Part 249 partsBitArray *bits.BitArray 250 count int 251 } 252 253 // Returns an immutable, full PartSet from the data bytes. 254 // The data bytes are split into "partSize" chunks, and merkle tree computed. 255 func NewPartSetFromData(data []byte, partSize int) *PartSet { 256 // divide data into 4kb parts. 257 total := (len(data) + partSize - 1) / partSize 258 parts := make([]*Part, total) 259 partsBytes := make([][]byte, total) 260 partsBitArray := bits.NewBitArray(total) 261 for i := 0; i < total; i++ { 262 part := &Part{ 263 Index: i, 264 Bytes: data[i*partSize : tmmath.MinInt(len(data), (i+1)*partSize)], 265 } 266 parts[i] = part 267 partsBytes[i] = part.Bytes 268 partsBitArray.SetIndex(i, true) 269 } 270 // Compute merkle proofs 271 root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes) 272 for i := 0; i < total; i++ { 273 parts[i].Proof = *proofs[i] 274 } 275 return &PartSet{ 276 total: total, 277 hash: root, 278 parts: parts, 279 partsBitArray: partsBitArray, 280 count: total, 281 } 282 } 283 284 // Returns an empty PartSet ready to be populated. 285 func NewPartSetFromHeader(header PartSetHeader) *PartSet { 286 return &PartSet{ 287 total: header.Total, 288 hash: header.Hash, 289 parts: make([]*Part, header.Total), 290 partsBitArray: bits.NewBitArray(header.Total), 291 count: 0, 292 } 293 } 294 295 func (ps *PartSet) Header() PartSetHeader { 296 if ps == nil { 297 return PartSetHeader{} 298 } 299 return PartSetHeader{ 300 Total: ps.total, 301 Hash: ps.hash, 302 } 303 } 304 305 func (ps *PartSet) HasHeader(header PartSetHeader) bool { 306 if ps == nil { 307 return false 308 } 309 return ps.Header().Equals(header) 310 } 311 312 func (ps *PartSet) BitArray() *bits.BitArray { 313 ps.mtx.Lock() 314 defer ps.mtx.Unlock() 315 return ps.partsBitArray.Copy() 316 } 317 318 func (ps *PartSet) Hash() []byte { 319 if ps == nil { 320 return nil 321 } 322 return ps.hash 323 } 324 325 func (ps *PartSet) HashesTo(hash []byte) bool { 326 if ps == nil { 327 return false 328 } 329 return bytes.Equal(ps.hash, hash) 330 } 331 332 func (ps *PartSet) Count() int { 333 if ps == nil { 334 return 0 335 } 336 return ps.count 337 } 338 339 func (ps *PartSet) Total() int { 340 if ps == nil { 341 return 0 342 } 343 return ps.total 344 } 345 346 func (ps *PartSet) AddPart(part *Part) (bool, error) { 347 if ps == nil { 348 return false, nil 349 } 350 ps.mtx.Lock() 351 defer ps.mtx.Unlock() 352 353 // Invalid part index 354 if part.Index >= ps.total { 355 return false, ErrPartSetUnexpectedIndex 356 } 357 358 // If part already exists, return false. 359 if ps.parts[part.Index] != nil { 360 return false, nil 361 } 362 363 // Check hash proof 364 if part.Proof.Verify(ps.Hash(), part.Bytes) != nil { 365 return false, ErrPartSetInvalidProof 366 } 367 368 // Add part 369 ps.parts[part.Index] = part 370 ps.partsBitArray.SetIndex(part.Index, true) 371 ps.count++ 372 return true, nil 373 } 374 375 func (ps *PartSet) GetPart(index int) *Part { 376 ps.mtx.Lock() 377 defer ps.mtx.Unlock() 378 return ps.parts[index] 379 } 380 381 func (ps *PartSet) IsComplete() bool { 382 return ps.count == ps.total 383 } 384 385 func (ps *PartSet) GetReader() io.Reader { 386 if !ps.IsComplete() { 387 panic("Cannot GetReader() on incomplete PartSet") 388 } 389 return NewPartSetReader(ps.parts) 390 } 391 392 type PartSetReader struct { 393 i int 394 parts []*Part 395 reader *bytes.Reader 396 } 397 398 func NewPartSetReader(parts []*Part) *PartSetReader { 399 return &PartSetReader{ 400 i: 0, 401 parts: parts, 402 reader: bytes.NewReader(parts[0].Bytes), 403 } 404 } 405 406 func (psr *PartSetReader) Read(p []byte) (n int, err error) { 407 readerLen := psr.reader.Len() 408 if readerLen >= len(p) { 409 return psr.reader.Read(p) 410 } else if readerLen > 0 { 411 n1, err := psr.Read(p[:readerLen]) 412 if err != nil { 413 return n1, err 414 } 415 n2, err := psr.Read(p[readerLen:]) 416 return n1 + n2, err 417 } 418 419 psr.i++ 420 if psr.i >= len(psr.parts) { 421 return 0, io.EOF 422 } 423 psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) 424 return psr.Read(p) 425 } 426 427 func (ps *PartSet) StringShort() string { 428 if ps == nil { 429 return "nil-PartSet" 430 } 431 ps.mtx.Lock() 432 defer ps.mtx.Unlock() 433 return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) 434 } 435 436 func (ps *PartSet) MarshalJSON() ([]byte, error) { 437 if ps == nil { 438 return []byte("{}"), nil 439 } 440 441 ps.mtx.Lock() 442 defer ps.mtx.Unlock() 443 444 return cdc.MarshalJSON(struct { 445 CountTotal string `json:"count/total"` 446 PartsBitArray *bits.BitArray `json:"parts_bit_array"` 447 }{ 448 fmt.Sprintf("%d/%d", ps.Count(), ps.Total()), 449 ps.partsBitArray, 450 }) 451 }