github.com/MetalBlockchain/metalgo@v1.11.9/snow/validators/set.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package validators 5 6 import ( 7 "errors" 8 "fmt" 9 "math/big" 10 "slices" 11 "strings" 12 "sync" 13 14 "github.com/MetalBlockchain/metalgo/ids" 15 "github.com/MetalBlockchain/metalgo/utils/crypto/bls" 16 "github.com/MetalBlockchain/metalgo/utils/formatting" 17 "github.com/MetalBlockchain/metalgo/utils/math" 18 "github.com/MetalBlockchain/metalgo/utils/sampler" 19 "github.com/MetalBlockchain/metalgo/utils/set" 20 ) 21 22 var ( 23 errDuplicateValidator = errors.New("duplicate validator") 24 errMissingValidator = errors.New("missing validator") 25 errTotalWeightNotUint64 = errors.New("total weight is not a uint64") 26 errInsufficientWeight = errors.New("insufficient weight") 27 ) 28 29 // newSet returns a new, empty set of validators. 30 func newSet(subnetID ids.ID, callbackListeners []ManagerCallbackListener) *vdrSet { 31 return &vdrSet{ 32 subnetID: subnetID, 33 vdrs: make(map[ids.NodeID]*Validator), 34 totalWeight: new(big.Int), 35 sampler: sampler.NewWeightedWithoutReplacement(), 36 managerCallbackListeners: slices.Clone(callbackListeners), 37 } 38 } 39 40 type vdrSet struct { 41 subnetID ids.ID 42 43 lock sync.RWMutex 44 vdrs map[ids.NodeID]*Validator 45 vdrSlice []*Validator 46 weights []uint64 47 totalWeight *big.Int 48 49 samplerInitialized bool 50 sampler sampler.WeightedWithoutReplacement 51 52 managerCallbackListeners []ManagerCallbackListener 53 setCallbackListeners []SetCallbackListener 54 } 55 56 func (s *vdrSet) Add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { 57 s.lock.Lock() 58 defer s.lock.Unlock() 59 60 return s.add(nodeID, pk, txID, weight) 61 } 62 63 func (s *vdrSet) add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { 64 _, nodeExists := s.vdrs[nodeID] 65 if nodeExists { 66 return errDuplicateValidator 67 } 68 69 vdr := &Validator{ 70 NodeID: nodeID, 71 PublicKey: pk, 72 TxID: txID, 73 Weight: weight, 74 index: len(s.vdrSlice), 75 } 76 s.vdrs[nodeID] = vdr 77 s.vdrSlice = append(s.vdrSlice, vdr) 78 s.weights = append(s.weights, weight) 79 s.totalWeight.Add(s.totalWeight, new(big.Int).SetUint64(weight)) 80 s.samplerInitialized = false 81 82 s.callValidatorAddedCallbacks(nodeID, pk, txID, weight) 83 return nil 84 } 85 86 func (s *vdrSet) AddWeight(nodeID ids.NodeID, weight uint64) error { 87 s.lock.Lock() 88 defer s.lock.Unlock() 89 90 return s.addWeight(nodeID, weight) 91 } 92 93 func (s *vdrSet) addWeight(nodeID ids.NodeID, weight uint64) error { 94 vdr, nodeExists := s.vdrs[nodeID] 95 if !nodeExists { 96 return errMissingValidator 97 } 98 99 oldWeight := vdr.Weight 100 newWeight, err := math.Add64(oldWeight, weight) 101 if err != nil { 102 return err 103 } 104 vdr.Weight = newWeight 105 s.weights[vdr.index] = newWeight 106 s.totalWeight.Add(s.totalWeight, new(big.Int).SetUint64(weight)) 107 s.samplerInitialized = false 108 109 s.callWeightChangeCallbacks(nodeID, oldWeight, vdr.Weight) 110 return nil 111 } 112 113 func (s *vdrSet) GetWeight(nodeID ids.NodeID) uint64 { 114 s.lock.RLock() 115 defer s.lock.RUnlock() 116 117 return s.getWeight(nodeID) 118 } 119 120 func (s *vdrSet) getWeight(nodeID ids.NodeID) uint64 { 121 if vdr, ok := s.vdrs[nodeID]; ok { 122 return vdr.Weight 123 } 124 return 0 125 } 126 127 func (s *vdrSet) SubsetWeight(subset set.Set[ids.NodeID]) (uint64, error) { 128 s.lock.RLock() 129 defer s.lock.RUnlock() 130 131 return s.subsetWeight(subset) 132 } 133 134 func (s *vdrSet) subsetWeight(subset set.Set[ids.NodeID]) (uint64, error) { 135 var ( 136 totalWeight uint64 137 err error 138 ) 139 for nodeID := range subset { 140 totalWeight, err = math.Add64(totalWeight, s.getWeight(nodeID)) 141 if err != nil { 142 return 0, err 143 } 144 } 145 return totalWeight, nil 146 } 147 148 func (s *vdrSet) RemoveWeight(nodeID ids.NodeID, weight uint64) error { 149 s.lock.Lock() 150 defer s.lock.Unlock() 151 152 return s.removeWeight(nodeID, weight) 153 } 154 155 func (s *vdrSet) removeWeight(nodeID ids.NodeID, weight uint64) error { 156 vdr, ok := s.vdrs[nodeID] 157 if !ok { 158 return errMissingValidator 159 } 160 161 oldWeight := vdr.Weight 162 // We first calculate the new weight of the validator, as this guarantees 163 // that none of the following operations can underflow. 164 newWeight, err := math.Sub(oldWeight, weight) 165 if err != nil { 166 return err 167 } 168 169 if newWeight == 0 { 170 // Get the last element 171 lastIndex := len(s.vdrSlice) - 1 172 vdrToSwap := s.vdrSlice[lastIndex] 173 174 // Move element at last index --> index of removed validator 175 vdrToSwap.index = vdr.index 176 s.vdrSlice[vdr.index] = vdrToSwap 177 s.weights[vdr.index] = vdrToSwap.Weight 178 179 // Remove validator 180 delete(s.vdrs, nodeID) 181 s.vdrSlice[lastIndex] = nil 182 s.vdrSlice = s.vdrSlice[:lastIndex] 183 s.weights = s.weights[:lastIndex] 184 185 s.callValidatorRemovedCallbacks(nodeID, oldWeight) 186 } else { 187 vdr.Weight = newWeight 188 s.weights[vdr.index] = newWeight 189 190 s.callWeightChangeCallbacks(nodeID, oldWeight, newWeight) 191 } 192 s.totalWeight.Sub(s.totalWeight, new(big.Int).SetUint64(weight)) 193 s.samplerInitialized = false 194 return nil 195 } 196 197 func (s *vdrSet) Get(nodeID ids.NodeID) (*Validator, bool) { 198 s.lock.RLock() 199 defer s.lock.RUnlock() 200 201 return s.get(nodeID) 202 } 203 204 func (s *vdrSet) get(nodeID ids.NodeID) (*Validator, bool) { 205 vdr, ok := s.vdrs[nodeID] 206 if !ok { 207 return nil, false 208 } 209 copiedVdr := *vdr 210 return &copiedVdr, true 211 } 212 213 func (s *vdrSet) Len() int { 214 s.lock.RLock() 215 defer s.lock.RUnlock() 216 217 return s.len() 218 } 219 220 func (s *vdrSet) len() int { 221 return len(s.vdrSlice) 222 } 223 224 func (s *vdrSet) HasCallbackRegistered() bool { 225 s.lock.RLock() 226 defer s.lock.RUnlock() 227 228 return len(s.setCallbackListeners) > 0 229 } 230 231 func (s *vdrSet) Map() map[ids.NodeID]*GetValidatorOutput { 232 s.lock.RLock() 233 defer s.lock.RUnlock() 234 235 set := make(map[ids.NodeID]*GetValidatorOutput, len(s.vdrSlice)) 236 for _, vdr := range s.vdrSlice { 237 set[vdr.NodeID] = &GetValidatorOutput{ 238 NodeID: vdr.NodeID, 239 PublicKey: vdr.PublicKey, 240 Weight: vdr.Weight, 241 } 242 } 243 return set 244 } 245 246 func (s *vdrSet) Sample(size int) ([]ids.NodeID, error) { 247 s.lock.Lock() 248 defer s.lock.Unlock() 249 250 return s.sample(size) 251 } 252 253 func (s *vdrSet) sample(size int) ([]ids.NodeID, error) { 254 if !s.samplerInitialized { 255 if err := s.sampler.Initialize(s.weights); err != nil { 256 return nil, err 257 } 258 s.samplerInitialized = true 259 } 260 261 indices, ok := s.sampler.Sample(size) 262 if !ok { 263 return nil, errInsufficientWeight 264 } 265 266 list := make([]ids.NodeID, size) 267 for i, index := range indices { 268 list[i] = s.vdrSlice[index].NodeID 269 } 270 return list, nil 271 } 272 273 func (s *vdrSet) TotalWeight() (uint64, error) { 274 s.lock.RLock() 275 defer s.lock.RUnlock() 276 277 if !s.totalWeight.IsUint64() { 278 return 0, fmt.Errorf("%w, total weight: %s", errTotalWeightNotUint64, s.totalWeight) 279 } 280 281 return s.totalWeight.Uint64(), nil 282 } 283 284 func (s *vdrSet) String() string { 285 return s.PrefixedString("") 286 } 287 288 func (s *vdrSet) PrefixedString(prefix string) string { 289 s.lock.RLock() 290 defer s.lock.RUnlock() 291 292 return s.prefixedString(prefix) 293 } 294 295 func (s *vdrSet) prefixedString(prefix string) string { 296 sb := strings.Builder{} 297 298 sb.WriteString(fmt.Sprintf("Validator Set: (Size = %d, Weight = %d)", 299 len(s.vdrSlice), 300 s.totalWeight, 301 )) 302 format := fmt.Sprintf("\n%s Validator[%s]: %%33s, %%d", prefix, formatting.IntFormat(len(s.vdrSlice)-1)) 303 for i, vdr := range s.vdrSlice { 304 sb.WriteString(fmt.Sprintf( 305 format, 306 i, 307 vdr.NodeID, 308 vdr.Weight, 309 )) 310 } 311 312 return sb.String() 313 } 314 315 func (s *vdrSet) RegisterManagerCallbackListener(callbackListener ManagerCallbackListener) { 316 s.lock.Lock() 317 defer s.lock.Unlock() 318 319 s.managerCallbackListeners = append(s.managerCallbackListeners, callbackListener) 320 for _, vdr := range s.vdrSlice { 321 callbackListener.OnValidatorAdded(s.subnetID, vdr.NodeID, vdr.PublicKey, vdr.TxID, vdr.Weight) 322 } 323 } 324 325 func (s *vdrSet) RegisterCallbackListener(callbackListener SetCallbackListener) { 326 s.lock.Lock() 327 defer s.lock.Unlock() 328 329 s.setCallbackListeners = append(s.setCallbackListeners, callbackListener) 330 for _, vdr := range s.vdrSlice { 331 callbackListener.OnValidatorAdded(vdr.NodeID, vdr.PublicKey, vdr.TxID, vdr.Weight) 332 } 333 } 334 335 // Assumes [s.lock] is held 336 func (s *vdrSet) callWeightChangeCallbacks(node ids.NodeID, oldWeight, newWeight uint64) { 337 for _, callbackListener := range s.managerCallbackListeners { 338 callbackListener.OnValidatorWeightChanged(s.subnetID, node, oldWeight, newWeight) 339 } 340 for _, callbackListener := range s.setCallbackListeners { 341 callbackListener.OnValidatorWeightChanged(node, oldWeight, newWeight) 342 } 343 } 344 345 // Assumes [s.lock] is held 346 func (s *vdrSet) callValidatorAddedCallbacks(node ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { 347 for _, callbackListener := range s.managerCallbackListeners { 348 callbackListener.OnValidatorAdded(s.subnetID, node, pk, txID, weight) 349 } 350 for _, callbackListener := range s.setCallbackListeners { 351 callbackListener.OnValidatorAdded(node, pk, txID, weight) 352 } 353 } 354 355 // Assumes [s.lock] is held 356 func (s *vdrSet) callValidatorRemovedCallbacks(node ids.NodeID, weight uint64) { 357 for _, callbackListener := range s.managerCallbackListeners { 358 callbackListener.OnValidatorRemoved(s.subnetID, node, weight) 359 } 360 for _, callbackListener := range s.setCallbackListeners { 361 callbackListener.OnValidatorRemoved(node, weight) 362 } 363 } 364 365 func (s *vdrSet) GetValidatorIDs() []ids.NodeID { 366 s.lock.RLock() 367 defer s.lock.RUnlock() 368 369 list := make([]ids.NodeID, len(s.vdrSlice)) 370 for i, vdr := range s.vdrSlice { 371 list[i] = vdr.NodeID 372 } 373 return list 374 }