github.com/ava-labs/avalanchego@v1.11.11/vms/platformvm/warp/validator_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package warp 5 6 import ( 7 "context" 8 "math" 9 "strconv" 10 "testing" 11 12 "github.com/stretchr/testify/require" 13 "go.uber.org/mock/gomock" 14 15 "github.com/ava-labs/avalanchego/ids" 16 "github.com/ava-labs/avalanchego/snow/validators" 17 "github.com/ava-labs/avalanchego/snow/validators/validatorsmock" 18 "github.com/ava-labs/avalanchego/snow/validators/validatorstest" 19 "github.com/ava-labs/avalanchego/utils/crypto/bls" 20 "github.com/ava-labs/avalanchego/utils/set" 21 ) 22 23 func TestGetCanonicalValidatorSet(t *testing.T) { 24 type test struct { 25 name string 26 stateF func(*gomock.Controller) validators.State 27 expectedVdrs []*Validator 28 expectedWeight uint64 29 expectedErr error 30 } 31 32 tests := []test{ 33 { 34 name: "can't get validator set", 35 stateF: func(ctrl *gomock.Controller) validators.State { 36 state := validatorsmock.NewState(ctrl) 37 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(nil, errTest) 38 return state 39 }, 40 expectedErr: errTest, 41 }, 42 { 43 name: "all validators have public keys; no duplicate pub keys", 44 stateF: func(ctrl *gomock.Controller) validators.State { 45 state := validatorsmock.NewState(ctrl) 46 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return( 47 map[ids.NodeID]*validators.GetValidatorOutput{ 48 testVdrs[0].nodeID: { 49 NodeID: testVdrs[0].nodeID, 50 PublicKey: testVdrs[0].vdr.PublicKey, 51 Weight: testVdrs[0].vdr.Weight, 52 }, 53 testVdrs[1].nodeID: { 54 NodeID: testVdrs[1].nodeID, 55 PublicKey: testVdrs[1].vdr.PublicKey, 56 Weight: testVdrs[1].vdr.Weight, 57 }, 58 }, 59 nil, 60 ) 61 return state 62 }, 63 expectedVdrs: []*Validator{testVdrs[0].vdr, testVdrs[1].vdr}, 64 expectedWeight: 6, 65 expectedErr: nil, 66 }, 67 { 68 name: "all validators have public keys; duplicate pub keys", 69 stateF: func(ctrl *gomock.Controller) validators.State { 70 state := validatorsmock.NewState(ctrl) 71 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return( 72 map[ids.NodeID]*validators.GetValidatorOutput{ 73 testVdrs[0].nodeID: { 74 NodeID: testVdrs[0].nodeID, 75 PublicKey: testVdrs[0].vdr.PublicKey, 76 Weight: testVdrs[0].vdr.Weight, 77 }, 78 testVdrs[1].nodeID: { 79 NodeID: testVdrs[1].nodeID, 80 PublicKey: testVdrs[1].vdr.PublicKey, 81 Weight: testVdrs[1].vdr.Weight, 82 }, 83 testVdrs[2].nodeID: { 84 NodeID: testVdrs[2].nodeID, 85 PublicKey: testVdrs[0].vdr.PublicKey, 86 Weight: testVdrs[0].vdr.Weight, 87 }, 88 }, 89 nil, 90 ) 91 return state 92 }, 93 expectedVdrs: []*Validator{ 94 { 95 PublicKey: testVdrs[0].vdr.PublicKey, 96 PublicKeyBytes: testVdrs[0].vdr.PublicKeyBytes, 97 Weight: testVdrs[0].vdr.Weight * 2, 98 NodeIDs: []ids.NodeID{ 99 testVdrs[0].nodeID, 100 testVdrs[2].nodeID, 101 }, 102 }, 103 testVdrs[1].vdr, 104 }, 105 expectedWeight: 9, 106 expectedErr: nil, 107 }, 108 { 109 name: "validator without public key; no duplicate pub keys", 110 stateF: func(ctrl *gomock.Controller) validators.State { 111 state := validatorsmock.NewState(ctrl) 112 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return( 113 map[ids.NodeID]*validators.GetValidatorOutput{ 114 testVdrs[0].nodeID: { 115 NodeID: testVdrs[0].nodeID, 116 PublicKey: nil, 117 Weight: testVdrs[0].vdr.Weight, 118 }, 119 testVdrs[1].nodeID: { 120 NodeID: testVdrs[1].nodeID, 121 PublicKey: testVdrs[1].vdr.PublicKey, 122 Weight: testVdrs[1].vdr.Weight, 123 }, 124 }, 125 nil, 126 ) 127 return state 128 }, 129 expectedVdrs: []*Validator{testVdrs[1].vdr}, 130 expectedWeight: 6, 131 expectedErr: nil, 132 }, 133 } 134 135 for _, tt := range tests { 136 t.Run(tt.name, func(t *testing.T) { 137 require := require.New(t) 138 ctrl := gomock.NewController(t) 139 140 state := tt.stateF(ctrl) 141 142 vdrs, weight, err := GetCanonicalValidatorSet(context.Background(), state, pChainHeight, subnetID) 143 require.ErrorIs(err, tt.expectedErr) 144 if err != nil { 145 return 146 } 147 require.Equal(tt.expectedWeight, weight) 148 149 // These are pointers so have to test equality like this 150 require.Len(vdrs, len(tt.expectedVdrs)) 151 for i, expectedVdr := range tt.expectedVdrs { 152 gotVdr := vdrs[i] 153 expectedPKBytes := bls.PublicKeyToCompressedBytes(expectedVdr.PublicKey) 154 gotPKBytes := bls.PublicKeyToCompressedBytes(gotVdr.PublicKey) 155 require.Equal(expectedPKBytes, gotPKBytes) 156 require.Equal(expectedVdr.PublicKeyBytes, gotVdr.PublicKeyBytes) 157 require.Equal(expectedVdr.Weight, gotVdr.Weight) 158 require.ElementsMatch(expectedVdr.NodeIDs, gotVdr.NodeIDs) 159 } 160 }) 161 } 162 } 163 164 func TestFilterValidators(t *testing.T) { 165 sk0, err := bls.NewSecretKey() 166 require.NoError(t, err) 167 pk0 := bls.PublicFromSecretKey(sk0) 168 vdr0 := &Validator{ 169 PublicKey: pk0, 170 PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk0), 171 Weight: 1, 172 } 173 174 sk1, err := bls.NewSecretKey() 175 require.NoError(t, err) 176 pk1 := bls.PublicFromSecretKey(sk1) 177 vdr1 := &Validator{ 178 PublicKey: pk1, 179 PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk1), 180 Weight: 2, 181 } 182 183 type test struct { 184 name string 185 indices set.Bits 186 vdrs []*Validator 187 expectedVdrs []*Validator 188 expectedErr error 189 } 190 191 tests := []test{ 192 { 193 name: "empty", 194 indices: set.NewBits(), 195 vdrs: []*Validator{}, 196 expectedVdrs: []*Validator{}, 197 expectedErr: nil, 198 }, 199 { 200 name: "unknown validator", 201 indices: set.NewBits(2), 202 vdrs: []*Validator{vdr0, vdr1}, 203 expectedErr: ErrUnknownValidator, 204 }, 205 { 206 name: "two filtered out", 207 indices: set.NewBits(), 208 vdrs: []*Validator{ 209 vdr0, 210 vdr1, 211 }, 212 expectedVdrs: []*Validator{}, 213 expectedErr: nil, 214 }, 215 { 216 name: "one filtered out", 217 indices: set.NewBits(1), 218 vdrs: []*Validator{ 219 vdr0, 220 vdr1, 221 }, 222 expectedVdrs: []*Validator{ 223 vdr1, 224 }, 225 expectedErr: nil, 226 }, 227 { 228 name: "none filtered out", 229 indices: set.NewBits(0, 1), 230 vdrs: []*Validator{ 231 vdr0, 232 vdr1, 233 }, 234 expectedVdrs: []*Validator{ 235 vdr0, 236 vdr1, 237 }, 238 expectedErr: nil, 239 }, 240 } 241 242 for _, tt := range tests { 243 t.Run(tt.name, func(t *testing.T) { 244 require := require.New(t) 245 246 vdrs, err := FilterValidators(tt.indices, tt.vdrs) 247 require.ErrorIs(err, tt.expectedErr) 248 if tt.expectedErr != nil { 249 return 250 } 251 require.Equal(tt.expectedVdrs, vdrs) 252 }) 253 } 254 } 255 256 func TestSumWeight(t *testing.T) { 257 vdr0 := &Validator{ 258 Weight: 1, 259 } 260 vdr1 := &Validator{ 261 Weight: 2, 262 } 263 vdr2 := &Validator{ 264 Weight: math.MaxUint64, 265 } 266 267 type test struct { 268 name string 269 vdrs []*Validator 270 expectedSum uint64 271 expectedErr error 272 } 273 274 tests := []test{ 275 { 276 name: "empty", 277 vdrs: []*Validator{}, 278 expectedSum: 0, 279 }, 280 { 281 name: "one", 282 vdrs: []*Validator{vdr0}, 283 expectedSum: 1, 284 }, 285 { 286 name: "two", 287 vdrs: []*Validator{vdr0, vdr1}, 288 expectedSum: 3, 289 }, 290 { 291 name: "overflow", 292 vdrs: []*Validator{vdr0, vdr2}, 293 expectedErr: ErrWeightOverflow, 294 }, 295 } 296 297 for _, tt := range tests { 298 t.Run(tt.name, func(t *testing.T) { 299 require := require.New(t) 300 301 sum, err := SumWeight(tt.vdrs) 302 require.ErrorIs(err, tt.expectedErr) 303 if tt.expectedErr != nil { 304 return 305 } 306 require.Equal(tt.expectedSum, sum) 307 }) 308 } 309 } 310 311 func BenchmarkGetCanonicalValidatorSet(b *testing.B) { 312 pChainHeight := uint64(1) 313 subnetID := ids.GenerateTestID() 314 numNodes := 10_000 315 getValidatorOutputs := make([]*validators.GetValidatorOutput, 0, numNodes) 316 for i := 0; i < numNodes; i++ { 317 nodeID := ids.GenerateTestNodeID() 318 blsPrivateKey, err := bls.NewSecretKey() 319 require.NoError(b, err) 320 blsPublicKey := bls.PublicFromSecretKey(blsPrivateKey) 321 getValidatorOutputs = append(getValidatorOutputs, &validators.GetValidatorOutput{ 322 NodeID: nodeID, 323 PublicKey: blsPublicKey, 324 Weight: 20, 325 }) 326 } 327 328 for _, size := range []int{0, 1, 10, 100, 1_000, 10_000} { 329 getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput) 330 for i := 0; i < size; i++ { 331 validator := getValidatorOutputs[i] 332 getValidatorsOutput[validator.NodeID] = validator 333 } 334 validatorState := &validatorstest.State{ 335 GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { 336 return getValidatorsOutput, nil 337 }, 338 } 339 340 b.Run(strconv.Itoa(size), func(b *testing.B) { 341 for i := 0; i < b.N; i++ { 342 _, _, err := GetCanonicalValidatorSet(context.Background(), validatorState, pChainHeight, subnetID) 343 require.NoError(b, err) 344 } 345 }) 346 } 347 }