github.com/MetalBlockchain/metalgo@v1.11.9/network/p2p/validators_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package p2p 5 6 import ( 7 "context" 8 "errors" 9 "testing" 10 "time" 11 12 "github.com/prometheus/client_golang/prometheus" 13 "github.com/stretchr/testify/require" 14 "go.uber.org/mock/gomock" 15 16 "github.com/MetalBlockchain/metalgo/ids" 17 "github.com/MetalBlockchain/metalgo/snow/engine/common" 18 "github.com/MetalBlockchain/metalgo/snow/validators" 19 "github.com/MetalBlockchain/metalgo/utils/logging" 20 ) 21 22 func TestValidatorsSample(t *testing.T) { 23 errFoobar := errors.New("foobar") 24 nodeID1 := ids.GenerateTestNodeID() 25 nodeID2 := ids.GenerateTestNodeID() 26 nodeID3 := ids.GenerateTestNodeID() 27 28 type call struct { 29 limit int 30 31 time time.Time 32 33 height uint64 34 getCurrentHeightErr error 35 36 validators []ids.NodeID 37 getValidatorSetErr error 38 39 // superset of possible values in the result 40 expected []ids.NodeID 41 } 42 43 tests := []struct { 44 name string 45 maxStaleness time.Duration 46 calls []call 47 }{ 48 { 49 // if we aren't connected to a validator, we shouldn't return it 50 name: "drop disconnected validators", 51 maxStaleness: time.Hour, 52 calls: []call{ 53 { 54 time: time.Time{}.Add(time.Second), 55 limit: 2, 56 height: 1, 57 validators: []ids.NodeID{nodeID1, nodeID3}, 58 expected: []ids.NodeID{nodeID1}, 59 }, 60 }, 61 }, 62 { 63 // if we don't have as many validators as requested by the caller, 64 // we should return all the validators we have 65 name: "less than limit validators", 66 maxStaleness: time.Hour, 67 calls: []call{ 68 { 69 time: time.Time{}.Add(time.Second), 70 limit: 2, 71 height: 1, 72 validators: []ids.NodeID{nodeID1}, 73 expected: []ids.NodeID{nodeID1}, 74 }, 75 }, 76 }, 77 { 78 // if we have as many validators as requested by the caller, we 79 // should return all the validators we have 80 name: "equal to limit validators", 81 maxStaleness: time.Hour, 82 calls: []call{ 83 { 84 time: time.Time{}.Add(time.Second), 85 limit: 1, 86 height: 1, 87 validators: []ids.NodeID{nodeID1}, 88 expected: []ids.NodeID{nodeID1}, 89 }, 90 }, 91 }, 92 { 93 // if we have less validators than requested by the caller, we 94 // should return a subset of the validators that we have 95 name: "less than limit validators", 96 maxStaleness: time.Hour, 97 calls: []call{ 98 { 99 time: time.Time{}.Add(time.Second), 100 limit: 1, 101 height: 1, 102 validators: []ids.NodeID{nodeID1, nodeID2}, 103 expected: []ids.NodeID{nodeID1, nodeID2}, 104 }, 105 }, 106 }, 107 { 108 name: "within max staleness threshold", 109 maxStaleness: time.Hour, 110 calls: []call{ 111 { 112 time: time.Time{}.Add(time.Second), 113 limit: 1, 114 height: 1, 115 validators: []ids.NodeID{nodeID1}, 116 expected: []ids.NodeID{nodeID1}, 117 }, 118 }, 119 }, 120 { 121 name: "beyond max staleness threshold", 122 maxStaleness: time.Hour, 123 calls: []call{ 124 { 125 limit: 1, 126 time: time.Time{}.Add(time.Hour), 127 height: 1, 128 validators: []ids.NodeID{nodeID1}, 129 expected: []ids.NodeID{nodeID1}, 130 }, 131 }, 132 }, 133 { 134 name: "fail to get current height", 135 maxStaleness: time.Second, 136 calls: []call{ 137 { 138 limit: 1, 139 time: time.Time{}.Add(time.Hour), 140 getCurrentHeightErr: errFoobar, 141 expected: []ids.NodeID{}, 142 }, 143 }, 144 }, 145 { 146 name: "second get validator set call fails", 147 maxStaleness: time.Minute, 148 calls: []call{ 149 { 150 limit: 1, 151 time: time.Time{}.Add(time.Second), 152 height: 1, 153 validators: []ids.NodeID{nodeID1}, 154 expected: []ids.NodeID{nodeID1}, 155 }, 156 { 157 limit: 1, 158 time: time.Time{}.Add(time.Hour), 159 height: 1, 160 getValidatorSetErr: errFoobar, 161 expected: []ids.NodeID{}, 162 }, 163 }, 164 }, 165 } 166 167 for _, tt := range tests { 168 t.Run(tt.name, func(t *testing.T) { 169 require := require.New(t) 170 subnetID := ids.GenerateTestID() 171 ctrl := gomock.NewController(t) 172 mockValidators := validators.NewMockState(ctrl) 173 174 calls := make([]any, 0) 175 for _, call := range tt.calls { 176 calls = append(calls, mockValidators.EXPECT(). 177 GetCurrentHeight(gomock.Any()).Return(call.height, call.getCurrentHeightErr)) 178 179 if call.getCurrentHeightErr != nil { 180 continue 181 } 182 183 validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0) 184 for _, validator := range call.validators { 185 validatorSet[validator] = &validators.GetValidatorOutput{ 186 NodeID: validator, 187 Weight: 1, 188 } 189 } 190 191 calls = append(calls, 192 mockValidators.EXPECT(). 193 GetValidatorSet(gomock.Any(), gomock.Any(), subnetID). 194 Return(validatorSet, call.getValidatorSetErr)) 195 } 196 gomock.InOrder(calls...) 197 198 network, err := NewNetwork(logging.NoLog{}, &common.FakeSender{}, prometheus.NewRegistry(), "") 199 require.NoError(err) 200 201 ctx := context.Background() 202 require.NoError(network.Connected(ctx, nodeID1, nil)) 203 require.NoError(network.Connected(ctx, nodeID2, nil)) 204 205 v := NewValidators(network.Peers, network.log, subnetID, mockValidators, tt.maxStaleness) 206 for _, call := range tt.calls { 207 v.lastUpdated = call.time 208 sampled := v.Sample(ctx, call.limit) 209 require.LessOrEqual(len(sampled), call.limit) 210 require.Subset(call.expected, sampled) 211 } 212 }) 213 } 214 } 215 216 func TestValidatorsTop(t *testing.T) { 217 nodeID1 := ids.GenerateTestNodeID() 218 nodeID2 := ids.GenerateTestNodeID() 219 nodeID3 := ids.GenerateTestNodeID() 220 221 tests := []struct { 222 name string 223 validators []validator 224 percentage float64 225 expected []ids.NodeID 226 }{ 227 { 228 name: "top 0% is empty", 229 validators: []validator{ 230 { 231 nodeID: nodeID1, 232 weight: 1, 233 }, 234 { 235 nodeID: nodeID2, 236 weight: 1, 237 }, 238 }, 239 percentage: 0, 240 expected: []ids.NodeID{}, 241 }, 242 { 243 name: "top 100% is full", 244 validators: []validator{ 245 { 246 nodeID: nodeID1, 247 weight: 2, 248 }, 249 { 250 nodeID: nodeID2, 251 weight: 1, 252 }, 253 }, 254 percentage: 1, 255 expected: []ids.NodeID{ 256 nodeID1, 257 nodeID2, 258 }, 259 }, 260 { 261 name: "top 50% takes larger validator", 262 validators: []validator{ 263 { 264 nodeID: nodeID1, 265 weight: 2, 266 }, 267 { 268 nodeID: nodeID2, 269 weight: 1, 270 }, 271 }, 272 percentage: .5, 273 expected: []ids.NodeID{ 274 nodeID1, 275 }, 276 }, 277 { 278 name: "top 50% bound", 279 validators: []validator{ 280 { 281 nodeID: nodeID1, 282 weight: 2, 283 }, 284 { 285 nodeID: nodeID2, 286 weight: 1, 287 }, 288 { 289 nodeID: nodeID3, 290 weight: 1, 291 }, 292 }, 293 percentage: .5, 294 expected: []ids.NodeID{ 295 nodeID1, 296 }, 297 }, 298 } 299 for _, test := range tests { 300 t.Run(test.name, func(t *testing.T) { 301 require := require.New(t) 302 ctrl := gomock.NewController(t) 303 304 validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, 0) 305 for _, validator := range test.validators { 306 validatorSet[validator.nodeID] = &validators.GetValidatorOutput{ 307 NodeID: validator.nodeID, 308 Weight: validator.weight, 309 } 310 } 311 312 subnetID := ids.GenerateTestID() 313 mockValidators := validators.NewMockState(ctrl) 314 315 mockValidators.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(1), nil) 316 mockValidators.EXPECT().GetValidatorSet(gomock.Any(), uint64(1), subnetID).Return(validatorSet, nil) 317 318 network, err := NewNetwork(logging.NoLog{}, &common.FakeSender{}, prometheus.NewRegistry(), "") 319 require.NoError(err) 320 321 ctx := context.Background() 322 require.NoError(network.Connected(ctx, nodeID1, nil)) 323 require.NoError(network.Connected(ctx, nodeID2, nil)) 324 325 v := NewValidators(network.Peers, network.log, subnetID, mockValidators, time.Second) 326 nodeIDs := v.Top(ctx, test.percentage) 327 require.Equal(test.expected, nodeIDs) 328 }) 329 } 330 }