github.com/MetalBlockchain/metalgo@v1.11.9/network/p2p/validators.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package p2p 5 6 import ( 7 "cmp" 8 "context" 9 "math" 10 "sync" 11 "time" 12 13 "go.uber.org/zap" 14 15 "github.com/MetalBlockchain/metalgo/ids" 16 "github.com/MetalBlockchain/metalgo/snow/validators" 17 "github.com/MetalBlockchain/metalgo/utils" 18 "github.com/MetalBlockchain/metalgo/utils/logging" 19 "github.com/MetalBlockchain/metalgo/utils/sampler" 20 "github.com/MetalBlockchain/metalgo/utils/set" 21 ) 22 23 var ( 24 _ ValidatorSet = (*Validators)(nil) 25 _ ValidatorSubset = (*Validators)(nil) 26 _ NodeSampler = (*Validators)(nil) 27 ) 28 29 type ValidatorSet interface { 30 Has(ctx context.Context, nodeID ids.NodeID) bool // TODO return error 31 } 32 33 type ValidatorSubset interface { 34 Top(ctx context.Context, percentage float64) []ids.NodeID // TODO return error 35 } 36 37 func NewValidators( 38 peers *Peers, 39 log logging.Logger, 40 subnetID ids.ID, 41 validators validators.State, 42 maxValidatorSetStaleness time.Duration, 43 ) *Validators { 44 return &Validators{ 45 peers: peers, 46 log: log, 47 subnetID: subnetID, 48 validators: validators, 49 maxValidatorSetStaleness: maxValidatorSetStaleness, 50 } 51 } 52 53 // Validators contains a set of nodes that are staking. 54 type Validators struct { 55 peers *Peers 56 log logging.Logger 57 subnetID ids.ID 58 validators validators.State 59 maxValidatorSetStaleness time.Duration 60 61 lock sync.Mutex 62 validatorList []validator 63 validatorSet set.Set[ids.NodeID] 64 totalWeight uint64 65 lastUpdated time.Time 66 } 67 68 type validator struct { 69 nodeID ids.NodeID 70 weight uint64 71 } 72 73 func (v validator) Compare(other validator) int { 74 if weightCmp := cmp.Compare(v.weight, other.weight); weightCmp != 0 { 75 return -weightCmp // Sort in decreasing order of stake 76 } 77 return v.nodeID.Compare(other.nodeID) 78 } 79 80 func (v *Validators) refresh(ctx context.Context) { 81 if time.Since(v.lastUpdated) < v.maxValidatorSetStaleness { 82 return 83 } 84 85 // Even though validatorList may be nil, truncating will not panic. 86 v.validatorList = v.validatorList[:0] 87 v.validatorSet.Clear() 88 v.totalWeight = 0 89 90 height, err := v.validators.GetCurrentHeight(ctx) 91 if err != nil { 92 v.log.Warn("failed to get current height", zap.Error(err)) 93 return 94 } 95 validatorSet, err := v.validators.GetValidatorSet(ctx, height, v.subnetID) 96 if err != nil { 97 v.log.Warn("failed to get validator set", zap.Error(err)) 98 return 99 } 100 101 for nodeID, vdr := range validatorSet { 102 v.validatorList = append(v.validatorList, validator{ 103 nodeID: nodeID, 104 weight: vdr.Weight, 105 }) 106 v.validatorSet.Add(nodeID) 107 v.totalWeight += vdr.Weight 108 } 109 utils.Sort(v.validatorList) 110 111 v.lastUpdated = time.Now() 112 } 113 114 // Sample returns a random sample of connected validators 115 func (v *Validators) Sample(ctx context.Context, limit int) []ids.NodeID { 116 v.lock.Lock() 117 defer v.lock.Unlock() 118 119 v.refresh(ctx) 120 121 var ( 122 uniform = sampler.NewUniform() 123 sampled = make([]ids.NodeID, 0, limit) 124 ) 125 126 uniform.Initialize(uint64(len(v.validatorList))) 127 for len(sampled) < limit { 128 i, hasNext := uniform.Next() 129 if !hasNext { 130 break 131 } 132 133 nodeID := v.validatorList[i].nodeID 134 if !v.peers.has(nodeID) { 135 continue 136 } 137 138 sampled = append(sampled, nodeID) 139 } 140 141 return sampled 142 } 143 144 // Top returns the top [percentage] of validators, regardless of if they are 145 // connected or not. 146 func (v *Validators) Top(ctx context.Context, percentage float64) []ids.NodeID { 147 percentage = max(0, min(1, percentage)) // bound percentage inside [0, 1] 148 149 v.lock.Lock() 150 defer v.lock.Unlock() 151 152 v.refresh(ctx) 153 154 var ( 155 maxSize = int(math.Ceil(percentage * float64(len(v.validatorList)))) 156 top = make([]ids.NodeID, 0, maxSize) 157 currentStake uint64 158 targetStake = uint64(math.Ceil(percentage * float64(v.totalWeight))) 159 ) 160 161 for _, vdr := range v.validatorList { 162 if currentStake >= targetStake { 163 break 164 } 165 top = append(top, vdr.nodeID) 166 currentStake += vdr.weight 167 } 168 169 return top 170 } 171 172 // Has returns if nodeID is a connected validator 173 func (v *Validators) Has(ctx context.Context, nodeID ids.NodeID) bool { 174 v.lock.Lock() 175 defer v.lock.Unlock() 176 177 v.refresh(ctx) 178 179 return v.peers.has(nodeID) && v.validatorSet.Contains(nodeID) 180 }