github.com/MetalBlockchain/metalgo@v1.11.9/snow/validators/set.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package validators
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"math/big"
    10  	"slices"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/MetalBlockchain/metalgo/ids"
    15  	"github.com/MetalBlockchain/metalgo/utils/crypto/bls"
    16  	"github.com/MetalBlockchain/metalgo/utils/formatting"
    17  	"github.com/MetalBlockchain/metalgo/utils/math"
    18  	"github.com/MetalBlockchain/metalgo/utils/sampler"
    19  	"github.com/MetalBlockchain/metalgo/utils/set"
    20  )
    21  
    22  var (
    23  	errDuplicateValidator   = errors.New("duplicate validator")
    24  	errMissingValidator     = errors.New("missing validator")
    25  	errTotalWeightNotUint64 = errors.New("total weight is not a uint64")
    26  	errInsufficientWeight   = errors.New("insufficient weight")
    27  )
    28  
    29  // newSet returns a new, empty set of validators.
    30  func newSet(subnetID ids.ID, callbackListeners []ManagerCallbackListener) *vdrSet {
    31  	return &vdrSet{
    32  		subnetID:                 subnetID,
    33  		vdrs:                     make(map[ids.NodeID]*Validator),
    34  		totalWeight:              new(big.Int),
    35  		sampler:                  sampler.NewWeightedWithoutReplacement(),
    36  		managerCallbackListeners: slices.Clone(callbackListeners),
    37  	}
    38  }
    39  
    40  type vdrSet struct {
    41  	subnetID ids.ID
    42  
    43  	lock        sync.RWMutex
    44  	vdrs        map[ids.NodeID]*Validator
    45  	vdrSlice    []*Validator
    46  	weights     []uint64
    47  	totalWeight *big.Int
    48  
    49  	samplerInitialized bool
    50  	sampler            sampler.WeightedWithoutReplacement
    51  
    52  	managerCallbackListeners []ManagerCallbackListener
    53  	setCallbackListeners     []SetCallbackListener
    54  }
    55  
    56  func (s *vdrSet) Add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error {
    57  	s.lock.Lock()
    58  	defer s.lock.Unlock()
    59  
    60  	return s.add(nodeID, pk, txID, weight)
    61  }
    62  
    63  func (s *vdrSet) add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error {
    64  	_, nodeExists := s.vdrs[nodeID]
    65  	if nodeExists {
    66  		return errDuplicateValidator
    67  	}
    68  
    69  	vdr := &Validator{
    70  		NodeID:    nodeID,
    71  		PublicKey: pk,
    72  		TxID:      txID,
    73  		Weight:    weight,
    74  		index:     len(s.vdrSlice),
    75  	}
    76  	s.vdrs[nodeID] = vdr
    77  	s.vdrSlice = append(s.vdrSlice, vdr)
    78  	s.weights = append(s.weights, weight)
    79  	s.totalWeight.Add(s.totalWeight, new(big.Int).SetUint64(weight))
    80  	s.samplerInitialized = false
    81  
    82  	s.callValidatorAddedCallbacks(nodeID, pk, txID, weight)
    83  	return nil
    84  }
    85  
    86  func (s *vdrSet) AddWeight(nodeID ids.NodeID, weight uint64) error {
    87  	s.lock.Lock()
    88  	defer s.lock.Unlock()
    89  
    90  	return s.addWeight(nodeID, weight)
    91  }
    92  
    93  func (s *vdrSet) addWeight(nodeID ids.NodeID, weight uint64) error {
    94  	vdr, nodeExists := s.vdrs[nodeID]
    95  	if !nodeExists {
    96  		return errMissingValidator
    97  	}
    98  
    99  	oldWeight := vdr.Weight
   100  	newWeight, err := math.Add64(oldWeight, weight)
   101  	if err != nil {
   102  		return err
   103  	}
   104  	vdr.Weight = newWeight
   105  	s.weights[vdr.index] = newWeight
   106  	s.totalWeight.Add(s.totalWeight, new(big.Int).SetUint64(weight))
   107  	s.samplerInitialized = false
   108  
   109  	s.callWeightChangeCallbacks(nodeID, oldWeight, vdr.Weight)
   110  	return nil
   111  }
   112  
   113  func (s *vdrSet) GetWeight(nodeID ids.NodeID) uint64 {
   114  	s.lock.RLock()
   115  	defer s.lock.RUnlock()
   116  
   117  	return s.getWeight(nodeID)
   118  }
   119  
   120  func (s *vdrSet) getWeight(nodeID ids.NodeID) uint64 {
   121  	if vdr, ok := s.vdrs[nodeID]; ok {
   122  		return vdr.Weight
   123  	}
   124  	return 0
   125  }
   126  
   127  func (s *vdrSet) SubsetWeight(subset set.Set[ids.NodeID]) (uint64, error) {
   128  	s.lock.RLock()
   129  	defer s.lock.RUnlock()
   130  
   131  	return s.subsetWeight(subset)
   132  }
   133  
   134  func (s *vdrSet) subsetWeight(subset set.Set[ids.NodeID]) (uint64, error) {
   135  	var (
   136  		totalWeight uint64
   137  		err         error
   138  	)
   139  	for nodeID := range subset {
   140  		totalWeight, err = math.Add64(totalWeight, s.getWeight(nodeID))
   141  		if err != nil {
   142  			return 0, err
   143  		}
   144  	}
   145  	return totalWeight, nil
   146  }
   147  
   148  func (s *vdrSet) RemoveWeight(nodeID ids.NodeID, weight uint64) error {
   149  	s.lock.Lock()
   150  	defer s.lock.Unlock()
   151  
   152  	return s.removeWeight(nodeID, weight)
   153  }
   154  
   155  func (s *vdrSet) removeWeight(nodeID ids.NodeID, weight uint64) error {
   156  	vdr, ok := s.vdrs[nodeID]
   157  	if !ok {
   158  		return errMissingValidator
   159  	}
   160  
   161  	oldWeight := vdr.Weight
   162  	// We first calculate the new weight of the validator, as this guarantees
   163  	// that none of the following operations can underflow.
   164  	newWeight, err := math.Sub(oldWeight, weight)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	if newWeight == 0 {
   170  		// Get the last element
   171  		lastIndex := len(s.vdrSlice) - 1
   172  		vdrToSwap := s.vdrSlice[lastIndex]
   173  
   174  		// Move element at last index --> index of removed validator
   175  		vdrToSwap.index = vdr.index
   176  		s.vdrSlice[vdr.index] = vdrToSwap
   177  		s.weights[vdr.index] = vdrToSwap.Weight
   178  
   179  		// Remove validator
   180  		delete(s.vdrs, nodeID)
   181  		s.vdrSlice[lastIndex] = nil
   182  		s.vdrSlice = s.vdrSlice[:lastIndex]
   183  		s.weights = s.weights[:lastIndex]
   184  
   185  		s.callValidatorRemovedCallbacks(nodeID, oldWeight)
   186  	} else {
   187  		vdr.Weight = newWeight
   188  		s.weights[vdr.index] = newWeight
   189  
   190  		s.callWeightChangeCallbacks(nodeID, oldWeight, newWeight)
   191  	}
   192  	s.totalWeight.Sub(s.totalWeight, new(big.Int).SetUint64(weight))
   193  	s.samplerInitialized = false
   194  	return nil
   195  }
   196  
   197  func (s *vdrSet) Get(nodeID ids.NodeID) (*Validator, bool) {
   198  	s.lock.RLock()
   199  	defer s.lock.RUnlock()
   200  
   201  	return s.get(nodeID)
   202  }
   203  
   204  func (s *vdrSet) get(nodeID ids.NodeID) (*Validator, bool) {
   205  	vdr, ok := s.vdrs[nodeID]
   206  	if !ok {
   207  		return nil, false
   208  	}
   209  	copiedVdr := *vdr
   210  	return &copiedVdr, true
   211  }
   212  
   213  func (s *vdrSet) Len() int {
   214  	s.lock.RLock()
   215  	defer s.lock.RUnlock()
   216  
   217  	return s.len()
   218  }
   219  
   220  func (s *vdrSet) len() int {
   221  	return len(s.vdrSlice)
   222  }
   223  
   224  func (s *vdrSet) HasCallbackRegistered() bool {
   225  	s.lock.RLock()
   226  	defer s.lock.RUnlock()
   227  
   228  	return len(s.setCallbackListeners) > 0
   229  }
   230  
   231  func (s *vdrSet) Map() map[ids.NodeID]*GetValidatorOutput {
   232  	s.lock.RLock()
   233  	defer s.lock.RUnlock()
   234  
   235  	set := make(map[ids.NodeID]*GetValidatorOutput, len(s.vdrSlice))
   236  	for _, vdr := range s.vdrSlice {
   237  		set[vdr.NodeID] = &GetValidatorOutput{
   238  			NodeID:    vdr.NodeID,
   239  			PublicKey: vdr.PublicKey,
   240  			Weight:    vdr.Weight,
   241  		}
   242  	}
   243  	return set
   244  }
   245  
   246  func (s *vdrSet) Sample(size int) ([]ids.NodeID, error) {
   247  	s.lock.Lock()
   248  	defer s.lock.Unlock()
   249  
   250  	return s.sample(size)
   251  }
   252  
   253  func (s *vdrSet) sample(size int) ([]ids.NodeID, error) {
   254  	if !s.samplerInitialized {
   255  		if err := s.sampler.Initialize(s.weights); err != nil {
   256  			return nil, err
   257  		}
   258  		s.samplerInitialized = true
   259  	}
   260  
   261  	indices, ok := s.sampler.Sample(size)
   262  	if !ok {
   263  		return nil, errInsufficientWeight
   264  	}
   265  
   266  	list := make([]ids.NodeID, size)
   267  	for i, index := range indices {
   268  		list[i] = s.vdrSlice[index].NodeID
   269  	}
   270  	return list, nil
   271  }
   272  
   273  func (s *vdrSet) TotalWeight() (uint64, error) {
   274  	s.lock.RLock()
   275  	defer s.lock.RUnlock()
   276  
   277  	if !s.totalWeight.IsUint64() {
   278  		return 0, fmt.Errorf("%w, total weight: %s", errTotalWeightNotUint64, s.totalWeight)
   279  	}
   280  
   281  	return s.totalWeight.Uint64(), nil
   282  }
   283  
   284  func (s *vdrSet) String() string {
   285  	return s.PrefixedString("")
   286  }
   287  
   288  func (s *vdrSet) PrefixedString(prefix string) string {
   289  	s.lock.RLock()
   290  	defer s.lock.RUnlock()
   291  
   292  	return s.prefixedString(prefix)
   293  }
   294  
   295  func (s *vdrSet) prefixedString(prefix string) string {
   296  	sb := strings.Builder{}
   297  
   298  	sb.WriteString(fmt.Sprintf("Validator Set: (Size = %d, Weight = %d)",
   299  		len(s.vdrSlice),
   300  		s.totalWeight,
   301  	))
   302  	format := fmt.Sprintf("\n%s    Validator[%s]: %%33s, %%d", prefix, formatting.IntFormat(len(s.vdrSlice)-1))
   303  	for i, vdr := range s.vdrSlice {
   304  		sb.WriteString(fmt.Sprintf(
   305  			format,
   306  			i,
   307  			vdr.NodeID,
   308  			vdr.Weight,
   309  		))
   310  	}
   311  
   312  	return sb.String()
   313  }
   314  
   315  func (s *vdrSet) RegisterManagerCallbackListener(callbackListener ManagerCallbackListener) {
   316  	s.lock.Lock()
   317  	defer s.lock.Unlock()
   318  
   319  	s.managerCallbackListeners = append(s.managerCallbackListeners, callbackListener)
   320  	for _, vdr := range s.vdrSlice {
   321  		callbackListener.OnValidatorAdded(s.subnetID, vdr.NodeID, vdr.PublicKey, vdr.TxID, vdr.Weight)
   322  	}
   323  }
   324  
   325  func (s *vdrSet) RegisterCallbackListener(callbackListener SetCallbackListener) {
   326  	s.lock.Lock()
   327  	defer s.lock.Unlock()
   328  
   329  	s.setCallbackListeners = append(s.setCallbackListeners, callbackListener)
   330  	for _, vdr := range s.vdrSlice {
   331  		callbackListener.OnValidatorAdded(vdr.NodeID, vdr.PublicKey, vdr.TxID, vdr.Weight)
   332  	}
   333  }
   334  
   335  // Assumes [s.lock] is held
   336  func (s *vdrSet) callWeightChangeCallbacks(node ids.NodeID, oldWeight, newWeight uint64) {
   337  	for _, callbackListener := range s.managerCallbackListeners {
   338  		callbackListener.OnValidatorWeightChanged(s.subnetID, node, oldWeight, newWeight)
   339  	}
   340  	for _, callbackListener := range s.setCallbackListeners {
   341  		callbackListener.OnValidatorWeightChanged(node, oldWeight, newWeight)
   342  	}
   343  }
   344  
   345  // Assumes [s.lock] is held
   346  func (s *vdrSet) callValidatorAddedCallbacks(node ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   347  	for _, callbackListener := range s.managerCallbackListeners {
   348  		callbackListener.OnValidatorAdded(s.subnetID, node, pk, txID, weight)
   349  	}
   350  	for _, callbackListener := range s.setCallbackListeners {
   351  		callbackListener.OnValidatorAdded(node, pk, txID, weight)
   352  	}
   353  }
   354  
   355  // Assumes [s.lock] is held
   356  func (s *vdrSet) callValidatorRemovedCallbacks(node ids.NodeID, weight uint64) {
   357  	for _, callbackListener := range s.managerCallbackListeners {
   358  		callbackListener.OnValidatorRemoved(s.subnetID, node, weight)
   359  	}
   360  	for _, callbackListener := range s.setCallbackListeners {
   361  		callbackListener.OnValidatorRemoved(node, weight)
   362  	}
   363  }
   364  
   365  func (s *vdrSet) GetValidatorIDs() []ids.NodeID {
   366  	s.lock.RLock()
   367  	defer s.lock.RUnlock()
   368  
   369  	list := make([]ids.NodeID, len(s.vdrSlice))
   370  	for i, vdr := range s.vdrSlice {
   371  		list[i] = vdr.NodeID
   372  	}
   373  	return list
   374  }