github.com/ava-labs/avalanchego@v1.11.11/network/p2p/validators_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package p2p 5 6 import ( 7 "context" 8 "errors" 9 "testing" 10 "time" 11 12 "github.com/prometheus/client_golang/prometheus" 13 "github.com/stretchr/testify/require" 14 "go.uber.org/mock/gomock" 15 16 "github.com/ava-labs/avalanchego/ids" 17 "github.com/ava-labs/avalanchego/snow/engine/enginetest" 18 "github.com/ava-labs/avalanchego/snow/validators" 19 "github.com/ava-labs/avalanchego/snow/validators/validatorsmock" 20 "github.com/ava-labs/avalanchego/utils/logging" 21 ) 22 23 func TestValidatorsSample(t *testing.T) { 24 errFoobar := errors.New("foobar") 25 nodeID1 := ids.GenerateTestNodeID() 26 nodeID2 := ids.GenerateTestNodeID() 27 nodeID3 := ids.GenerateTestNodeID() 28 29 type call struct { 30 limit int 31 32 time time.Time 33 34 height uint64 35 getCurrentHeightErr error 36 37 validators []ids.NodeID 38 getValidatorSetErr error 39 40 // superset of possible values in the result 41 expected []ids.NodeID 42 } 43 44 tests := []struct { 45 name string 46 maxStaleness time.Duration 47 calls []call 48 }{ 49 { 50 // if we aren't connected to a validator, we shouldn't return it 51 name: "drop disconnected validators", 52 maxStaleness: time.Hour, 53 calls: []call{ 54 { 55 time: time.Time{}.Add(time.Second), 56 limit: 2, 57 height: 1, 58 validators: []ids.NodeID{nodeID1, nodeID3}, 59 expected: []ids.NodeID{nodeID1}, 60 }, 61 }, 62 }, 63 { 64 // if we don't have as many validators as requested by the caller, 65 // we should return all the validators we have 66 name: "less than limit validators", 67 maxStaleness: time.Hour, 68 calls: []call{ 69 { 70 time: time.Time{}.Add(time.Second), 71 limit: 2, 72 height: 1, 73 validators: []ids.NodeID{nodeID1}, 74 expected: []ids.NodeID{nodeID1}, 75 }, 76 }, 77 }, 78 { 79 // if we have as many validators as requested by the caller, we 80 // should return all the validators we have 81 name: "equal to limit validators", 82 maxStaleness: time.Hour, 83 calls: []call{ 84 { 85 time: time.Time{}.Add(time.Second), 86 limit: 1, 87 height: 1, 88 validators: []ids.NodeID{nodeID1}, 89 expected: []ids.NodeID{nodeID1}, 90 }, 91 }, 92 }, 93 { 94 // if we have less validators than requested by the caller, we 95 // should return a subset of the validators that we have 96 name: "less than limit validators", 97 maxStaleness: time.Hour, 98 calls: []call{ 99 { 100 time: time.Time{}.Add(time.Second), 101 limit: 1, 102 height: 1, 103 validators: []ids.NodeID{nodeID1, nodeID2}, 104 expected: []ids.NodeID{nodeID1, nodeID2}, 105 }, 106 }, 107 }, 108 { 109 name: "within max staleness threshold", 110 maxStaleness: time.Hour, 111 calls: []call{ 112 { 113 time: time.Time{}.Add(time.Second), 114 limit: 1, 115 height: 1, 116 validators: []ids.NodeID{nodeID1}, 117 expected: []ids.NodeID{nodeID1}, 118 }, 119 }, 120 }, 121 { 122 name: "beyond max staleness threshold", 123 maxStaleness: time.Hour, 124 calls: []call{ 125 { 126 limit: 1, 127 time: time.Time{}.Add(time.Hour), 128 height: 1, 129 validators: []ids.NodeID{nodeID1}, 130 expected: []ids.NodeID{nodeID1}, 131 }, 132 }, 133 }, 134 { 135 name: "fail to get current height", 136 maxStaleness: time.Second, 137 calls: []call{ 138 { 139 limit: 1, 140 time: time.Time{}.Add(time.Hour), 141 getCurrentHeightErr: errFoobar, 142 expected: []ids.NodeID{}, 143 }, 144 }, 145 }, 146 { 147 name: "second get validator set call fails", 148 maxStaleness: time.Minute, 149 calls: []call{ 150 { 151 limit: 1, 152 time: time.Time{}.Add(time.Second), 153 height: 1, 154 validators: []ids.NodeID{nodeID1}, 155 expected: []ids.NodeID{nodeID1}, 156 }, 157 { 158 limit: 1, 159 time: time.Time{}.Add(time.Hour), 160 height: 1, 161 getValidatorSetErr: errFoobar, 162 expected: []ids.NodeID{}, 163 }, 164 }, 165 }, 166 } 167 168 for _, tt := range tests { 169 t.Run(tt.name, func(t *testing.T) { 170 require := require.New(t) 171 subnetID := ids.GenerateTestID() 172 ctrl := gomock.NewController(t) 173 mockValidators := validatorsmock.NewState(ctrl) 174 175 calls := make([]any, 0) 176 for _, call := range tt.calls { 177 calls = append(calls, mockValidators.EXPECT(). 178 GetCurrentHeight(gomock.Any()).Return(call.height, call.getCurrentHeightErr)) 179 180 if call.getCurrentHeightErr != nil { 181 continue 182 } 183 184 validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0) 185 for _, validator := range call.validators { 186 validatorSet[validator] = &validators.GetValidatorOutput{ 187 NodeID: validator, 188 Weight: 1, 189 } 190 } 191 192 calls = append(calls, 193 mockValidators.EXPECT(). 194 GetValidatorSet(gomock.Any(), gomock.Any(), subnetID). 195 Return(validatorSet, call.getValidatorSetErr)) 196 } 197 gomock.InOrder(calls...) 198 199 network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "") 200 require.NoError(err) 201 202 ctx := context.Background() 203 require.NoError(network.Connected(ctx, nodeID1, nil)) 204 require.NoError(network.Connected(ctx, nodeID2, nil)) 205 206 v := NewValidators(network.Peers, network.log, subnetID, mockValidators, tt.maxStaleness) 207 for _, call := range tt.calls { 208 v.lastUpdated = call.time 209 sampled := v.Sample(ctx, call.limit) 210 require.LessOrEqual(len(sampled), call.limit) 211 require.Subset(call.expected, sampled) 212 } 213 }) 214 } 215 } 216 217 func TestValidatorsTop(t *testing.T) { 218 nodeID1 := ids.GenerateTestNodeID() 219 nodeID2 := ids.GenerateTestNodeID() 220 nodeID3 := ids.GenerateTestNodeID() 221 222 tests := []struct { 223 name string 224 validators []validator 225 percentage float64 226 expected []ids.NodeID 227 }{ 228 { 229 name: "top 0% is empty", 230 validators: []validator{ 231 { 232 nodeID: nodeID1, 233 weight: 1, 234 }, 235 { 236 nodeID: nodeID2, 237 weight: 1, 238 }, 239 }, 240 percentage: 0, 241 expected: []ids.NodeID{}, 242 }, 243 { 244 name: "top 100% is full", 245 validators: []validator{ 246 { 247 nodeID: nodeID1, 248 weight: 2, 249 }, 250 { 251 nodeID: nodeID2, 252 weight: 1, 253 }, 254 }, 255 percentage: 1, 256 expected: []ids.NodeID{ 257 nodeID1, 258 nodeID2, 259 }, 260 }, 261 { 262 name: "top 50% takes larger validator", 263 validators: []validator{ 264 { 265 nodeID: nodeID1, 266 weight: 2, 267 }, 268 { 269 nodeID: nodeID2, 270 weight: 1, 271 }, 272 }, 273 percentage: .5, 274 expected: []ids.NodeID{ 275 nodeID1, 276 }, 277 }, 278 { 279 name: "top 50% bound", 280 validators: []validator{ 281 { 282 nodeID: nodeID1, 283 weight: 2, 284 }, 285 { 286 nodeID: nodeID2, 287 weight: 1, 288 }, 289 { 290 nodeID: nodeID3, 291 weight: 1, 292 }, 293 }, 294 percentage: .5, 295 expected: []ids.NodeID{ 296 nodeID1, 297 }, 298 }, 299 } 300 for _, test := range tests { 301 t.Run(test.name, func(t *testing.T) { 302 require := require.New(t) 303 ctrl := gomock.NewController(t) 304 305 validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0) 306 for _, validator := range test.validators { 307 validatorSet[validator.nodeID] = &validators.GetValidatorOutput{ 308 NodeID: validator.nodeID, 309 Weight: validator.weight, 310 } 311 } 312 313 subnetID := ids.GenerateTestID() 314 mockValidators := validatorsmock.NewState(ctrl) 315 316 mockValidators.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(1), nil) 317 mockValidators.EXPECT().GetValidatorSet(gomock.Any(), uint64(1), subnetID).Return(validatorSet, nil) 318 319 network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "") 320 require.NoError(err) 321 322 ctx := context.Background() 323 require.NoError(network.Connected(ctx, nodeID1, nil)) 324 require.NoError(network.Connected(ctx, nodeID2, nil)) 325 326 v := NewValidators(network.Peers, network.log, subnetID, mockValidators, time.Second) 327 nodeIDs := v.Top(ctx, test.percentage) 328 require.Equal(test.expected, nodeIDs) 329 }) 330 } 331 }