github.com/ava-labs/avalanchego@v1.11.11/network/p2p/validators_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package p2p
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"github.com/stretchr/testify/require"
    14  	"go.uber.org/mock/gomock"
    15  
    16  	"github.com/ava-labs/avalanchego/ids"
    17  	"github.com/ava-labs/avalanchego/snow/engine/enginetest"
    18  	"github.com/ava-labs/avalanchego/snow/validators"
    19  	"github.com/ava-labs/avalanchego/snow/validators/validatorsmock"
    20  	"github.com/ava-labs/avalanchego/utils/logging"
    21  )
    22  
    23  func TestValidatorsSample(t *testing.T) {
    24  	errFoobar := errors.New("foobar")
    25  	nodeID1 := ids.GenerateTestNodeID()
    26  	nodeID2 := ids.GenerateTestNodeID()
    27  	nodeID3 := ids.GenerateTestNodeID()
    28  
    29  	type call struct {
    30  		limit int
    31  
    32  		time time.Time
    33  
    34  		height              uint64
    35  		getCurrentHeightErr error
    36  
    37  		validators         []ids.NodeID
    38  		getValidatorSetErr error
    39  
    40  		// superset of possible values in the result
    41  		expected []ids.NodeID
    42  	}
    43  
    44  	tests := []struct {
    45  		name         string
    46  		maxStaleness time.Duration
    47  		calls        []call
    48  	}{
    49  		{
    50  			// if we aren't connected to a validator, we shouldn't return it
    51  			name:         "drop disconnected validators",
    52  			maxStaleness: time.Hour,
    53  			calls: []call{
    54  				{
    55  					time:       time.Time{}.Add(time.Second),
    56  					limit:      2,
    57  					height:     1,
    58  					validators: []ids.NodeID{nodeID1, nodeID3},
    59  					expected:   []ids.NodeID{nodeID1},
    60  				},
    61  			},
    62  		},
    63  		{
    64  			// if we don't have as many validators as requested by the caller,
    65  			// we should return all the validators we have
    66  			name:         "less than limit validators",
    67  			maxStaleness: time.Hour,
    68  			calls: []call{
    69  				{
    70  					time:       time.Time{}.Add(time.Second),
    71  					limit:      2,
    72  					height:     1,
    73  					validators: []ids.NodeID{nodeID1},
    74  					expected:   []ids.NodeID{nodeID1},
    75  				},
    76  			},
    77  		},
    78  		{
    79  			// if we have as many validators as requested by the caller, we
    80  			// should return all the validators we have
    81  			name:         "equal to limit validators",
    82  			maxStaleness: time.Hour,
    83  			calls: []call{
    84  				{
    85  					time:       time.Time{}.Add(time.Second),
    86  					limit:      1,
    87  					height:     1,
    88  					validators: []ids.NodeID{nodeID1},
    89  					expected:   []ids.NodeID{nodeID1},
    90  				},
    91  			},
    92  		},
    93  		{
    94  			// if we have less validators than requested by the caller, we
    95  			// should return a subset of the validators that we have
    96  			name:         "less than limit validators",
    97  			maxStaleness: time.Hour,
    98  			calls: []call{
    99  				{
   100  					time:       time.Time{}.Add(time.Second),
   101  					limit:      1,
   102  					height:     1,
   103  					validators: []ids.NodeID{nodeID1, nodeID2},
   104  					expected:   []ids.NodeID{nodeID1, nodeID2},
   105  				},
   106  			},
   107  		},
   108  		{
   109  			name:         "within max staleness threshold",
   110  			maxStaleness: time.Hour,
   111  			calls: []call{
   112  				{
   113  					time:       time.Time{}.Add(time.Second),
   114  					limit:      1,
   115  					height:     1,
   116  					validators: []ids.NodeID{nodeID1},
   117  					expected:   []ids.NodeID{nodeID1},
   118  				},
   119  			},
   120  		},
   121  		{
   122  			name:         "beyond max staleness threshold",
   123  			maxStaleness: time.Hour,
   124  			calls: []call{
   125  				{
   126  					limit:      1,
   127  					time:       time.Time{}.Add(time.Hour),
   128  					height:     1,
   129  					validators: []ids.NodeID{nodeID1},
   130  					expected:   []ids.NodeID{nodeID1},
   131  				},
   132  			},
   133  		},
   134  		{
   135  			name:         "fail to get current height",
   136  			maxStaleness: time.Second,
   137  			calls: []call{
   138  				{
   139  					limit:               1,
   140  					time:                time.Time{}.Add(time.Hour),
   141  					getCurrentHeightErr: errFoobar,
   142  					expected:            []ids.NodeID{},
   143  				},
   144  			},
   145  		},
   146  		{
   147  			name:         "second get validator set call fails",
   148  			maxStaleness: time.Minute,
   149  			calls: []call{
   150  				{
   151  					limit:      1,
   152  					time:       time.Time{}.Add(time.Second),
   153  					height:     1,
   154  					validators: []ids.NodeID{nodeID1},
   155  					expected:   []ids.NodeID{nodeID1},
   156  				},
   157  				{
   158  					limit:              1,
   159  					time:               time.Time{}.Add(time.Hour),
   160  					height:             1,
   161  					getValidatorSetErr: errFoobar,
   162  					expected:           []ids.NodeID{},
   163  				},
   164  			},
   165  		},
   166  	}
   167  
   168  	for _, tt := range tests {
   169  		t.Run(tt.name, func(t *testing.T) {
   170  			require := require.New(t)
   171  			subnetID := ids.GenerateTestID()
   172  			ctrl := gomock.NewController(t)
   173  			mockValidators := validatorsmock.NewState(ctrl)
   174  
   175  			calls := make([]any, 0)
   176  			for _, call := range tt.calls {
   177  				calls = append(calls, mockValidators.EXPECT().
   178  					GetCurrentHeight(gomock.Any()).Return(call.height, call.getCurrentHeightErr))
   179  
   180  				if call.getCurrentHeightErr != nil {
   181  					continue
   182  				}
   183  
   184  				validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0)
   185  				for _, validator := range call.validators {
   186  					validatorSet[validator] = &validators.GetValidatorOutput{
   187  						NodeID: validator,
   188  						Weight: 1,
   189  					}
   190  				}
   191  
   192  				calls = append(calls,
   193  					mockValidators.EXPECT().
   194  						GetValidatorSet(gomock.Any(), gomock.Any(), subnetID).
   195  						Return(validatorSet, call.getValidatorSetErr))
   196  			}
   197  			gomock.InOrder(calls...)
   198  
   199  			network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "")
   200  			require.NoError(err)
   201  
   202  			ctx := context.Background()
   203  			require.NoError(network.Connected(ctx, nodeID1, nil))
   204  			require.NoError(network.Connected(ctx, nodeID2, nil))
   205  
   206  			v := NewValidators(network.Peers, network.log, subnetID, mockValidators, tt.maxStaleness)
   207  			for _, call := range tt.calls {
   208  				v.lastUpdated = call.time
   209  				sampled := v.Sample(ctx, call.limit)
   210  				require.LessOrEqual(len(sampled), call.limit)
   211  				require.Subset(call.expected, sampled)
   212  			}
   213  		})
   214  	}
   215  }
   216  
   217  func TestValidatorsTop(t *testing.T) {
   218  	nodeID1 := ids.GenerateTestNodeID()
   219  	nodeID2 := ids.GenerateTestNodeID()
   220  	nodeID3 := ids.GenerateTestNodeID()
   221  
   222  	tests := []struct {
   223  		name       string
   224  		validators []validator
   225  		percentage float64
   226  		expected   []ids.NodeID
   227  	}{
   228  		{
   229  			name: "top 0% is empty",
   230  			validators: []validator{
   231  				{
   232  					nodeID: nodeID1,
   233  					weight: 1,
   234  				},
   235  				{
   236  					nodeID: nodeID2,
   237  					weight: 1,
   238  				},
   239  			},
   240  			percentage: 0,
   241  			expected:   []ids.NodeID{},
   242  		},
   243  		{
   244  			name: "top 100% is full",
   245  			validators: []validator{
   246  				{
   247  					nodeID: nodeID1,
   248  					weight: 2,
   249  				},
   250  				{
   251  					nodeID: nodeID2,
   252  					weight: 1,
   253  				},
   254  			},
   255  			percentage: 1,
   256  			expected: []ids.NodeID{
   257  				nodeID1,
   258  				nodeID2,
   259  			},
   260  		},
   261  		{
   262  			name: "top 50% takes larger validator",
   263  			validators: []validator{
   264  				{
   265  					nodeID: nodeID1,
   266  					weight: 2,
   267  				},
   268  				{
   269  					nodeID: nodeID2,
   270  					weight: 1,
   271  				},
   272  			},
   273  			percentage: .5,
   274  			expected: []ids.NodeID{
   275  				nodeID1,
   276  			},
   277  		},
   278  		{
   279  			name: "top 50% bound",
   280  			validators: []validator{
   281  				{
   282  					nodeID: nodeID1,
   283  					weight: 2,
   284  				},
   285  				{
   286  					nodeID: nodeID2,
   287  					weight: 1,
   288  				},
   289  				{
   290  					nodeID: nodeID3,
   291  					weight: 1,
   292  				},
   293  			},
   294  			percentage: .5,
   295  			expected: []ids.NodeID{
   296  				nodeID1,
   297  			},
   298  		},
   299  	}
   300  	for _, test := range tests {
   301  		t.Run(test.name, func(t *testing.T) {
   302  			require := require.New(t)
   303  			ctrl := gomock.NewController(t)
   304  
   305  			validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0)
   306  			for _, validator := range test.validators {
   307  				validatorSet[validator.nodeID] = &validators.GetValidatorOutput{
   308  					NodeID: validator.nodeID,
   309  					Weight: validator.weight,
   310  				}
   311  			}
   312  
   313  			subnetID := ids.GenerateTestID()
   314  			mockValidators := validatorsmock.NewState(ctrl)
   315  
   316  			mockValidators.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(1), nil)
   317  			mockValidators.EXPECT().GetValidatorSet(gomock.Any(), uint64(1), subnetID).Return(validatorSet, nil)
   318  
   319  			network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "")
   320  			require.NoError(err)
   321  
   322  			ctx := context.Background()
   323  			require.NoError(network.Connected(ctx, nodeID1, nil))
   324  			require.NoError(network.Connected(ctx, nodeID2, nil))
   325  
   326  			v := NewValidators(network.Peers, network.log, subnetID, mockValidators, time.Second)
   327  			nodeIDs := v.Top(ctx, test.percentage)
   328  			require.Equal(test.expected, nodeIDs)
   329  		})
   330  	}
   331  }