github.com/MetalBlockchain/metalgo@v1.11.9/vms/platformvm/warp/validator_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package warp 5 6 import ( 7 "context" 8 "math" 9 "strconv" 10 "testing" 11 12 "github.com/stretchr/testify/require" 13 "go.uber.org/mock/gomock" 14 15 "github.com/MetalBlockchain/metalgo/ids" 16 "github.com/MetalBlockchain/metalgo/snow/validators" 17 "github.com/MetalBlockchain/metalgo/utils/crypto/bls" 18 "github.com/MetalBlockchain/metalgo/utils/set" 19 ) 20 21 func TestGetCanonicalValidatorSet(t *testing.T) { 22 type test struct { 23 name string 24 stateF func(*gomock.Controller) validators.State 25 expectedVdrs []*Validator 26 expectedWeight uint64 27 expectedErr error 28 } 29 30 tests := []test{ 31 { 32 name: "can't get validator set", 33 stateF: func(ctrl *gomock.Controller) validators.State { 34 state := validators.NewMockState(ctrl) 35 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(nil, errTest) 36 return state 37 }, 38 expectedErr: errTest, 39 }, 40 { 41 name: "all validators have public keys; no duplicate pub keys", 42 stateF: func(ctrl *gomock.Controller) validators.State { 43 state := validators.NewMockState(ctrl) 44 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return( 45 map[ids.NodeID]*validators.GetValidatorOutput{ 46 testVdrs[0].nodeID: { 47 NodeID: testVdrs[0].nodeID, 48 PublicKey: testVdrs[0].vdr.PublicKey, 49 Weight: testVdrs[0].vdr.Weight, 50 }, 51 testVdrs[1].nodeID: { 52 NodeID: testVdrs[1].nodeID, 53 PublicKey: testVdrs[1].vdr.PublicKey, 54 Weight: testVdrs[1].vdr.Weight, 55 }, 56 }, 57 nil, 58 ) 59 return state 60 }, 61 expectedVdrs: []*Validator{testVdrs[0].vdr, testVdrs[1].vdr}, 62 expectedWeight: 6, 63 expectedErr: nil, 64 }, 65 { 66 name: "all validators have public keys; duplicate pub keys", 67 stateF: func(ctrl *gomock.Controller) validators.State { 68 state := validators.NewMockState(ctrl) 69 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return( 70 map[ids.NodeID]*validators.GetValidatorOutput{ 71 testVdrs[0].nodeID: { 72 NodeID: testVdrs[0].nodeID, 73 PublicKey: testVdrs[0].vdr.PublicKey, 74 Weight: testVdrs[0].vdr.Weight, 75 }, 76 testVdrs[1].nodeID: { 77 NodeID: testVdrs[1].nodeID, 78 PublicKey: testVdrs[1].vdr.PublicKey, 79 Weight: testVdrs[1].vdr.Weight, 80 }, 81 testVdrs[2].nodeID: { 82 NodeID: testVdrs[2].nodeID, 83 PublicKey: testVdrs[0].vdr.PublicKey, 84 Weight: testVdrs[0].vdr.Weight, 85 }, 86 }, 87 nil, 88 ) 89 return state 90 }, 91 expectedVdrs: []*Validator{ 92 { 93 PublicKey: testVdrs[0].vdr.PublicKey, 94 PublicKeyBytes: testVdrs[0].vdr.PublicKeyBytes, 95 Weight: testVdrs[0].vdr.Weight * 2, 96 NodeIDs: []ids.NodeID{ 97 testVdrs[0].nodeID, 98 testVdrs[2].nodeID, 99 }, 100 }, 101 testVdrs[1].vdr, 102 }, 103 expectedWeight: 9, 104 expectedErr: nil, 105 }, 106 { 107 name: "validator without public key; no duplicate pub keys", 108 stateF: func(ctrl *gomock.Controller) validators.State { 109 state := validators.NewMockState(ctrl) 110 state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return( 111 map[ids.NodeID]*validators.GetValidatorOutput{ 112 testVdrs[0].nodeID: { 113 NodeID: testVdrs[0].nodeID, 114 PublicKey: nil, 115 Weight: testVdrs[0].vdr.Weight, 116 }, 117 testVdrs[1].nodeID: { 118 NodeID: testVdrs[1].nodeID, 119 PublicKey: testVdrs[1].vdr.PublicKey, 120 Weight: testVdrs[1].vdr.Weight, 121 }, 122 }, 123 nil, 124 ) 125 return state 126 }, 127 expectedVdrs: []*Validator{testVdrs[1].vdr}, 128 expectedWeight: 6, 129 expectedErr: nil, 130 }, 131 } 132 133 for _, tt := range tests { 134 t.Run(tt.name, func(t *testing.T) { 135 require := require.New(t) 136 ctrl := gomock.NewController(t) 137 138 state := tt.stateF(ctrl) 139 140 vdrs, weight, err := GetCanonicalValidatorSet(context.Background(), state, pChainHeight, subnetID) 141 require.ErrorIs(err, tt.expectedErr) 142 if err != nil { 143 return 144 } 145 require.Equal(tt.expectedWeight, weight) 146 147 // These are pointers so have to test equality like this 148 require.Len(vdrs, len(tt.expectedVdrs)) 149 for i, expectedVdr := range tt.expectedVdrs { 150 gotVdr := vdrs[i] 151 expectedPKBytes := bls.PublicKeyToCompressedBytes(expectedVdr.PublicKey) 152 gotPKBytes := bls.PublicKeyToCompressedBytes(gotVdr.PublicKey) 153 require.Equal(expectedPKBytes, gotPKBytes) 154 require.Equal(expectedVdr.PublicKeyBytes, gotVdr.PublicKeyBytes) 155 require.Equal(expectedVdr.Weight, gotVdr.Weight) 156 require.ElementsMatch(expectedVdr.NodeIDs, gotVdr.NodeIDs) 157 } 158 }) 159 } 160 } 161 162 func TestFilterValidators(t *testing.T) { 163 sk0, err := bls.NewSecretKey() 164 require.NoError(t, err) 165 pk0 := bls.PublicFromSecretKey(sk0) 166 vdr0 := &Validator{ 167 PublicKey: pk0, 168 PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk0), 169 Weight: 1, 170 } 171 172 sk1, err := bls.NewSecretKey() 173 require.NoError(t, err) 174 pk1 := bls.PublicFromSecretKey(sk1) 175 vdr1 := &Validator{ 176 PublicKey: pk1, 177 PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk1), 178 Weight: 2, 179 } 180 181 type test struct { 182 name string 183 indices set.Bits 184 vdrs []*Validator 185 expectedVdrs []*Validator 186 expectedErr error 187 } 188 189 tests := []test{ 190 { 191 name: "empty", 192 indices: set.NewBits(), 193 vdrs: []*Validator{}, 194 expectedVdrs: []*Validator{}, 195 expectedErr: nil, 196 }, 197 { 198 name: "unknown validator", 199 indices: set.NewBits(2), 200 vdrs: []*Validator{vdr0, vdr1}, 201 expectedErr: ErrUnknownValidator, 202 }, 203 { 204 name: "two filtered out", 205 indices: set.NewBits(), 206 vdrs: []*Validator{ 207 vdr0, 208 vdr1, 209 }, 210 expectedVdrs: []*Validator{}, 211 expectedErr: nil, 212 }, 213 { 214 name: "one filtered out", 215 indices: set.NewBits(1), 216 vdrs: []*Validator{ 217 vdr0, 218 vdr1, 219 }, 220 expectedVdrs: []*Validator{ 221 vdr1, 222 }, 223 expectedErr: nil, 224 }, 225 { 226 name: "none filtered out", 227 indices: set.NewBits(0, 1), 228 vdrs: []*Validator{ 229 vdr0, 230 vdr1, 231 }, 232 expectedVdrs: []*Validator{ 233 vdr0, 234 vdr1, 235 }, 236 expectedErr: nil, 237 }, 238 } 239 240 for _, tt := range tests { 241 t.Run(tt.name, func(t *testing.T) { 242 require := require.New(t) 243 244 vdrs, err := FilterValidators(tt.indices, tt.vdrs) 245 require.ErrorIs(err, tt.expectedErr) 246 if tt.expectedErr != nil { 247 return 248 } 249 require.Equal(tt.expectedVdrs, vdrs) 250 }) 251 } 252 } 253 254 func TestSumWeight(t *testing.T) { 255 vdr0 := &Validator{ 256 Weight: 1, 257 } 258 vdr1 := &Validator{ 259 Weight: 2, 260 } 261 vdr2 := &Validator{ 262 Weight: math.MaxUint64, 263 } 264 265 type test struct { 266 name string 267 vdrs []*Validator 268 expectedSum uint64 269 expectedErr error 270 } 271 272 tests := []test{ 273 { 274 name: "empty", 275 vdrs: []*Validator{}, 276 expectedSum: 0, 277 }, 278 { 279 name: "one", 280 vdrs: []*Validator{vdr0}, 281 expectedSum: 1, 282 }, 283 { 284 name: "two", 285 vdrs: []*Validator{vdr0, vdr1}, 286 expectedSum: 3, 287 }, 288 { 289 name: "overflow", 290 vdrs: []*Validator{vdr0, vdr2}, 291 expectedErr: ErrWeightOverflow, 292 }, 293 } 294 295 for _, tt := range tests { 296 t.Run(tt.name, func(t *testing.T) { 297 require := require.New(t) 298 299 sum, err := SumWeight(tt.vdrs) 300 require.ErrorIs(err, tt.expectedErr) 301 if tt.expectedErr != nil { 302 return 303 } 304 require.Equal(tt.expectedSum, sum) 305 }) 306 } 307 } 308 309 func BenchmarkGetCanonicalValidatorSet(b *testing.B) { 310 pChainHeight := uint64(1) 311 subnetID := ids.GenerateTestID() 312 numNodes := 10_000 313 getValidatorOutputs := make([]*validators.GetValidatorOutput, 0, numNodes) 314 for i := 0; i < numNodes; i++ { 315 nodeID := ids.GenerateTestNodeID() 316 blsPrivateKey, err := bls.NewSecretKey() 317 require.NoError(b, err) 318 blsPublicKey := bls.PublicFromSecretKey(blsPrivateKey) 319 getValidatorOutputs = append(getValidatorOutputs, &validators.GetValidatorOutput{ 320 NodeID: nodeID, 321 PublicKey: blsPublicKey, 322 Weight: 20, 323 }) 324 } 325 326 for _, size := range []int{0, 1, 10, 100, 1_000, 10_000} { 327 getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput) 328 for i := 0; i < size; i++ { 329 validator := getValidatorOutputs[i] 330 getValidatorsOutput[validator.NodeID] = validator 331 } 332 validatorState := &validators.TestState{ 333 GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { 334 return getValidatorsOutput, nil 335 }, 336 } 337 338 b.Run(strconv.Itoa(size), func(b *testing.B) { 339 for i := 0; i < b.N; i++ { 340 _, _, err := GetCanonicalValidatorSet(context.Background(), validatorState, pChainHeight, subnetID) 341 require.NoError(b, err) 342 } 343 }) 344 } 345 }