github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/state/v1/state_trie.go (about) 1 package v1 2 3 import ( 4 "context" 5 "runtime" 6 "sort" 7 "sync" 8 9 "github.com/pkg/errors" 10 "github.com/prometheus/client_golang/prometheus" 11 "github.com/prometheus/client_golang/prometheus/promauto" 12 iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface" 13 "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" 14 pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" 15 v1 "github.com/prysmaticlabs/prysm/proto/eth/v1" 16 "github.com/prysmaticlabs/prysm/shared/bytesutil" 17 "github.com/prysmaticlabs/prysm/shared/hashutil" 18 "github.com/prysmaticlabs/prysm/shared/htrutils" 19 "github.com/prysmaticlabs/prysm/shared/params" 20 "github.com/prysmaticlabs/prysm/shared/sliceutil" 21 "go.opencensus.io/trace" 22 "google.golang.org/protobuf/proto" 23 ) 24 25 var ( 26 stateCount = promauto.NewGauge(prometheus.GaugeOpts{ 27 Name: "beacon_state_count", 28 Help: "Count the number of active beacon state objects.", 29 }) 30 ) 31 32 // InitializeFromProto the beacon state from a protobuf representation. 33 func InitializeFromProto(st *pbp2p.BeaconState) (*BeaconState, error) { 34 return InitializeFromProtoUnsafe(proto.Clone(st).(*pbp2p.BeaconState)) 35 } 36 37 // InitializeFromProtoUnsafe directly uses the beacon state protobuf pointer 38 // and sets it as the inner state of the BeaconState type. 39 func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) { 40 if st == nil { 41 return nil, errors.New("received nil state") 42 } 43 44 fieldCount := params.BeaconConfig().BeaconStateFieldCount 45 b := &BeaconState{ 46 state: st, 47 dirtyFields: make(map[fieldIndex]bool, fieldCount), 48 dirtyIndices: make(map[fieldIndex][]uint64, fieldCount), 49 stateFieldLeaves: make(map[fieldIndex]*FieldTrie, fieldCount), 50 sharedFieldReferences: make(map[fieldIndex]*stateutil.Reference, 10), 51 rebuildTrie: make(map[fieldIndex]bool, fieldCount), 52 valMapHandler: stateutil.NewValMapHandler(st.Validators), 53 } 54 55 for i := 0; i < fieldCount; i++ { 56 b.dirtyFields[fieldIndex(i)] = true 57 b.rebuildTrie[fieldIndex(i)] = true 58 b.dirtyIndices[fieldIndex(i)] = []uint64{} 59 b.stateFieldLeaves[fieldIndex(i)] = &FieldTrie{ 60 field: fieldIndex(i), 61 reference: stateutil.NewRef(1), 62 RWMutex: new(sync.RWMutex), 63 } 64 } 65 66 // Initialize field reference tracking for shared data. 67 b.sharedFieldReferences[randaoMixes] = stateutil.NewRef(1) 68 b.sharedFieldReferences[stateRoots] = stateutil.NewRef(1) 69 b.sharedFieldReferences[blockRoots] = stateutil.NewRef(1) 70 b.sharedFieldReferences[previousEpochAttestations] = stateutil.NewRef(1) 71 b.sharedFieldReferences[currentEpochAttestations] = stateutil.NewRef(1) 72 b.sharedFieldReferences[slashings] = stateutil.NewRef(1) 73 b.sharedFieldReferences[eth1DataVotes] = stateutil.NewRef(1) 74 b.sharedFieldReferences[validators] = stateutil.NewRef(1) 75 b.sharedFieldReferences[balances] = stateutil.NewRef(1) 76 b.sharedFieldReferences[historicalRoots] = stateutil.NewRef(1) 77 78 stateCount.Inc() 79 return b, nil 80 } 81 82 // Copy returns a deep copy of the beacon state. 83 func (b *BeaconState) Copy() iface.BeaconState { 84 if !b.hasInnerState() { 85 return nil 86 } 87 88 b.lock.RLock() 89 defer b.lock.RUnlock() 90 fieldCount := params.BeaconConfig().BeaconStateFieldCount 91 dst := &BeaconState{ 92 state: &pbp2p.BeaconState{ 93 // Primitive types, safe to copy. 94 GenesisTime: b.state.GenesisTime, 95 Slot: b.state.Slot, 96 Eth1DepositIndex: b.state.Eth1DepositIndex, 97 98 // Large arrays, infrequently changed, constant size. 99 RandaoMixes: b.state.RandaoMixes, 100 StateRoots: b.state.StateRoots, 101 BlockRoots: b.state.BlockRoots, 102 PreviousEpochAttestations: b.state.PreviousEpochAttestations, 103 CurrentEpochAttestations: b.state.CurrentEpochAttestations, 104 Slashings: b.state.Slashings, 105 Eth1DataVotes: b.state.Eth1DataVotes, 106 107 // Large arrays, increases over time. 108 Validators: b.state.Validators, 109 Balances: b.state.Balances, 110 HistoricalRoots: b.state.HistoricalRoots, 111 112 // Everything else, too small to be concerned about, constant size. 113 Fork: b.fork(), 114 LatestBlockHeader: b.latestBlockHeader(), 115 Eth1Data: b.eth1Data(), 116 JustificationBits: b.justificationBits(), 117 PreviousJustifiedCheckpoint: b.previousJustifiedCheckpoint(), 118 CurrentJustifiedCheckpoint: b.currentJustifiedCheckpoint(), 119 FinalizedCheckpoint: b.finalizedCheckpoint(), 120 GenesisValidatorsRoot: b.genesisValidatorRoot(), 121 }, 122 dirtyFields: make(map[fieldIndex]bool, fieldCount), 123 dirtyIndices: make(map[fieldIndex][]uint64, fieldCount), 124 rebuildTrie: make(map[fieldIndex]bool, fieldCount), 125 sharedFieldReferences: make(map[fieldIndex]*stateutil.Reference, 10), 126 stateFieldLeaves: make(map[fieldIndex]*FieldTrie, fieldCount), 127 128 // Copy on write validator index map. 129 valMapHandler: b.valMapHandler, 130 } 131 132 for field, ref := range b.sharedFieldReferences { 133 ref.AddRef() 134 dst.sharedFieldReferences[field] = ref 135 } 136 137 // Increment ref for validator map 138 b.valMapHandler.AddRef() 139 140 for i := range b.dirtyFields { 141 dst.dirtyFields[i] = true 142 } 143 144 for i := range b.dirtyIndices { 145 indices := make([]uint64, len(b.dirtyIndices[i])) 146 copy(indices, b.dirtyIndices[i]) 147 dst.dirtyIndices[i] = indices 148 } 149 150 for i := range b.rebuildTrie { 151 dst.rebuildTrie[i] = true 152 } 153 154 for fldIdx, fieldTrie := range b.stateFieldLeaves { 155 dst.stateFieldLeaves[fldIdx] = fieldTrie 156 if fieldTrie.reference != nil { 157 fieldTrie.Lock() 158 fieldTrie.reference.AddRef() 159 fieldTrie.Unlock() 160 } 161 } 162 163 if b.merkleLayers != nil { 164 dst.merkleLayers = make([][][]byte, len(b.merkleLayers)) 165 for i, layer := range b.merkleLayers { 166 dst.merkleLayers[i] = make([][]byte, len(layer)) 167 for j, content := range layer { 168 dst.merkleLayers[i][j] = make([]byte, len(content)) 169 copy(dst.merkleLayers[i][j], content) 170 } 171 } 172 } 173 174 stateCount.Inc() 175 // Finalizer runs when dst is being destroyed in garbage collection. 176 runtime.SetFinalizer(dst, func(b *BeaconState) { 177 for field, v := range b.sharedFieldReferences { 178 v.MinusRef() 179 if b.stateFieldLeaves[field].reference != nil { 180 b.stateFieldLeaves[field].reference.MinusRef() 181 } 182 183 } 184 for i := 0; i < fieldCount; i++ { 185 field := fieldIndex(i) 186 delete(b.stateFieldLeaves, field) 187 delete(b.dirtyIndices, field) 188 delete(b.dirtyFields, field) 189 delete(b.sharedFieldReferences, field) 190 delete(b.stateFieldLeaves, field) 191 } 192 stateCount.Sub(1) 193 }) 194 return dst 195 } 196 197 // HashTreeRoot of the beacon state retrieves the Merkle root of the trie 198 // representation of the beacon state based on the Ethereum Simple Serialize specification. 199 func (b *BeaconState) HashTreeRoot(ctx context.Context) ([32]byte, error) { 200 ctx, span := trace.StartSpan(ctx, "beaconState.HashTreeRoot") 201 defer span.End() 202 203 b.lock.Lock() 204 defer b.lock.Unlock() 205 206 if b.merkleLayers == nil || len(b.merkleLayers) == 0 { 207 fieldRoots, err := computeFieldRoots(ctx, b.state) 208 if err != nil { 209 return [32]byte{}, err 210 } 211 layers := stateutil.Merkleize(fieldRoots) 212 b.merkleLayers = layers 213 b.dirtyFields = make(map[fieldIndex]bool, params.BeaconConfig().BeaconStateFieldCount) 214 } 215 216 for field := range b.dirtyFields { 217 root, err := b.rootSelector(ctx, field) 218 if err != nil { 219 return [32]byte{}, err 220 } 221 b.merkleLayers[0][field] = root[:] 222 b.recomputeRoot(int(field)) 223 delete(b.dirtyFields, field) 224 } 225 return bytesutil.ToBytes32(b.merkleLayers[len(b.merkleLayers)-1][0]), nil 226 } 227 228 // ToProto returns a protobuf *v1.BeaconState representation of the state. 229 func (b *BeaconState) ToProto() (*v1.BeaconState, error) { 230 sourceFork := b.Fork() 231 sourceLatestBlockHeader := b.LatestBlockHeader() 232 sourceEth1Data := b.Eth1Data() 233 sourceEth1DataVotes := b.Eth1DataVotes() 234 sourceValidators := b.Validators() 235 sourcePrevEpochAtts, err := b.PreviousEpochAttestations() 236 if err != nil { 237 return nil, errors.Wrap(err, "could not get previous epoch attestations") 238 } 239 sourceCurrEpochAtts, err := b.CurrentEpochAttestations() 240 if err != nil { 241 return nil, errors.Wrap(err, "could not get current epoch attestations") 242 } 243 sourcePrevJustifiedCheckpoint := b.PreviousJustifiedCheckpoint() 244 sourceCurrJustifiedCheckpoint := b.CurrentJustifiedCheckpoint() 245 sourceFinalizedCheckpoint := b.FinalizedCheckpoint() 246 247 resultEth1DataVotes := make([]*v1.Eth1Data, len(sourceEth1DataVotes)) 248 for i, vote := range sourceEth1DataVotes { 249 resultEth1DataVotes[i] = &v1.Eth1Data{ 250 DepositRoot: vote.DepositRoot, 251 DepositCount: vote.DepositCount, 252 BlockHash: vote.BlockHash, 253 } 254 } 255 resultValidators := make([]*v1.Validator, len(sourceValidators)) 256 for i, validator := range sourceValidators { 257 resultValidators[i] = &v1.Validator{ 258 Pubkey: validator.PublicKey, 259 WithdrawalCredentials: validator.WithdrawalCredentials, 260 EffectiveBalance: validator.EffectiveBalance, 261 Slashed: validator.Slashed, 262 ActivationEligibilityEpoch: validator.ActivationEligibilityEpoch, 263 ActivationEpoch: validator.ActivationEpoch, 264 ExitEpoch: validator.ExitEpoch, 265 WithdrawableEpoch: validator.WithdrawableEpoch, 266 } 267 } 268 resultPrevEpochAtts := make([]*v1.PendingAttestation, len(sourcePrevEpochAtts)) 269 for i, att := range sourcePrevEpochAtts { 270 data := att.Data 271 resultPrevEpochAtts[i] = &v1.PendingAttestation{ 272 AggregationBits: att.AggregationBits, 273 Data: &v1.AttestationData{ 274 Slot: data.Slot, 275 Index: data.CommitteeIndex, 276 BeaconBlockRoot: data.BeaconBlockRoot, 277 Source: &v1.Checkpoint{ 278 Epoch: data.Source.Epoch, 279 Root: data.Source.Root, 280 }, 281 Target: &v1.Checkpoint{ 282 Epoch: data.Target.Epoch, 283 Root: data.Target.Root, 284 }, 285 }, 286 InclusionDelay: att.InclusionDelay, 287 ProposerIndex: att.ProposerIndex, 288 } 289 } 290 resultCurrEpochAtts := make([]*v1.PendingAttestation, len(sourceCurrEpochAtts)) 291 for i, att := range sourceCurrEpochAtts { 292 data := att.Data 293 resultCurrEpochAtts[i] = &v1.PendingAttestation{ 294 AggregationBits: att.AggregationBits, 295 Data: &v1.AttestationData{ 296 Slot: data.Slot, 297 Index: data.CommitteeIndex, 298 BeaconBlockRoot: data.BeaconBlockRoot, 299 Source: &v1.Checkpoint{ 300 Epoch: data.Source.Epoch, 301 Root: data.Source.Root, 302 }, 303 Target: &v1.Checkpoint{ 304 Epoch: data.Target.Epoch, 305 Root: data.Target.Root, 306 }, 307 }, 308 InclusionDelay: att.InclusionDelay, 309 ProposerIndex: att.ProposerIndex, 310 } 311 } 312 result := &v1.BeaconState{ 313 GenesisTime: b.GenesisTime(), 314 GenesisValidatorsRoot: b.GenesisValidatorRoot(), 315 Slot: b.Slot(), 316 Fork: &v1.Fork{ 317 PreviousVersion: sourceFork.PreviousVersion, 318 CurrentVersion: sourceFork.CurrentVersion, 319 Epoch: sourceFork.Epoch, 320 }, 321 LatestBlockHeader: &v1.BeaconBlockHeader{ 322 Slot: sourceLatestBlockHeader.Slot, 323 ProposerIndex: sourceLatestBlockHeader.ProposerIndex, 324 ParentRoot: sourceLatestBlockHeader.ParentRoot, 325 StateRoot: sourceLatestBlockHeader.StateRoot, 326 BodyRoot: sourceLatestBlockHeader.BodyRoot, 327 }, 328 BlockRoots: b.BlockRoots(), 329 StateRoots: b.StateRoots(), 330 HistoricalRoots: b.HistoricalRoots(), 331 Eth1Data: &v1.Eth1Data{ 332 DepositRoot: sourceEth1Data.DepositRoot, 333 DepositCount: sourceEth1Data.DepositCount, 334 BlockHash: sourceEth1Data.BlockHash, 335 }, 336 Eth1DataVotes: resultEth1DataVotes, 337 Eth1DepositIndex: b.Eth1DepositIndex(), 338 Validators: resultValidators, 339 Balances: b.Balances(), 340 RandaoMixes: b.RandaoMixes(), 341 Slashings: b.Slashings(), 342 PreviousEpochAttestations: resultPrevEpochAtts, 343 CurrentEpochAttestations: resultCurrEpochAtts, 344 JustificationBits: b.JustificationBits(), 345 PreviousJustifiedCheckpoint: &v1.Checkpoint{ 346 Epoch: sourcePrevJustifiedCheckpoint.Epoch, 347 Root: sourcePrevJustifiedCheckpoint.Root, 348 }, 349 CurrentJustifiedCheckpoint: &v1.Checkpoint{ 350 Epoch: sourceCurrJustifiedCheckpoint.Epoch, 351 Root: sourceCurrJustifiedCheckpoint.Root, 352 }, 353 FinalizedCheckpoint: &v1.Checkpoint{ 354 Epoch: sourceFinalizedCheckpoint.Epoch, 355 Root: sourceFinalizedCheckpoint.Root, 356 }, 357 } 358 359 return result, nil 360 } 361 362 // FieldReferencesCount returns the reference count held by each field. This 363 // also includes the field trie held by each field. 364 func (b *BeaconState) FieldReferencesCount() map[string]uint64 { 365 refMap := make(map[string]uint64) 366 b.lock.RLock() 367 defer b.lock.RUnlock() 368 for i, f := range b.sharedFieldReferences { 369 refMap[i.String()] = uint64(f.Refs()) 370 } 371 for i, f := range b.stateFieldLeaves { 372 numOfRefs := uint64(f.reference.Refs()) 373 f.RLock() 374 if len(f.fieldLayers) != 0 { 375 refMap[i.String()+"_trie"] = numOfRefs 376 } 377 f.RUnlock() 378 } 379 return refMap 380 } 381 382 // IsNil checks if the state and the underlying proto 383 // object are nil. 384 func (b *BeaconState) IsNil() bool { 385 return b == nil || b.state == nil 386 } 387 388 func (b *BeaconState) rootSelector(ctx context.Context, field fieldIndex) ([32]byte, error) { 389 ctx, span := trace.StartSpan(ctx, "beaconState.rootSelector") 390 defer span.End() 391 span.AddAttributes(trace.StringAttribute("field", field.String())) 392 393 hasher := hashutil.CustomSHA256Hasher() 394 switch field { 395 case genesisTime: 396 return htrutils.Uint64Root(b.state.GenesisTime), nil 397 case genesisValidatorRoot: 398 return bytesutil.ToBytes32(b.state.GenesisValidatorsRoot), nil 399 case slot: 400 return htrutils.Uint64Root(uint64(b.state.Slot)), nil 401 case eth1DepositIndex: 402 return htrutils.Uint64Root(b.state.Eth1DepositIndex), nil 403 case fork: 404 return htrutils.ForkRoot(b.state.Fork) 405 case latestBlockHeader: 406 return stateutil.BlockHeaderRoot(b.state.LatestBlockHeader) 407 case blockRoots: 408 if b.rebuildTrie[field] { 409 err := b.resetFieldTrie(field, b.state.BlockRoots, uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) 410 if err != nil { 411 return [32]byte{}, err 412 } 413 b.dirtyIndices[field] = []uint64{} 414 delete(b.rebuildTrie, field) 415 return b.stateFieldLeaves[field].TrieRoot() 416 } 417 return b.recomputeFieldTrie(blockRoots, b.state.BlockRoots) 418 case stateRoots: 419 if b.rebuildTrie[field] { 420 err := b.resetFieldTrie(field, b.state.StateRoots, uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) 421 if err != nil { 422 return [32]byte{}, err 423 } 424 b.dirtyIndices[field] = []uint64{} 425 delete(b.rebuildTrie, field) 426 return b.stateFieldLeaves[field].TrieRoot() 427 } 428 return b.recomputeFieldTrie(stateRoots, b.state.StateRoots) 429 case historicalRoots: 430 return htrutils.HistoricalRootsRoot(b.state.HistoricalRoots) 431 case eth1Data: 432 return eth1Root(hasher, b.state.Eth1Data) 433 case eth1DataVotes: 434 if b.rebuildTrie[field] { 435 err := b.resetFieldTrie( 436 field, 437 b.state.Eth1DataVotes, 438 uint64(params.BeaconConfig().SlotsPerEpoch.Mul(uint64(params.BeaconConfig().EpochsPerEth1VotingPeriod))), 439 ) 440 if err != nil { 441 return [32]byte{}, err 442 } 443 b.dirtyIndices[field] = []uint64{} 444 delete(b.rebuildTrie, field) 445 return b.stateFieldLeaves[field].TrieRoot() 446 } 447 return b.recomputeFieldTrie(field, b.state.Eth1DataVotes) 448 case validators: 449 if b.rebuildTrie[field] { 450 err := b.resetFieldTrie(field, b.state.Validators, params.BeaconConfig().ValidatorRegistryLimit) 451 if err != nil { 452 return [32]byte{}, err 453 } 454 b.dirtyIndices[validators] = []uint64{} 455 delete(b.rebuildTrie, validators) 456 return b.stateFieldLeaves[field].TrieRoot() 457 } 458 return b.recomputeFieldTrie(validators, b.state.Validators) 459 case balances: 460 return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances) 461 case randaoMixes: 462 if b.rebuildTrie[field] { 463 err := b.resetFieldTrie(field, b.state.RandaoMixes, uint64(params.BeaconConfig().EpochsPerHistoricalVector)) 464 if err != nil { 465 return [32]byte{}, err 466 } 467 b.dirtyIndices[field] = []uint64{} 468 delete(b.rebuildTrie, field) 469 return b.stateFieldLeaves[field].TrieRoot() 470 } 471 return b.recomputeFieldTrie(randaoMixes, b.state.RandaoMixes) 472 case slashings: 473 return htrutils.SlashingsRoot(b.state.Slashings) 474 case previousEpochAttestations: 475 if b.rebuildTrie[field] { 476 err := b.resetFieldTrie( 477 field, 478 b.state.PreviousEpochAttestations, 479 uint64(params.BeaconConfig().SlotsPerEpoch.Mul(params.BeaconConfig().MaxAttestations)), 480 ) 481 if err != nil { 482 return [32]byte{}, err 483 } 484 b.dirtyIndices[field] = []uint64{} 485 delete(b.rebuildTrie, field) 486 return b.stateFieldLeaves[field].TrieRoot() 487 } 488 return b.recomputeFieldTrie(field, b.state.PreviousEpochAttestations) 489 case currentEpochAttestations: 490 if b.rebuildTrie[field] { 491 err := b.resetFieldTrie( 492 field, 493 b.state.CurrentEpochAttestations, 494 uint64(params.BeaconConfig().SlotsPerEpoch.Mul(params.BeaconConfig().MaxAttestations)), 495 ) 496 if err != nil { 497 return [32]byte{}, err 498 } 499 b.dirtyIndices[field] = []uint64{} 500 delete(b.rebuildTrie, field) 501 return b.stateFieldLeaves[field].TrieRoot() 502 } 503 return b.recomputeFieldTrie(field, b.state.CurrentEpochAttestations) 504 case justificationBits: 505 return bytesutil.ToBytes32(b.state.JustificationBits), nil 506 case previousJustifiedCheckpoint: 507 return htrutils.CheckpointRoot(hasher, b.state.PreviousJustifiedCheckpoint) 508 case currentJustifiedCheckpoint: 509 return htrutils.CheckpointRoot(hasher, b.state.CurrentJustifiedCheckpoint) 510 case finalizedCheckpoint: 511 return htrutils.CheckpointRoot(hasher, b.state.FinalizedCheckpoint) 512 } 513 return [32]byte{}, errors.New("invalid field index provided") 514 } 515 516 func (b *BeaconState) recomputeFieldTrie(index fieldIndex, elements interface{}) ([32]byte, error) { 517 fTrie := b.stateFieldLeaves[index] 518 if fTrie.reference.Refs() > 1 { 519 fTrie.Lock() 520 defer fTrie.Unlock() 521 fTrie.reference.MinusRef() 522 newTrie := fTrie.CopyTrie() 523 b.stateFieldLeaves[index] = newTrie 524 fTrie = newTrie 525 } 526 // remove duplicate indexes 527 b.dirtyIndices[index] = sliceutil.SetUint64(b.dirtyIndices[index]) 528 // sort indexes again 529 sort.Slice(b.dirtyIndices[index], func(i int, j int) bool { 530 return b.dirtyIndices[index][i] < b.dirtyIndices[index][j] 531 }) 532 root, err := fTrie.RecomputeTrie(b.dirtyIndices[index], elements) 533 if err != nil { 534 return [32]byte{}, err 535 } 536 b.dirtyIndices[index] = []uint64{} 537 return root, nil 538 } 539 540 func (b *BeaconState) resetFieldTrie(index fieldIndex, elements interface{}, length uint64) error { 541 fTrie, err := NewFieldTrie(index, elements, length) 542 if err != nil { 543 return err 544 } 545 b.stateFieldLeaves[index] = fTrie 546 b.dirtyIndices[index] = []uint64{} 547 return nil 548 }