github.com/MetalBlockchain/metalgo@v1.11.9/network/p2p/validators_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package p2p
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"github.com/stretchr/testify/require"
    14  	"go.uber.org/mock/gomock"
    15  
    16  	"github.com/MetalBlockchain/metalgo/ids"
    17  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    18  	"github.com/MetalBlockchain/metalgo/snow/validators"
    19  	"github.com/MetalBlockchain/metalgo/utils/logging"
    20  )
    21  
    22  func TestValidatorsSample(t *testing.T) {
    23  	errFoobar := errors.New("foobar")
    24  	nodeID1 := ids.GenerateTestNodeID()
    25  	nodeID2 := ids.GenerateTestNodeID()
    26  	nodeID3 := ids.GenerateTestNodeID()
    27  
    28  	type call struct {
    29  		limit int
    30  
    31  		time time.Time
    32  
    33  		height              uint64
    34  		getCurrentHeightErr error
    35  
    36  		validators         []ids.NodeID
    37  		getValidatorSetErr error
    38  
    39  		// superset of possible values in the result
    40  		expected []ids.NodeID
    41  	}
    42  
    43  	tests := []struct {
    44  		name         string
    45  		maxStaleness time.Duration
    46  		calls        []call
    47  	}{
    48  		{
    49  			// if we aren't connected to a validator, we shouldn't return it
    50  			name:         "drop disconnected validators",
    51  			maxStaleness: time.Hour,
    52  			calls: []call{
    53  				{
    54  					time:       time.Time{}.Add(time.Second),
    55  					limit:      2,
    56  					height:     1,
    57  					validators: []ids.NodeID{nodeID1, nodeID3},
    58  					expected:   []ids.NodeID{nodeID1},
    59  				},
    60  			},
    61  		},
    62  		{
    63  			// if we don't have as many validators as requested by the caller,
    64  			// we should return all the validators we have
    65  			name:         "less than limit validators",
    66  			maxStaleness: time.Hour,
    67  			calls: []call{
    68  				{
    69  					time:       time.Time{}.Add(time.Second),
    70  					limit:      2,
    71  					height:     1,
    72  					validators: []ids.NodeID{nodeID1},
    73  					expected:   []ids.NodeID{nodeID1},
    74  				},
    75  			},
    76  		},
    77  		{
    78  			// if we have as many validators as requested by the caller, we
    79  			// should return all the validators we have
    80  			name:         "equal to limit validators",
    81  			maxStaleness: time.Hour,
    82  			calls: []call{
    83  				{
    84  					time:       time.Time{}.Add(time.Second),
    85  					limit:      1,
    86  					height:     1,
    87  					validators: []ids.NodeID{nodeID1},
    88  					expected:   []ids.NodeID{nodeID1},
    89  				},
    90  			},
    91  		},
    92  		{
    93  			// if we have less validators than requested by the caller, we
    94  			// should return a subset of the validators that we have
    95  			name:         "less than limit validators",
    96  			maxStaleness: time.Hour,
    97  			calls: []call{
    98  				{
    99  					time:       time.Time{}.Add(time.Second),
   100  					limit:      1,
   101  					height:     1,
   102  					validators: []ids.NodeID{nodeID1, nodeID2},
   103  					expected:   []ids.NodeID{nodeID1, nodeID2},
   104  				},
   105  			},
   106  		},
   107  		{
   108  			name:         "within max staleness threshold",
   109  			maxStaleness: time.Hour,
   110  			calls: []call{
   111  				{
   112  					time:       time.Time{}.Add(time.Second),
   113  					limit:      1,
   114  					height:     1,
   115  					validators: []ids.NodeID{nodeID1},
   116  					expected:   []ids.NodeID{nodeID1},
   117  				},
   118  			},
   119  		},
   120  		{
   121  			name:         "beyond max staleness threshold",
   122  			maxStaleness: time.Hour,
   123  			calls: []call{
   124  				{
   125  					limit:      1,
   126  					time:       time.Time{}.Add(time.Hour),
   127  					height:     1,
   128  					validators: []ids.NodeID{nodeID1},
   129  					expected:   []ids.NodeID{nodeID1},
   130  				},
   131  			},
   132  		},
   133  		{
   134  			name:         "fail to get current height",
   135  			maxStaleness: time.Second,
   136  			calls: []call{
   137  				{
   138  					limit:               1,
   139  					time:                time.Time{}.Add(time.Hour),
   140  					getCurrentHeightErr: errFoobar,
   141  					expected:            []ids.NodeID{},
   142  				},
   143  			},
   144  		},
   145  		{
   146  			name:         "second get validator set call fails",
   147  			maxStaleness: time.Minute,
   148  			calls: []call{
   149  				{
   150  					limit:      1,
   151  					time:       time.Time{}.Add(time.Second),
   152  					height:     1,
   153  					validators: []ids.NodeID{nodeID1},
   154  					expected:   []ids.NodeID{nodeID1},
   155  				},
   156  				{
   157  					limit:              1,
   158  					time:               time.Time{}.Add(time.Hour),
   159  					height:             1,
   160  					getValidatorSetErr: errFoobar,
   161  					expected:           []ids.NodeID{},
   162  				},
   163  			},
   164  		},
   165  	}
   166  
   167  	for _, tt := range tests {
   168  		t.Run(tt.name, func(t *testing.T) {
   169  			require := require.New(t)
   170  			subnetID := ids.GenerateTestID()
   171  			ctrl := gomock.NewController(t)
   172  			mockValidators := validators.NewMockState(ctrl)
   173  
   174  			calls := make([]any, 0)
   175  			for _, call := range tt.calls {
   176  				calls = append(calls, mockValidators.EXPECT().
   177  					GetCurrentHeight(gomock.Any()).Return(call.height, call.getCurrentHeightErr))
   178  
   179  				if call.getCurrentHeightErr != nil {
   180  					continue
   181  				}
   182  
   183  				validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0)
   184  				for _, validator := range call.validators {
   185  					validatorSet[validator] = &validators.GetValidatorOutput{
   186  						NodeID: validator,
   187  						Weight: 1,
   188  					}
   189  				}
   190  
   191  				calls = append(calls,
   192  					mockValidators.EXPECT().
   193  						GetValidatorSet(gomock.Any(), gomock.Any(), subnetID).
   194  						Return(validatorSet, call.getValidatorSetErr))
   195  			}
   196  			gomock.InOrder(calls...)
   197  
   198  			network, err := NewNetwork(logging.NoLog{}, &common.FakeSender{}, prometheus.NewRegistry(), "")
   199  			require.NoError(err)
   200  
   201  			ctx := context.Background()
   202  			require.NoError(network.Connected(ctx, nodeID1, nil))
   203  			require.NoError(network.Connected(ctx, nodeID2, nil))
   204  
   205  			v := NewValidators(network.Peers, network.log, subnetID, mockValidators, tt.maxStaleness)
   206  			for _, call := range tt.calls {
   207  				v.lastUpdated = call.time
   208  				sampled := v.Sample(ctx, call.limit)
   209  				require.LessOrEqual(len(sampled), call.limit)
   210  				require.Subset(call.expected, sampled)
   211  			}
   212  		})
   213  	}
   214  }
   215  
   216  func TestValidatorsTop(t *testing.T) {
   217  	nodeID1 := ids.GenerateTestNodeID()
   218  	nodeID2 := ids.GenerateTestNodeID()
   219  	nodeID3 := ids.GenerateTestNodeID()
   220  
   221  	tests := []struct {
   222  		name       string
   223  		validators []validator
   224  		percentage float64
   225  		expected   []ids.NodeID
   226  	}{
   227  		{
   228  			name: "top 0% is empty",
   229  			validators: []validator{
   230  				{
   231  					nodeID: nodeID1,
   232  					weight: 1,
   233  				},
   234  				{
   235  					nodeID: nodeID2,
   236  					weight: 1,
   237  				},
   238  			},
   239  			percentage: 0,
   240  			expected:   []ids.NodeID{},
   241  		},
   242  		{
   243  			name: "top 100% is full",
   244  			validators: []validator{
   245  				{
   246  					nodeID: nodeID1,
   247  					weight: 2,
   248  				},
   249  				{
   250  					nodeID: nodeID2,
   251  					weight: 1,
   252  				},
   253  			},
   254  			percentage: 1,
   255  			expected: []ids.NodeID{
   256  				nodeID1,
   257  				nodeID2,
   258  			},
   259  		},
   260  		{
   261  			name: "top 50% takes larger validator",
   262  			validators: []validator{
   263  				{
   264  					nodeID: nodeID1,
   265  					weight: 2,
   266  				},
   267  				{
   268  					nodeID: nodeID2,
   269  					weight: 1,
   270  				},
   271  			},
   272  			percentage: .5,
   273  			expected: []ids.NodeID{
   274  				nodeID1,
   275  			},
   276  		},
   277  		{
   278  			name: "top 50% bound",
   279  			validators: []validator{
   280  				{
   281  					nodeID: nodeID1,
   282  					weight: 2,
   283  				},
   284  				{
   285  					nodeID: nodeID2,
   286  					weight: 1,
   287  				},
   288  				{
   289  					nodeID: nodeID3,
   290  					weight: 1,
   291  				},
   292  			},
   293  			percentage: .5,
   294  			expected: []ids.NodeID{
   295  				nodeID1,
   296  			},
   297  		},
   298  	}
   299  	for _, test := range tests {
   300  		t.Run(test.name, func(t *testing.T) {
   301  			require := require.New(t)
   302  			ctrl := gomock.NewController(t)
   303  
   304  			validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0)
   305  			for _, validator := range test.validators {
   306  				validatorSet[validator.nodeID] = &validators.GetValidatorOutput{
   307  					NodeID: validator.nodeID,
   308  					Weight: validator.weight,
   309  				}
   310  			}
   311  
   312  			subnetID := ids.GenerateTestID()
   313  			mockValidators := validators.NewMockState(ctrl)
   314  
   315  			mockValidators.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(1), nil)
   316  			mockValidators.EXPECT().GetValidatorSet(gomock.Any(), uint64(1), subnetID).Return(validatorSet, nil)
   317  
   318  			network, err := NewNetwork(logging.NoLog{}, &common.FakeSender{}, prometheus.NewRegistry(), "")
   319  			require.NoError(err)
   320  
   321  			ctx := context.Background()
   322  			require.NoError(network.Connected(ctx, nodeID1, nil))
   323  			require.NoError(network.Connected(ctx, nodeID2, nil))
   324  
   325  			v := NewValidators(network.Peers, network.log, subnetID, mockValidators, time.Second)
   326  			nodeIDs := v.Top(ctx, test.percentage)
   327  			require.Equal(test.expected, nodeIDs)
   328  		})
   329  	}
   330  }