github.com/evdatsion/aphelion-dpos-bft@v0.32.1/types/part_set.go (about) 1 package types 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "sync" 8 9 "github.com/pkg/errors" 10 11 "github.com/evdatsion/aphelion-dpos-bft/crypto/merkle" 12 cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common" 13 ) 14 15 var ( 16 ErrPartSetUnexpectedIndex = errors.New("Error part set unexpected index") 17 ErrPartSetInvalidProof = errors.New("Error part set invalid proof") 18 ) 19 20 type Part struct { 21 Index int `json:"index"` 22 Bytes cmn.HexBytes `json:"bytes"` 23 Proof merkle.SimpleProof `json:"proof"` 24 } 25 26 // ValidateBasic performs basic validation. 27 func (part *Part) ValidateBasic() error { 28 if part.Index < 0 { 29 return errors.New("Negative Index") 30 } 31 if len(part.Bytes) > BlockPartSizeBytes { 32 return fmt.Errorf("Too big (max: %d)", BlockPartSizeBytes) 33 } 34 return nil 35 } 36 37 func (part *Part) String() string { 38 return part.StringIndented("") 39 } 40 41 func (part *Part) StringIndented(indent string) string { 42 return fmt.Sprintf(`Part{#%v 43 %s Bytes: %X... 44 %s Proof: %v 45 %s}`, 46 part.Index, 47 indent, cmn.Fingerprint(part.Bytes), 48 indent, part.Proof.StringIndented(indent+" "), 49 indent) 50 } 51 52 //------------------------------------- 53 54 type PartSetHeader struct { 55 Total int `json:"total"` 56 Hash cmn.HexBytes `json:"hash"` 57 } 58 59 func (psh PartSetHeader) String() string { 60 return fmt.Sprintf("%v:%X", psh.Total, cmn.Fingerprint(psh.Hash)) 61 } 62 63 func (psh PartSetHeader) IsZero() bool { 64 return psh.Total == 0 && len(psh.Hash) == 0 65 } 66 67 func (psh PartSetHeader) Equals(other PartSetHeader) bool { 68 return psh.Total == other.Total && bytes.Equal(psh.Hash, other.Hash) 69 } 70 71 // ValidateBasic performs basic validation. 72 func (psh PartSetHeader) ValidateBasic() error { 73 if psh.Total < 0 { 74 return errors.New("Negative Total") 75 } 76 // Hash can be empty in case of POLBlockID.PartsHeader in Proposal. 77 if err := ValidateHash(psh.Hash); err != nil { 78 return errors.Wrap(err, "Wrong Hash") 79 } 80 return nil 81 } 82 83 //------------------------------------- 84 85 type PartSet struct { 86 total int 87 hash []byte 88 89 mtx sync.Mutex 90 parts []*Part 91 partsBitArray *cmn.BitArray 92 count int 93 } 94 95 // Returns an immutable, full PartSet from the data bytes. 96 // The data bytes are split into "partSize" chunks, and merkle tree computed. 97 func NewPartSetFromData(data []byte, partSize int) *PartSet { 98 // divide data into 4kb parts. 99 total := (len(data) + partSize - 1) / partSize 100 parts := make([]*Part, total) 101 partsBytes := make([][]byte, total) 102 partsBitArray := cmn.NewBitArray(total) 103 for i := 0; i < total; i++ { 104 part := &Part{ 105 Index: i, 106 Bytes: data[i*partSize : cmn.MinInt(len(data), (i+1)*partSize)], 107 } 108 parts[i] = part 109 partsBytes[i] = part.Bytes 110 partsBitArray.SetIndex(i, true) 111 } 112 // Compute merkle proofs 113 root, proofs := merkle.SimpleProofsFromByteSlices(partsBytes) 114 for i := 0; i < total; i++ { 115 parts[i].Proof = *proofs[i] 116 } 117 return &PartSet{ 118 total: total, 119 hash: root, 120 parts: parts, 121 partsBitArray: partsBitArray, 122 count: total, 123 } 124 } 125 126 // Returns an empty PartSet ready to be populated. 127 func NewPartSetFromHeader(header PartSetHeader) *PartSet { 128 return &PartSet{ 129 total: header.Total, 130 hash: header.Hash, 131 parts: make([]*Part, header.Total), 132 partsBitArray: cmn.NewBitArray(header.Total), 133 count: 0, 134 } 135 } 136 137 func (ps *PartSet) Header() PartSetHeader { 138 if ps == nil { 139 return PartSetHeader{} 140 } 141 return PartSetHeader{ 142 Total: ps.total, 143 Hash: ps.hash, 144 } 145 } 146 147 func (ps *PartSet) HasHeader(header PartSetHeader) bool { 148 if ps == nil { 149 return false 150 } 151 return ps.Header().Equals(header) 152 } 153 154 func (ps *PartSet) BitArray() *cmn.BitArray { 155 ps.mtx.Lock() 156 defer ps.mtx.Unlock() 157 return ps.partsBitArray.Copy() 158 } 159 160 func (ps *PartSet) Hash() []byte { 161 if ps == nil { 162 return nil 163 } 164 return ps.hash 165 } 166 167 func (ps *PartSet) HashesTo(hash []byte) bool { 168 if ps == nil { 169 return false 170 } 171 return bytes.Equal(ps.hash, hash) 172 } 173 174 func (ps *PartSet) Count() int { 175 if ps == nil { 176 return 0 177 } 178 return ps.count 179 } 180 181 func (ps *PartSet) Total() int { 182 if ps == nil { 183 return 0 184 } 185 return ps.total 186 } 187 188 func (ps *PartSet) AddPart(part *Part) (bool, error) { 189 if ps == nil { 190 return false, nil 191 } 192 ps.mtx.Lock() 193 defer ps.mtx.Unlock() 194 195 // Invalid part index 196 if part.Index >= ps.total { 197 return false, ErrPartSetUnexpectedIndex 198 } 199 200 // If part already exists, return false. 201 if ps.parts[part.Index] != nil { 202 return false, nil 203 } 204 205 // Check hash proof 206 if part.Proof.Verify(ps.Hash(), part.Bytes) != nil { 207 return false, ErrPartSetInvalidProof 208 } 209 210 // Add part 211 ps.parts[part.Index] = part 212 ps.partsBitArray.SetIndex(part.Index, true) 213 ps.count++ 214 return true, nil 215 } 216 217 func (ps *PartSet) GetPart(index int) *Part { 218 ps.mtx.Lock() 219 defer ps.mtx.Unlock() 220 return ps.parts[index] 221 } 222 223 func (ps *PartSet) IsComplete() bool { 224 return ps.count == ps.total 225 } 226 227 func (ps *PartSet) GetReader() io.Reader { 228 if !ps.IsComplete() { 229 panic("Cannot GetReader() on incomplete PartSet") 230 } 231 return NewPartSetReader(ps.parts) 232 } 233 234 type PartSetReader struct { 235 i int 236 parts []*Part 237 reader *bytes.Reader 238 } 239 240 func NewPartSetReader(parts []*Part) *PartSetReader { 241 return &PartSetReader{ 242 i: 0, 243 parts: parts, 244 reader: bytes.NewReader(parts[0].Bytes), 245 } 246 } 247 248 func (psr *PartSetReader) Read(p []byte) (n int, err error) { 249 readerLen := psr.reader.Len() 250 if readerLen >= len(p) { 251 return psr.reader.Read(p) 252 } else if readerLen > 0 { 253 n1, err := psr.Read(p[:readerLen]) 254 if err != nil { 255 return n1, err 256 } 257 n2, err := psr.Read(p[readerLen:]) 258 return n1 + n2, err 259 } 260 261 psr.i++ 262 if psr.i >= len(psr.parts) { 263 return 0, io.EOF 264 } 265 psr.reader = bytes.NewReader(psr.parts[psr.i].Bytes) 266 return psr.Read(p) 267 } 268 269 func (ps *PartSet) StringShort() string { 270 if ps == nil { 271 return "nil-PartSet" 272 } 273 ps.mtx.Lock() 274 defer ps.mtx.Unlock() 275 return fmt.Sprintf("(%v of %v)", ps.Count(), ps.Total()) 276 } 277 278 func (ps *PartSet) MarshalJSON() ([]byte, error) { 279 if ps == nil { 280 return []byte("{}"), nil 281 } 282 283 ps.mtx.Lock() 284 defer ps.mtx.Unlock() 285 286 return cdc.MarshalJSON(struct { 287 CountTotal string `json:"count/total"` 288 PartsBitArray *cmn.BitArray `json:"parts_bit_array"` 289 }{ 290 fmt.Sprintf("%d/%d", ps.Count(), ps.Total()), 291 ps.partsBitArray, 292 }) 293 }