github.com/MetalBlockchain/metalgo@v1.11.9/snow/validators/set_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package validators
     5  
     6  import (
     7  	"math"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/MetalBlockchain/metalgo/ids"
    13  	"github.com/MetalBlockchain/metalgo/utils/crypto/bls"
    14  	"github.com/MetalBlockchain/metalgo/utils/set"
    15  
    16  	safemath "github.com/MetalBlockchain/metalgo/utils/math"
    17  )
    18  
    19  var _ SetCallbackListener = (*setCallbackListener)(nil)
    20  
    21  type setCallbackListener struct {
    22  	t         *testing.T
    23  	onAdd     func(ids.NodeID, *bls.PublicKey, ids.ID, uint64)
    24  	onWeight  func(ids.NodeID, uint64, uint64)
    25  	onRemoved func(ids.NodeID, uint64)
    26  }
    27  
    28  func (c *setCallbackListener) OnValidatorAdded(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
    29  	if c.onAdd != nil {
    30  		c.onAdd(nodeID, pk, txID, weight)
    31  	} else {
    32  		c.t.Fail()
    33  	}
    34  }
    35  
    36  func (c *setCallbackListener) OnValidatorRemoved(nodeID ids.NodeID, weight uint64) {
    37  	if c.onRemoved != nil {
    38  		c.onRemoved(nodeID, weight)
    39  	} else {
    40  		c.t.Fail()
    41  	}
    42  }
    43  
    44  func (c *setCallbackListener) OnValidatorWeightChanged(nodeID ids.NodeID, oldWeight, newWeight uint64) {
    45  	if c.onWeight != nil {
    46  		c.onWeight(nodeID, oldWeight, newWeight)
    47  	} else {
    48  		c.t.Fail()
    49  	}
    50  }
    51  
    52  func TestSetAddDuplicate(t *testing.T) {
    53  	require := require.New(t)
    54  
    55  	s := newSet(ids.Empty, nil)
    56  
    57  	nodeID := ids.GenerateTestNodeID()
    58  	require.NoError(s.Add(nodeID, nil, ids.Empty, 1))
    59  
    60  	err := s.Add(nodeID, nil, ids.Empty, 1)
    61  	require.ErrorIs(err, errDuplicateValidator)
    62  }
    63  
    64  func TestSetAddOverflow(t *testing.T) {
    65  	require := require.New(t)
    66  
    67  	s := newSet(ids.Empty, nil)
    68  	require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1))
    69  
    70  	require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, math.MaxUint64))
    71  
    72  	_, err := s.TotalWeight()
    73  	require.ErrorIs(err, errTotalWeightNotUint64)
    74  }
    75  
    76  func TestSetAddWeightOverflow(t *testing.T) {
    77  	require := require.New(t)
    78  
    79  	s := newSet(ids.Empty, nil)
    80  
    81  	require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1))
    82  
    83  	nodeID := ids.GenerateTestNodeID()
    84  	require.NoError(s.Add(nodeID, nil, ids.Empty, 1))
    85  
    86  	require.NoError(s.AddWeight(nodeID, math.MaxUint64-1))
    87  
    88  	_, err := s.TotalWeight()
    89  	require.ErrorIs(err, errTotalWeightNotUint64)
    90  }
    91  
    92  func TestSetGetWeight(t *testing.T) {
    93  	require := require.New(t)
    94  
    95  	s := newSet(ids.Empty, nil)
    96  
    97  	nodeID := ids.GenerateTestNodeID()
    98  	require.Zero(s.GetWeight(nodeID))
    99  
   100  	require.NoError(s.Add(nodeID, nil, ids.Empty, 1))
   101  
   102  	require.Equal(uint64(1), s.GetWeight(nodeID))
   103  }
   104  
   105  func TestSetSubsetWeight(t *testing.T) {
   106  	require := require.New(t)
   107  
   108  	nodeID0 := ids.GenerateTestNodeID()
   109  	nodeID1 := ids.GenerateTestNodeID()
   110  	nodeID2 := ids.GenerateTestNodeID()
   111  
   112  	weight0 := uint64(93)
   113  	weight1 := uint64(123)
   114  	weight2 := uint64(810)
   115  
   116  	subset := set.Of(nodeID0, nodeID1)
   117  
   118  	s := newSet(ids.Empty, nil)
   119  
   120  	require.NoError(s.Add(nodeID0, nil, ids.Empty, weight0))
   121  	require.NoError(s.Add(nodeID1, nil, ids.Empty, weight1))
   122  	require.NoError(s.Add(nodeID2, nil, ids.Empty, weight2))
   123  
   124  	expectedWeight := weight0 + weight1
   125  	subsetWeight, err := s.SubsetWeight(subset)
   126  	require.NoError(err)
   127  	require.Equal(expectedWeight, subsetWeight)
   128  }
   129  
   130  func TestSetRemoveWeightMissingValidator(t *testing.T) {
   131  	require := require.New(t)
   132  
   133  	s := newSet(ids.Empty, nil)
   134  
   135  	require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   136  
   137  	err := s.RemoveWeight(ids.GenerateTestNodeID(), 1)
   138  	require.ErrorIs(err, errMissingValidator)
   139  }
   140  
   141  func TestSetRemoveWeightUnderflow(t *testing.T) {
   142  	require := require.New(t)
   143  
   144  	s := newSet(ids.Empty, nil)
   145  
   146  	require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   147  
   148  	nodeID := ids.GenerateTestNodeID()
   149  	require.NoError(s.Add(nodeID, nil, ids.Empty, 1))
   150  
   151  	err := s.RemoveWeight(nodeID, 2)
   152  	require.ErrorIs(err, safemath.ErrUnderflow)
   153  
   154  	totalWeight, err := s.TotalWeight()
   155  	require.NoError(err)
   156  	require.Equal(uint64(2), totalWeight)
   157  }
   158  
   159  func TestSetGet(t *testing.T) {
   160  	require := require.New(t)
   161  
   162  	s := newSet(ids.Empty, nil)
   163  
   164  	nodeID := ids.GenerateTestNodeID()
   165  	_, ok := s.Get(nodeID)
   166  	require.False(ok)
   167  
   168  	sk, err := bls.NewSecretKey()
   169  	require.NoError(err)
   170  
   171  	pk := bls.PublicFromSecretKey(sk)
   172  	require.NoError(s.Add(nodeID, pk, ids.Empty, 1))
   173  
   174  	vdr0, ok := s.Get(nodeID)
   175  	require.True(ok)
   176  	require.Equal(nodeID, vdr0.NodeID)
   177  	require.Equal(pk, vdr0.PublicKey)
   178  	require.Equal(uint64(1), vdr0.Weight)
   179  
   180  	require.NoError(s.AddWeight(nodeID, 1))
   181  
   182  	vdr1, ok := s.Get(nodeID)
   183  	require.True(ok)
   184  	require.Equal(nodeID, vdr0.NodeID)
   185  	require.Equal(pk, vdr0.PublicKey)
   186  	require.Equal(uint64(1), vdr0.Weight)
   187  	require.Equal(nodeID, vdr1.NodeID)
   188  	require.Equal(pk, vdr1.PublicKey)
   189  	require.Equal(uint64(2), vdr1.Weight)
   190  
   191  	require.NoError(s.RemoveWeight(nodeID, 2))
   192  	_, ok = s.Get(nodeID)
   193  	require.False(ok)
   194  }
   195  
   196  func TestSetLen(t *testing.T) {
   197  	require := require.New(t)
   198  
   199  	s := newSet(ids.Empty, nil)
   200  
   201  	setLen := s.Len()
   202  	require.Zero(setLen)
   203  
   204  	nodeID0 := ids.GenerateTestNodeID()
   205  	require.NoError(s.Add(nodeID0, nil, ids.Empty, 1))
   206  
   207  	setLen = s.Len()
   208  	require.Equal(1, setLen)
   209  
   210  	nodeID1 := ids.GenerateTestNodeID()
   211  	require.NoError(s.Add(nodeID1, nil, ids.Empty, 1))
   212  
   213  	setLen = s.Len()
   214  	require.Equal(2, setLen)
   215  
   216  	require.NoError(s.RemoveWeight(nodeID1, 1))
   217  
   218  	setLen = s.Len()
   219  	require.Equal(1, setLen)
   220  
   221  	require.NoError(s.RemoveWeight(nodeID0, 1))
   222  
   223  	setLen = s.Len()
   224  	require.Zero(setLen)
   225  }
   226  
   227  func TestSetMap(t *testing.T) {
   228  	require := require.New(t)
   229  
   230  	s := newSet(ids.Empty, nil)
   231  
   232  	m := s.Map()
   233  	require.Empty(m)
   234  
   235  	sk, err := bls.NewSecretKey()
   236  	require.NoError(err)
   237  
   238  	pk := bls.PublicFromSecretKey(sk)
   239  	nodeID0 := ids.GenerateTestNodeID()
   240  	require.NoError(s.Add(nodeID0, pk, ids.Empty, 2))
   241  
   242  	m = s.Map()
   243  	require.Len(m, 1)
   244  	require.Contains(m, nodeID0)
   245  
   246  	node0 := m[nodeID0]
   247  	require.Equal(nodeID0, node0.NodeID)
   248  	require.Equal(pk, node0.PublicKey)
   249  	require.Equal(uint64(2), node0.Weight)
   250  
   251  	nodeID1 := ids.GenerateTestNodeID()
   252  	require.NoError(s.Add(nodeID1, nil, ids.Empty, 1))
   253  
   254  	m = s.Map()
   255  	require.Len(m, 2)
   256  	require.Contains(m, nodeID0)
   257  	require.Contains(m, nodeID1)
   258  
   259  	node0 = m[nodeID0]
   260  	require.Equal(nodeID0, node0.NodeID)
   261  	require.Equal(pk, node0.PublicKey)
   262  	require.Equal(uint64(2), node0.Weight)
   263  
   264  	node1 := m[nodeID1]
   265  	require.Equal(nodeID1, node1.NodeID)
   266  	require.Nil(node1.PublicKey)
   267  	require.Equal(uint64(1), node1.Weight)
   268  
   269  	require.NoError(s.RemoveWeight(nodeID0, 1))
   270  	require.Equal(nodeID0, node0.NodeID)
   271  	require.Equal(pk, node0.PublicKey)
   272  	require.Equal(uint64(2), node0.Weight)
   273  
   274  	m = s.Map()
   275  	require.Len(m, 2)
   276  	require.Contains(m, nodeID0)
   277  	require.Contains(m, nodeID1)
   278  
   279  	node0 = m[nodeID0]
   280  	require.Equal(nodeID0, node0.NodeID)
   281  	require.Equal(pk, node0.PublicKey)
   282  	require.Equal(uint64(1), node0.Weight)
   283  
   284  	node1 = m[nodeID1]
   285  	require.Equal(nodeID1, node1.NodeID)
   286  	require.Nil(node1.PublicKey)
   287  	require.Equal(uint64(1), node1.Weight)
   288  
   289  	require.NoError(s.RemoveWeight(nodeID0, 1))
   290  
   291  	m = s.Map()
   292  	require.Len(m, 1)
   293  	require.Contains(m, nodeID1)
   294  
   295  	node1 = m[nodeID1]
   296  	require.Equal(nodeID1, node1.NodeID)
   297  	require.Nil(node1.PublicKey)
   298  	require.Equal(uint64(1), node1.Weight)
   299  
   300  	require.NoError(s.RemoveWeight(nodeID1, 1))
   301  
   302  	require.Empty(s.Map())
   303  }
   304  
   305  func TestSetWeight(t *testing.T) {
   306  	require := require.New(t)
   307  
   308  	vdr0 := ids.BuildTestNodeID([]byte{1})
   309  	weight0 := uint64(93)
   310  	vdr1 := ids.BuildTestNodeID([]byte{2})
   311  	weight1 := uint64(123)
   312  
   313  	s := newSet(ids.Empty, nil)
   314  	require.NoError(s.Add(vdr0, nil, ids.Empty, weight0))
   315  
   316  	require.NoError(s.Add(vdr1, nil, ids.Empty, weight1))
   317  
   318  	setWeight, err := s.TotalWeight()
   319  	require.NoError(err)
   320  	expectedWeight := weight0 + weight1
   321  	require.Equal(expectedWeight, setWeight)
   322  }
   323  
   324  func TestSetSample(t *testing.T) {
   325  	require := require.New(t)
   326  
   327  	s := newSet(ids.Empty, nil)
   328  
   329  	sampled, err := s.Sample(0)
   330  	require.NoError(err)
   331  	require.Empty(sampled)
   332  
   333  	sk, err := bls.NewSecretKey()
   334  	require.NoError(err)
   335  
   336  	nodeID0 := ids.GenerateTestNodeID()
   337  	pk := bls.PublicFromSecretKey(sk)
   338  	require.NoError(s.Add(nodeID0, pk, ids.Empty, 1))
   339  
   340  	sampled, err = s.Sample(1)
   341  	require.NoError(err)
   342  	require.Equal([]ids.NodeID{nodeID0}, sampled)
   343  
   344  	_, err = s.Sample(2)
   345  	require.ErrorIs(err, errInsufficientWeight)
   346  
   347  	nodeID1 := ids.GenerateTestNodeID()
   348  	require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1))
   349  
   350  	sampled, err = s.Sample(1)
   351  	require.NoError(err)
   352  	require.Equal([]ids.NodeID{nodeID1}, sampled)
   353  
   354  	sampled, err = s.Sample(2)
   355  	require.NoError(err)
   356  	require.Equal([]ids.NodeID{nodeID1, nodeID1}, sampled)
   357  
   358  	sampled, err = s.Sample(3)
   359  	require.NoError(err)
   360  	require.Equal([]ids.NodeID{nodeID1, nodeID1, nodeID1}, sampled)
   361  }
   362  
   363  func TestSetString(t *testing.T) {
   364  	require := require.New(t)
   365  
   366  	nodeID0 := ids.EmptyNodeID
   367  	nodeID1 := ids.BuildTestNodeID([]byte{
   368  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
   369  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
   370  	})
   371  
   372  	s := newSet(ids.Empty, nil)
   373  	require.NoError(s.Add(nodeID0, nil, ids.Empty, 1))
   374  
   375  	require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1))
   376  
   377  	expected := `Validator Set: (Size = 2, Weight = 9223372036854775807)
   378      Validator[0]: NodeID-111111111111111111116DBWJs, 1
   379      Validator[1]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 9223372036854775806`
   380  	result := s.String()
   381  	require.Equal(expected, result)
   382  }
   383  
   384  func TestSetAddCallback(t *testing.T) {
   385  	require := require.New(t)
   386  
   387  	nodeID0 := ids.BuildTestNodeID([]byte{1})
   388  	sk0, err := bls.NewSecretKey()
   389  	require.NoError(err)
   390  	pk0 := bls.PublicFromSecretKey(sk0)
   391  	txID0 := ids.GenerateTestID()
   392  	weight0 := uint64(1)
   393  
   394  	s := newSet(ids.Empty, nil)
   395  	callCount := 0
   396  	require.False(s.HasCallbackRegistered())
   397  	s.RegisterCallbackListener(&setCallbackListener{
   398  		t: t,
   399  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   400  			require.Equal(nodeID0, nodeID)
   401  			require.Equal(pk0, pk)
   402  			require.Equal(txID0, txID)
   403  			require.Equal(weight0, weight)
   404  			callCount++
   405  		},
   406  	})
   407  	require.True(s.HasCallbackRegistered())
   408  	require.NoError(s.Add(nodeID0, pk0, txID0, weight0))
   409  	require.Equal(1, callCount)
   410  }
   411  
   412  func TestSetAddWeightCallback(t *testing.T) {
   413  	require := require.New(t)
   414  
   415  	nodeID0 := ids.BuildTestNodeID([]byte{1})
   416  	txID0 := ids.GenerateTestID()
   417  	weight0 := uint64(1)
   418  	weight1 := uint64(93)
   419  
   420  	s := newSet(ids.Empty, nil)
   421  	require.NoError(s.Add(nodeID0, nil, txID0, weight0))
   422  
   423  	callCount := 0
   424  	require.False(s.HasCallbackRegistered())
   425  	s.RegisterCallbackListener(&setCallbackListener{
   426  		t: t,
   427  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   428  			require.Equal(nodeID0, nodeID)
   429  			require.Nil(pk)
   430  			require.Equal(txID0, txID)
   431  			require.Equal(weight0, weight)
   432  			callCount++
   433  		},
   434  		onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) {
   435  			require.Equal(nodeID0, nodeID)
   436  			require.Equal(weight0, oldWeight)
   437  			require.Equal(weight0+weight1, newWeight)
   438  			callCount++
   439  		},
   440  	})
   441  	require.True(s.HasCallbackRegistered())
   442  	require.NoError(s.AddWeight(nodeID0, weight1))
   443  	require.Equal(2, callCount)
   444  }
   445  
   446  func TestSetRemoveWeightCallback(t *testing.T) {
   447  	require := require.New(t)
   448  
   449  	nodeID0 := ids.BuildTestNodeID([]byte{1})
   450  	txID0 := ids.GenerateTestID()
   451  	weight0 := uint64(93)
   452  	weight1 := uint64(92)
   453  
   454  	s := newSet(ids.Empty, nil)
   455  	require.NoError(s.Add(nodeID0, nil, txID0, weight0))
   456  
   457  	callCount := 0
   458  	require.False(s.HasCallbackRegistered())
   459  	s.RegisterCallbackListener(&setCallbackListener{
   460  		t: t,
   461  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   462  			require.Equal(nodeID0, nodeID)
   463  			require.Nil(pk)
   464  			require.Equal(txID0, txID)
   465  			require.Equal(weight0, weight)
   466  			callCount++
   467  		},
   468  		onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) {
   469  			require.Equal(nodeID0, nodeID)
   470  			require.Equal(weight0, oldWeight)
   471  			require.Equal(weight0-weight1, newWeight)
   472  			callCount++
   473  		},
   474  	})
   475  	require.True(s.HasCallbackRegistered())
   476  	require.NoError(s.RemoveWeight(nodeID0, weight1))
   477  	require.Equal(2, callCount)
   478  }
   479  
   480  func TestSetValidatorRemovedCallback(t *testing.T) {
   481  	require := require.New(t)
   482  
   483  	nodeID0 := ids.BuildTestNodeID([]byte{1})
   484  	txID0 := ids.GenerateTestID()
   485  	weight0 := uint64(93)
   486  
   487  	s := newSet(ids.Empty, nil)
   488  	require.NoError(s.Add(nodeID0, nil, txID0, weight0))
   489  
   490  	callCount := 0
   491  	require.False(s.HasCallbackRegistered())
   492  	s.RegisterCallbackListener(&setCallbackListener{
   493  		t: t,
   494  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   495  			require.Equal(nodeID0, nodeID)
   496  			require.Nil(pk)
   497  			require.Equal(txID0, txID)
   498  			require.Equal(weight0, weight)
   499  			callCount++
   500  		},
   501  		onRemoved: func(nodeID ids.NodeID, weight uint64) {
   502  			require.Equal(nodeID0, nodeID)
   503  			require.Equal(weight0, weight)
   504  			callCount++
   505  		},
   506  	})
   507  	require.True(s.HasCallbackRegistered())
   508  	require.NoError(s.RemoveWeight(nodeID0, weight0))
   509  	require.Equal(2, callCount)
   510  }