github.com/MetalBlockchain/metalgo@v1.11.9/snow/validators/manager_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package validators
     5  
     6  import (
     7  	"math"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/MetalBlockchain/metalgo/ids"
    13  	"github.com/MetalBlockchain/metalgo/utils/crypto/bls"
    14  	"github.com/MetalBlockchain/metalgo/utils/set"
    15  
    16  	safemath "github.com/MetalBlockchain/metalgo/utils/math"
    17  )
    18  
    19  var _ ManagerCallbackListener = (*managerCallbackListener)(nil)
    20  
    21  type managerCallbackListener struct {
    22  	t         *testing.T
    23  	onAdd     func(ids.ID, ids.NodeID, *bls.PublicKey, ids.ID, uint64)
    24  	onWeight  func(ids.ID, ids.NodeID, uint64, uint64)
    25  	onRemoved func(ids.ID, ids.NodeID, uint64)
    26  }
    27  
    28  func (c *managerCallbackListener) OnValidatorAdded(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
    29  	if c.onAdd != nil {
    30  		c.onAdd(subnetID, nodeID, pk, txID, weight)
    31  	} else {
    32  		c.t.Fail()
    33  	}
    34  }
    35  
    36  func (c *managerCallbackListener) OnValidatorRemoved(subnetID ids.ID, nodeID ids.NodeID, weight uint64) {
    37  	if c.onRemoved != nil {
    38  		c.onRemoved(subnetID, nodeID, weight)
    39  	} else {
    40  		c.t.Fail()
    41  	}
    42  }
    43  
    44  func (c *managerCallbackListener) OnValidatorWeightChanged(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) {
    45  	if c.onWeight != nil {
    46  		c.onWeight(subnetID, nodeID, oldWeight, newWeight)
    47  	} else {
    48  		c.t.Fail()
    49  	}
    50  }
    51  
    52  func TestAddZeroWeight(t *testing.T) {
    53  	require := require.New(t)
    54  
    55  	m := NewManager().(*manager)
    56  	err := m.AddStaker(ids.GenerateTestID(), ids.GenerateTestNodeID(), nil, ids.Empty, 0)
    57  	require.ErrorIs(err, ErrZeroWeight)
    58  	require.Empty(m.subnetToVdrs)
    59  }
    60  
    61  func TestAddDuplicate(t *testing.T) {
    62  	require := require.New(t)
    63  
    64  	m := NewManager()
    65  	subnetID := ids.GenerateTestID()
    66  
    67  	nodeID := ids.GenerateTestNodeID()
    68  	require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1))
    69  
    70  	err := m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)
    71  	require.ErrorIs(err, errDuplicateValidator)
    72  }
    73  
    74  func TestAddOverflow(t *testing.T) {
    75  	require := require.New(t)
    76  
    77  	m := NewManager()
    78  	subnetID := ids.GenerateTestID()
    79  	nodeID1 := ids.GenerateTestNodeID()
    80  	nodeID2 := ids.GenerateTestNodeID()
    81  	require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, 1))
    82  
    83  	require.NoError(m.AddStaker(subnetID, nodeID2, nil, ids.Empty, math.MaxUint64))
    84  
    85  	_, err := m.TotalWeight(subnetID)
    86  	require.ErrorIs(err, errTotalWeightNotUint64)
    87  
    88  	set := set.Of(nodeID1, nodeID2)
    89  	_, err = m.SubsetWeight(subnetID, set)
    90  	require.ErrorIs(err, safemath.ErrOverflow)
    91  }
    92  
    93  func TestAddWeightZeroWeight(t *testing.T) {
    94  	require := require.New(t)
    95  
    96  	m := NewManager()
    97  	subnetID := ids.GenerateTestID()
    98  
    99  	nodeID := ids.GenerateTestNodeID()
   100  	require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1))
   101  
   102  	err := m.AddWeight(subnetID, nodeID, 0)
   103  	require.ErrorIs(err, ErrZeroWeight)
   104  }
   105  
   106  func TestAddWeightOverflow(t *testing.T) {
   107  	require := require.New(t)
   108  
   109  	m := NewManager()
   110  	subnetID := ids.GenerateTestID()
   111  	require.NoError(m.AddStaker(subnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   112  
   113  	nodeID := ids.GenerateTestNodeID()
   114  	require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1))
   115  
   116  	require.NoError(m.AddWeight(subnetID, nodeID, math.MaxUint64-1))
   117  
   118  	_, err := m.TotalWeight(subnetID)
   119  	require.ErrorIs(err, errTotalWeightNotUint64)
   120  }
   121  
   122  func TestGetWeight(t *testing.T) {
   123  	require := require.New(t)
   124  
   125  	m := NewManager()
   126  	subnetID := ids.GenerateTestID()
   127  
   128  	nodeID := ids.GenerateTestNodeID()
   129  	require.Zero(m.GetWeight(subnetID, nodeID))
   130  
   131  	require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1))
   132  
   133  	totalWeight, err := m.TotalWeight(subnetID)
   134  	require.NoError(err)
   135  	require.Equal(uint64(1), totalWeight)
   136  }
   137  
   138  func TestSubsetWeight(t *testing.T) {
   139  	require := require.New(t)
   140  
   141  	nodeID0 := ids.GenerateTestNodeID()
   142  	nodeID1 := ids.GenerateTestNodeID()
   143  	nodeID2 := ids.GenerateTestNodeID()
   144  
   145  	weight0 := uint64(93)
   146  	weight1 := uint64(123)
   147  	weight2 := uint64(810)
   148  
   149  	subset := set.Of(nodeID0, nodeID1)
   150  
   151  	m := NewManager()
   152  	subnetID := ids.GenerateTestID()
   153  
   154  	require.NoError(m.AddStaker(subnetID, nodeID0, nil, ids.Empty, weight0))
   155  	require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, weight1))
   156  	require.NoError(m.AddStaker(subnetID, nodeID2, nil, ids.Empty, weight2))
   157  
   158  	expectedWeight := weight0 + weight1
   159  	subsetWeight, err := m.SubsetWeight(subnetID, subset)
   160  	require.NoError(err)
   161  	require.Equal(expectedWeight, subsetWeight)
   162  }
   163  
   164  func TestRemoveWeightZeroWeight(t *testing.T) {
   165  	require := require.New(t)
   166  
   167  	m := NewManager()
   168  	subnetID := ids.GenerateTestID()
   169  	nodeID := ids.GenerateTestNodeID()
   170  	require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1))
   171  
   172  	err := m.RemoveWeight(subnetID, nodeID, 0)
   173  	require.ErrorIs(err, ErrZeroWeight)
   174  }
   175  
   176  func TestRemoveWeightMissingValidator(t *testing.T) {
   177  	require := require.New(t)
   178  
   179  	m := NewManager()
   180  	subnetID := ids.GenerateTestID()
   181  
   182  	require.NoError(m.AddStaker(subnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   183  
   184  	err := m.RemoveWeight(subnetID, ids.GenerateTestNodeID(), 1)
   185  	require.ErrorIs(err, errMissingValidator)
   186  }
   187  
   188  func TestRemoveWeightUnderflow(t *testing.T) {
   189  	require := require.New(t)
   190  
   191  	m := NewManager()
   192  	subnetID := ids.GenerateTestID()
   193  
   194  	require.NoError(m.AddStaker(subnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   195  
   196  	nodeID := ids.GenerateTestNodeID()
   197  	require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1))
   198  
   199  	err := m.RemoveWeight(subnetID, nodeID, 2)
   200  	require.ErrorIs(err, safemath.ErrUnderflow)
   201  
   202  	totalWeight, err := m.TotalWeight(subnetID)
   203  	require.NoError(err)
   204  	require.Equal(uint64(2), totalWeight)
   205  }
   206  
   207  func TestGet(t *testing.T) {
   208  	require := require.New(t)
   209  
   210  	m := NewManager()
   211  	subnetID := ids.GenerateTestID()
   212  
   213  	nodeID := ids.GenerateTestNodeID()
   214  	_, ok := m.GetValidator(subnetID, nodeID)
   215  	require.False(ok)
   216  
   217  	sk, err := bls.NewSecretKey()
   218  	require.NoError(err)
   219  
   220  	pk := bls.PublicFromSecretKey(sk)
   221  	require.NoError(m.AddStaker(subnetID, nodeID, pk, ids.Empty, 1))
   222  
   223  	vdr0, ok := m.GetValidator(subnetID, nodeID)
   224  	require.True(ok)
   225  	require.Equal(nodeID, vdr0.NodeID)
   226  	require.Equal(pk, vdr0.PublicKey)
   227  	require.Equal(uint64(1), vdr0.Weight)
   228  
   229  	require.NoError(m.AddWeight(subnetID, nodeID, 1))
   230  
   231  	vdr1, ok := m.GetValidator(subnetID, nodeID)
   232  	require.True(ok)
   233  	require.Equal(nodeID, vdr0.NodeID)
   234  	require.Equal(pk, vdr0.PublicKey)
   235  	require.Equal(uint64(1), vdr0.Weight)
   236  	require.Equal(nodeID, vdr1.NodeID)
   237  	require.Equal(pk, vdr1.PublicKey)
   238  	require.Equal(uint64(2), vdr1.Weight)
   239  
   240  	require.NoError(m.RemoveWeight(subnetID, nodeID, 2))
   241  	_, ok = m.GetValidator(subnetID, nodeID)
   242  	require.False(ok)
   243  }
   244  
   245  func TestLen(t *testing.T) {
   246  	require := require.New(t)
   247  
   248  	m := NewManager()
   249  	subnetID := ids.GenerateTestID()
   250  
   251  	count := m.Count(subnetID)
   252  	require.Zero(count)
   253  
   254  	nodeID0 := ids.GenerateTestNodeID()
   255  	require.NoError(m.AddStaker(subnetID, nodeID0, nil, ids.Empty, 1))
   256  
   257  	count = m.Count(subnetID)
   258  	require.Equal(1, count)
   259  
   260  	nodeID1 := ids.GenerateTestNodeID()
   261  	require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, 1))
   262  
   263  	count = m.Count(subnetID)
   264  	require.Equal(2, count)
   265  
   266  	require.NoError(m.RemoveWeight(subnetID, nodeID1, 1))
   267  
   268  	count = m.Count(subnetID)
   269  	require.Equal(1, count)
   270  
   271  	require.NoError(m.RemoveWeight(subnetID, nodeID0, 1))
   272  
   273  	count = m.Count(subnetID)
   274  	require.Zero(count)
   275  }
   276  
   277  func TestGetMap(t *testing.T) {
   278  	require := require.New(t)
   279  
   280  	m := NewManager()
   281  	subnetID := ids.GenerateTestID()
   282  
   283  	mp := m.GetMap(subnetID)
   284  	require.Empty(mp)
   285  
   286  	sk, err := bls.NewSecretKey()
   287  	require.NoError(err)
   288  
   289  	pk := bls.PublicFromSecretKey(sk)
   290  	nodeID0 := ids.GenerateTestNodeID()
   291  	require.NoError(m.AddStaker(subnetID, nodeID0, pk, ids.Empty, 2))
   292  
   293  	mp = m.GetMap(subnetID)
   294  	require.Len(mp, 1)
   295  	require.Contains(mp, nodeID0)
   296  
   297  	node0 := mp[nodeID0]
   298  	require.Equal(nodeID0, node0.NodeID)
   299  	require.Equal(pk, node0.PublicKey)
   300  	require.Equal(uint64(2), node0.Weight)
   301  
   302  	nodeID1 := ids.GenerateTestNodeID()
   303  	require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, 1))
   304  
   305  	mp = m.GetMap(subnetID)
   306  	require.Len(mp, 2)
   307  	require.Contains(mp, nodeID0)
   308  	require.Contains(mp, nodeID1)
   309  
   310  	node0 = mp[nodeID0]
   311  	require.Equal(nodeID0, node0.NodeID)
   312  	require.Equal(pk, node0.PublicKey)
   313  	require.Equal(uint64(2), node0.Weight)
   314  
   315  	node1 := mp[nodeID1]
   316  	require.Equal(nodeID1, node1.NodeID)
   317  	require.Nil(node1.PublicKey)
   318  	require.Equal(uint64(1), node1.Weight)
   319  
   320  	require.NoError(m.RemoveWeight(subnetID, nodeID0, 1))
   321  	require.Equal(nodeID0, node0.NodeID)
   322  	require.Equal(pk, node0.PublicKey)
   323  	require.Equal(uint64(2), node0.Weight)
   324  
   325  	mp = m.GetMap(subnetID)
   326  	require.Len(mp, 2)
   327  	require.Contains(mp, nodeID0)
   328  	require.Contains(mp, nodeID1)
   329  
   330  	node0 = mp[nodeID0]
   331  	require.Equal(nodeID0, node0.NodeID)
   332  	require.Equal(pk, node0.PublicKey)
   333  	require.Equal(uint64(1), node0.Weight)
   334  
   335  	node1 = mp[nodeID1]
   336  	require.Equal(nodeID1, node1.NodeID)
   337  	require.Nil(node1.PublicKey)
   338  	require.Equal(uint64(1), node1.Weight)
   339  
   340  	require.NoError(m.RemoveWeight(subnetID, nodeID0, 1))
   341  
   342  	mp = m.GetMap(subnetID)
   343  	require.Len(mp, 1)
   344  	require.Contains(mp, nodeID1)
   345  
   346  	node1 = mp[nodeID1]
   347  	require.Equal(nodeID1, node1.NodeID)
   348  	require.Nil(node1.PublicKey)
   349  	require.Equal(uint64(1), node1.Weight)
   350  
   351  	require.NoError(m.RemoveWeight(subnetID, nodeID1, 1))
   352  
   353  	require.Empty(m.GetMap(subnetID))
   354  }
   355  
   356  func TestWeight(t *testing.T) {
   357  	require := require.New(t)
   358  
   359  	vdr0 := ids.BuildTestNodeID([]byte{1})
   360  	weight0 := uint64(93)
   361  	vdr1 := ids.BuildTestNodeID([]byte{2})
   362  	weight1 := uint64(123)
   363  
   364  	m := NewManager()
   365  	subnetID := ids.GenerateTestID()
   366  	require.NoError(m.AddStaker(subnetID, vdr0, nil, ids.Empty, weight0))
   367  
   368  	require.NoError(m.AddStaker(subnetID, vdr1, nil, ids.Empty, weight1))
   369  
   370  	setWeight, err := m.TotalWeight(subnetID)
   371  	require.NoError(err)
   372  	expectedWeight := weight0 + weight1
   373  	require.Equal(expectedWeight, setWeight)
   374  }
   375  
   376  func TestSample(t *testing.T) {
   377  	require := require.New(t)
   378  
   379  	m := NewManager()
   380  	subnetID := ids.GenerateTestID()
   381  
   382  	sampled, err := m.Sample(subnetID, 0)
   383  	require.NoError(err)
   384  	require.Empty(sampled)
   385  
   386  	sk, err := bls.NewSecretKey()
   387  	require.NoError(err)
   388  
   389  	nodeID0 := ids.GenerateTestNodeID()
   390  	pk := bls.PublicFromSecretKey(sk)
   391  	require.NoError(m.AddStaker(subnetID, nodeID0, pk, ids.Empty, 1))
   392  
   393  	sampled, err = m.Sample(subnetID, 1)
   394  	require.NoError(err)
   395  	require.Equal([]ids.NodeID{nodeID0}, sampled)
   396  
   397  	_, err = m.Sample(subnetID, 2)
   398  	require.ErrorIs(err, errInsufficientWeight)
   399  
   400  	nodeID1 := ids.GenerateTestNodeID()
   401  	require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, math.MaxInt64-1))
   402  
   403  	sampled, err = m.Sample(subnetID, 1)
   404  	require.NoError(err)
   405  	require.Equal([]ids.NodeID{nodeID1}, sampled)
   406  
   407  	sampled, err = m.Sample(subnetID, 2)
   408  	require.NoError(err)
   409  	require.Equal([]ids.NodeID{nodeID1, nodeID1}, sampled)
   410  
   411  	sampled, err = m.Sample(subnetID, 3)
   412  	require.NoError(err)
   413  	require.Equal([]ids.NodeID{nodeID1, nodeID1, nodeID1}, sampled)
   414  }
   415  
   416  func TestString(t *testing.T) {
   417  	require := require.New(t)
   418  
   419  	nodeID0 := ids.EmptyNodeID
   420  	nodeID1, err := ids.NodeIDFromString("NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V")
   421  	require.NoError(err)
   422  
   423  	subnetID0, err := ids.FromString("TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES")
   424  	require.NoError(err)
   425  	subnetID1, err := ids.FromString("2mcwQKiD8VEspmMJpL1dc7okQQ5dDVAWeCBZ7FWBFAbxpv3t7w")
   426  	require.NoError(err)
   427  
   428  	m := NewManager()
   429  	require.NoError(m.AddStaker(subnetID0, nodeID0, nil, ids.Empty, 1))
   430  	require.NoError(m.AddStaker(subnetID0, nodeID1, nil, ids.Empty, math.MaxInt64-1))
   431  	require.NoError(m.AddStaker(subnetID1, nodeID1, nil, ids.Empty, 1))
   432  
   433  	expected := `Validator Manager: (Size = 2)
   434      Subnet[TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES]: Validator Set: (Size = 2, Weight = 9223372036854775807)
   435          Validator[0]: NodeID-111111111111111111116DBWJs, 1
   436          Validator[1]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 9223372036854775806
   437      Subnet[2mcwQKiD8VEspmMJpL1dc7okQQ5dDVAWeCBZ7FWBFAbxpv3t7w]: Validator Set: (Size = 1, Weight = 1)
   438          Validator[0]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 1`
   439  	result := m.String()
   440  	require.Equal(expected, result)
   441  }
   442  
   443  func TestAddCallback(t *testing.T) {
   444  	require := require.New(t)
   445  
   446  	expectedSK, err := bls.NewSecretKey()
   447  	require.NoError(err)
   448  
   449  	var (
   450  		expectedNodeID           = ids.GenerateTestNodeID()
   451  		expectedPK               = bls.PublicFromSecretKey(expectedSK)
   452  		expectedTxID             = ids.GenerateTestID()
   453  		expectedWeight    uint64 = 1
   454  		expectedSubnetID0        = ids.GenerateTestID()
   455  		expectedSubnetID1        = ids.GenerateTestID()
   456  
   457  		m                = NewManager()
   458  		managerCallCount = 0
   459  		setCallCount     = 0
   460  	)
   461  	m.RegisterCallbackListener(&managerCallbackListener{
   462  		t: t,
   463  		onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   464  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   465  			require.Equal(expectedNodeID, nodeID)
   466  			require.Equal(expectedPK, pk)
   467  			require.Equal(expectedTxID, txID)
   468  			require.Equal(expectedWeight, weight)
   469  			managerCallCount++
   470  		},
   471  	})
   472  	m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{
   473  		t: t,
   474  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   475  			require.Equal(expectedNodeID, nodeID)
   476  			require.Equal(expectedPK, pk)
   477  			require.Equal(expectedTxID, txID)
   478  			require.Equal(expectedWeight, weight)
   479  			setCallCount++
   480  		},
   481  	})
   482  	require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedWeight))
   483  	require.Equal(1, managerCallCount) // should be called for expectedSubnetID0
   484  	require.Equal(1, setCallCount)     // should be called for expectedSubnetID0
   485  
   486  	require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedWeight))
   487  	require.Equal(2, managerCallCount) // should be called for expectedSubnetID1
   488  	require.Equal(1, setCallCount)     // should not be called for expectedSubnetID1
   489  }
   490  
   491  func TestAddWeightCallback(t *testing.T) {
   492  	require := require.New(t)
   493  
   494  	expectedSK, err := bls.NewSecretKey()
   495  	require.NoError(err)
   496  
   497  	var (
   498  		expectedNodeID             = ids.GenerateTestNodeID()
   499  		expectedPK                 = bls.PublicFromSecretKey(expectedSK)
   500  		expectedTxID               = ids.GenerateTestID()
   501  		expectedOldWeight   uint64 = 1
   502  		expectedAddedWeight uint64 = 10
   503  		expectedNewWeight          = expectedOldWeight + expectedAddedWeight
   504  		expectedSubnetID0          = ids.GenerateTestID()
   505  		expectedSubnetID1          = ids.GenerateTestID()
   506  
   507  		m                      = NewManager()
   508  		managerAddCallCount    = 0
   509  		managerChangeCallCount = 0
   510  		setAddCallCount        = 0
   511  		setChangeCallCount     = 0
   512  	)
   513  
   514  	require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight))
   515  
   516  	m.RegisterCallbackListener(&managerCallbackListener{
   517  		t: t,
   518  		onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   519  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   520  			require.Equal(expectedNodeID, nodeID)
   521  			require.Equal(expectedPK, pk)
   522  			require.Equal(expectedTxID, txID)
   523  			require.Equal(expectedOldWeight, weight)
   524  			managerAddCallCount++
   525  		},
   526  		onWeight: func(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) {
   527  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   528  			require.Equal(expectedNodeID, nodeID)
   529  			require.Equal(expectedOldWeight, oldWeight)
   530  			require.Equal(expectedNewWeight, newWeight)
   531  			managerChangeCallCount++
   532  		},
   533  	})
   534  	m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{
   535  		t: t,
   536  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   537  			require.Equal(expectedNodeID, nodeID)
   538  			require.Equal(expectedPK, pk)
   539  			require.Equal(expectedTxID, txID)
   540  			require.Equal(expectedOldWeight, weight)
   541  			setAddCallCount++
   542  		},
   543  		onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) {
   544  			require.Equal(expectedNodeID, nodeID)
   545  			require.Equal(expectedOldWeight, oldWeight)
   546  			require.Equal(expectedNewWeight, newWeight)
   547  			setChangeCallCount++
   548  		},
   549  	})
   550  	require.Equal(1, managerAddCallCount)
   551  	require.Zero(managerChangeCallCount)
   552  	require.Equal(1, setAddCallCount)
   553  	require.Zero(setChangeCallCount)
   554  
   555  	require.NoError(m.AddWeight(expectedSubnetID0, expectedNodeID, expectedAddedWeight))
   556  	require.Equal(1, managerAddCallCount)
   557  	require.Equal(1, managerChangeCallCount)
   558  	require.Equal(1, setAddCallCount)
   559  	require.Equal(1, setChangeCallCount)
   560  
   561  	require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight))
   562  	require.Equal(2, managerAddCallCount)
   563  	require.Equal(1, managerChangeCallCount)
   564  	require.Equal(1, setAddCallCount)
   565  	require.Equal(1, setChangeCallCount)
   566  
   567  	require.NoError(m.AddWeight(expectedSubnetID1, expectedNodeID, expectedAddedWeight))
   568  	require.Equal(2, managerAddCallCount)
   569  	require.Equal(2, managerChangeCallCount)
   570  	require.Equal(1, setAddCallCount)
   571  	require.Equal(1, setChangeCallCount)
   572  }
   573  
   574  func TestRemoveWeightCallback(t *testing.T) {
   575  	require := require.New(t)
   576  
   577  	expectedSK, err := bls.NewSecretKey()
   578  	require.NoError(err)
   579  
   580  	var (
   581  		expectedNodeID               = ids.GenerateTestNodeID()
   582  		expectedPK                   = bls.PublicFromSecretKey(expectedSK)
   583  		expectedTxID                 = ids.GenerateTestID()
   584  		expectedNewWeight     uint64 = 1
   585  		expectedRemovedWeight uint64 = 10
   586  		expectedOldWeight            = expectedNewWeight + expectedRemovedWeight
   587  		expectedSubnetID0            = ids.GenerateTestID()
   588  		expectedSubnetID1            = ids.GenerateTestID()
   589  
   590  		m                      = NewManager()
   591  		managerAddCallCount    = 0
   592  		managerChangeCallCount = 0
   593  		setAddCallCount        = 0
   594  		setChangeCallCount     = 0
   595  	)
   596  
   597  	require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight))
   598  
   599  	m.RegisterCallbackListener(&managerCallbackListener{
   600  		t: t,
   601  		onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   602  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   603  			require.Equal(expectedNodeID, nodeID)
   604  			require.Equal(expectedPK, pk)
   605  			require.Equal(expectedTxID, txID)
   606  			require.Equal(expectedOldWeight, weight)
   607  			managerAddCallCount++
   608  		},
   609  		onWeight: func(subnetID ids.ID, nodeID ids.NodeID, oldWeight, newWeight uint64) {
   610  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   611  			require.Equal(expectedNodeID, nodeID)
   612  			require.Equal(expectedOldWeight, oldWeight)
   613  			require.Equal(expectedNewWeight, newWeight)
   614  			managerChangeCallCount++
   615  		},
   616  	})
   617  	m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{
   618  		t: t,
   619  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   620  			require.Equal(expectedNodeID, nodeID)
   621  			require.Equal(expectedPK, pk)
   622  			require.Equal(expectedTxID, txID)
   623  			require.Equal(expectedOldWeight, weight)
   624  			setAddCallCount++
   625  		},
   626  		onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) {
   627  			require.Equal(expectedNodeID, nodeID)
   628  			require.Equal(expectedOldWeight, oldWeight)
   629  			require.Equal(expectedNewWeight, newWeight)
   630  			setChangeCallCount++
   631  		},
   632  	})
   633  	require.Equal(1, managerAddCallCount)
   634  	require.Zero(managerChangeCallCount)
   635  	require.Equal(1, setAddCallCount)
   636  	require.Zero(setChangeCallCount)
   637  
   638  	require.NoError(m.RemoveWeight(expectedSubnetID0, expectedNodeID, expectedRemovedWeight))
   639  	require.Equal(1, managerAddCallCount)
   640  	require.Equal(1, managerChangeCallCount)
   641  	require.Equal(1, setAddCallCount)
   642  	require.Equal(1, setChangeCallCount)
   643  
   644  	require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedOldWeight))
   645  	require.Equal(2, managerAddCallCount)
   646  	require.Equal(1, managerChangeCallCount)
   647  	require.Equal(1, setAddCallCount)
   648  	require.Equal(1, setChangeCallCount)
   649  
   650  	require.NoError(m.RemoveWeight(expectedSubnetID1, expectedNodeID, expectedRemovedWeight))
   651  	require.Equal(2, managerAddCallCount)
   652  	require.Equal(2, managerChangeCallCount)
   653  	require.Equal(1, setAddCallCount)
   654  	require.Equal(1, setChangeCallCount)
   655  }
   656  
   657  func TestRemoveCallback(t *testing.T) {
   658  	require := require.New(t)
   659  
   660  	expectedSK, err := bls.NewSecretKey()
   661  	require.NoError(err)
   662  
   663  	var (
   664  		expectedNodeID           = ids.GenerateTestNodeID()
   665  		expectedPK               = bls.PublicFromSecretKey(expectedSK)
   666  		expectedTxID             = ids.GenerateTestID()
   667  		expectedWeight    uint64 = 1
   668  		expectedSubnetID0        = ids.GenerateTestID()
   669  		expectedSubnetID1        = ids.GenerateTestID()
   670  
   671  		m                      = NewManager()
   672  		managerAddCallCount    = 0
   673  		managerRemoveCallCount = 0
   674  		setAddCallCount        = 0
   675  		setRemoveCallCount     = 0
   676  	)
   677  
   678  	require.NoError(m.AddStaker(expectedSubnetID0, expectedNodeID, expectedPK, expectedTxID, expectedWeight))
   679  
   680  	m.RegisterCallbackListener(&managerCallbackListener{
   681  		t: t,
   682  		onAdd: func(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   683  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   684  			require.Equal(expectedNodeID, nodeID)
   685  			require.Equal(expectedPK, pk)
   686  			require.Equal(expectedTxID, txID)
   687  			require.Equal(expectedWeight, weight)
   688  			managerAddCallCount++
   689  		},
   690  		onRemoved: func(subnetID ids.ID, nodeID ids.NodeID, weight uint64) {
   691  			require.Contains([]ids.ID{expectedSubnetID0, expectedSubnetID1}, subnetID)
   692  			require.Equal(expectedNodeID, nodeID)
   693  			require.Equal(expectedWeight, weight)
   694  			managerRemoveCallCount++
   695  		},
   696  	})
   697  	m.RegisterSetCallbackListener(expectedSubnetID0, &setCallbackListener{
   698  		t: t,
   699  		onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) {
   700  			require.Equal(expectedNodeID, nodeID)
   701  			require.Equal(expectedPK, pk)
   702  			require.Equal(expectedTxID, txID)
   703  			require.Equal(expectedWeight, weight)
   704  			setAddCallCount++
   705  		},
   706  		onRemoved: func(nodeID ids.NodeID, weight uint64) {
   707  			require.Equal(expectedNodeID, nodeID)
   708  			require.Equal(expectedWeight, weight)
   709  			setRemoveCallCount++
   710  		},
   711  	})
   712  	require.Equal(1, managerAddCallCount)
   713  	require.Zero(managerRemoveCallCount)
   714  	require.Equal(1, setAddCallCount)
   715  	require.Zero(setRemoveCallCount)
   716  
   717  	require.NoError(m.RemoveWeight(expectedSubnetID0, expectedNodeID, expectedWeight))
   718  	require.Equal(1, managerAddCallCount)
   719  	require.Equal(1, managerRemoveCallCount)
   720  	require.Equal(1, setAddCallCount)
   721  	require.Equal(1, setRemoveCallCount)
   722  
   723  	require.NoError(m.AddStaker(expectedSubnetID1, expectedNodeID, expectedPK, expectedTxID, expectedWeight))
   724  	require.Equal(2, managerAddCallCount)
   725  	require.Equal(1, managerRemoveCallCount)
   726  	require.Equal(1, setAddCallCount)
   727  	require.Equal(1, setRemoveCallCount)
   728  
   729  	require.NoError(m.RemoveWeight(expectedSubnetID1, expectedNodeID, expectedWeight))
   730  	require.Equal(2, managerAddCallCount)
   731  	require.Equal(2, managerRemoveCallCount)
   732  	require.Equal(1, setAddCallCount)
   733  	require.Equal(1, setRemoveCallCount)
   734  }