github.com/ava-labs/avalanchego@v1.11.11/vms/platformvm/warp/validator_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package warp
     5  
     6  import (
     7  	"context"
     8  	"math"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/require"
    13  	"go.uber.org/mock/gomock"
    14  
    15  	"github.com/ava-labs/avalanchego/ids"
    16  	"github.com/ava-labs/avalanchego/snow/validators"
    17  	"github.com/ava-labs/avalanchego/snow/validators/validatorsmock"
    18  	"github.com/ava-labs/avalanchego/snow/validators/validatorstest"
    19  	"github.com/ava-labs/avalanchego/utils/crypto/bls"
    20  	"github.com/ava-labs/avalanchego/utils/set"
    21  )
    22  
    23  func TestGetCanonicalValidatorSet(t *testing.T) {
    24  	type test struct {
    25  		name           string
    26  		stateF         func(*gomock.Controller) validators.State
    27  		expectedVdrs   []*Validator
    28  		expectedWeight uint64
    29  		expectedErr    error
    30  	}
    31  
    32  	tests := []test{
    33  		{
    34  			name: "can't get validator set",
    35  			stateF: func(ctrl *gomock.Controller) validators.State {
    36  				state := validatorsmock.NewState(ctrl)
    37  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(nil, errTest)
    38  				return state
    39  			},
    40  			expectedErr: errTest,
    41  		},
    42  		{
    43  			name: "all validators have public keys; no duplicate pub keys",
    44  			stateF: func(ctrl *gomock.Controller) validators.State {
    45  				state := validatorsmock.NewState(ctrl)
    46  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(
    47  					map[ids.NodeID]*validators.GetValidatorOutput{
    48  						testVdrs[0].nodeID: {
    49  							NodeID:    testVdrs[0].nodeID,
    50  							PublicKey: testVdrs[0].vdr.PublicKey,
    51  							Weight:    testVdrs[0].vdr.Weight,
    52  						},
    53  						testVdrs[1].nodeID: {
    54  							NodeID:    testVdrs[1].nodeID,
    55  							PublicKey: testVdrs[1].vdr.PublicKey,
    56  							Weight:    testVdrs[1].vdr.Weight,
    57  						},
    58  					},
    59  					nil,
    60  				)
    61  				return state
    62  			},
    63  			expectedVdrs:   []*Validator{testVdrs[0].vdr, testVdrs[1].vdr},
    64  			expectedWeight: 6,
    65  			expectedErr:    nil,
    66  		},
    67  		{
    68  			name: "all validators have public keys; duplicate pub keys",
    69  			stateF: func(ctrl *gomock.Controller) validators.State {
    70  				state := validatorsmock.NewState(ctrl)
    71  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(
    72  					map[ids.NodeID]*validators.GetValidatorOutput{
    73  						testVdrs[0].nodeID: {
    74  							NodeID:    testVdrs[0].nodeID,
    75  							PublicKey: testVdrs[0].vdr.PublicKey,
    76  							Weight:    testVdrs[0].vdr.Weight,
    77  						},
    78  						testVdrs[1].nodeID: {
    79  							NodeID:    testVdrs[1].nodeID,
    80  							PublicKey: testVdrs[1].vdr.PublicKey,
    81  							Weight:    testVdrs[1].vdr.Weight,
    82  						},
    83  						testVdrs[2].nodeID: {
    84  							NodeID:    testVdrs[2].nodeID,
    85  							PublicKey: testVdrs[0].vdr.PublicKey,
    86  							Weight:    testVdrs[0].vdr.Weight,
    87  						},
    88  					},
    89  					nil,
    90  				)
    91  				return state
    92  			},
    93  			expectedVdrs: []*Validator{
    94  				{
    95  					PublicKey:      testVdrs[0].vdr.PublicKey,
    96  					PublicKeyBytes: testVdrs[0].vdr.PublicKeyBytes,
    97  					Weight:         testVdrs[0].vdr.Weight * 2,
    98  					NodeIDs: []ids.NodeID{
    99  						testVdrs[0].nodeID,
   100  						testVdrs[2].nodeID,
   101  					},
   102  				},
   103  				testVdrs[1].vdr,
   104  			},
   105  			expectedWeight: 9,
   106  			expectedErr:    nil,
   107  		},
   108  		{
   109  			name: "validator without public key; no duplicate pub keys",
   110  			stateF: func(ctrl *gomock.Controller) validators.State {
   111  				state := validatorsmock.NewState(ctrl)
   112  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(
   113  					map[ids.NodeID]*validators.GetValidatorOutput{
   114  						testVdrs[0].nodeID: {
   115  							NodeID:    testVdrs[0].nodeID,
   116  							PublicKey: nil,
   117  							Weight:    testVdrs[0].vdr.Weight,
   118  						},
   119  						testVdrs[1].nodeID: {
   120  							NodeID:    testVdrs[1].nodeID,
   121  							PublicKey: testVdrs[1].vdr.PublicKey,
   122  							Weight:    testVdrs[1].vdr.Weight,
   123  						},
   124  					},
   125  					nil,
   126  				)
   127  				return state
   128  			},
   129  			expectedVdrs:   []*Validator{testVdrs[1].vdr},
   130  			expectedWeight: 6,
   131  			expectedErr:    nil,
   132  		},
   133  	}
   134  
   135  	for _, tt := range tests {
   136  		t.Run(tt.name, func(t *testing.T) {
   137  			require := require.New(t)
   138  			ctrl := gomock.NewController(t)
   139  
   140  			state := tt.stateF(ctrl)
   141  
   142  			vdrs, weight, err := GetCanonicalValidatorSet(context.Background(), state, pChainHeight, subnetID)
   143  			require.ErrorIs(err, tt.expectedErr)
   144  			if err != nil {
   145  				return
   146  			}
   147  			require.Equal(tt.expectedWeight, weight)
   148  
   149  			// These are pointers so have to test equality like this
   150  			require.Len(vdrs, len(tt.expectedVdrs))
   151  			for i, expectedVdr := range tt.expectedVdrs {
   152  				gotVdr := vdrs[i]
   153  				expectedPKBytes := bls.PublicKeyToCompressedBytes(expectedVdr.PublicKey)
   154  				gotPKBytes := bls.PublicKeyToCompressedBytes(gotVdr.PublicKey)
   155  				require.Equal(expectedPKBytes, gotPKBytes)
   156  				require.Equal(expectedVdr.PublicKeyBytes, gotVdr.PublicKeyBytes)
   157  				require.Equal(expectedVdr.Weight, gotVdr.Weight)
   158  				require.ElementsMatch(expectedVdr.NodeIDs, gotVdr.NodeIDs)
   159  			}
   160  		})
   161  	}
   162  }
   163  
   164  func TestFilterValidators(t *testing.T) {
   165  	sk0, err := bls.NewSecretKey()
   166  	require.NoError(t, err)
   167  	pk0 := bls.PublicFromSecretKey(sk0)
   168  	vdr0 := &Validator{
   169  		PublicKey:      pk0,
   170  		PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk0),
   171  		Weight:         1,
   172  	}
   173  
   174  	sk1, err := bls.NewSecretKey()
   175  	require.NoError(t, err)
   176  	pk1 := bls.PublicFromSecretKey(sk1)
   177  	vdr1 := &Validator{
   178  		PublicKey:      pk1,
   179  		PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk1),
   180  		Weight:         2,
   181  	}
   182  
   183  	type test struct {
   184  		name         string
   185  		indices      set.Bits
   186  		vdrs         []*Validator
   187  		expectedVdrs []*Validator
   188  		expectedErr  error
   189  	}
   190  
   191  	tests := []test{
   192  		{
   193  			name:         "empty",
   194  			indices:      set.NewBits(),
   195  			vdrs:         []*Validator{},
   196  			expectedVdrs: []*Validator{},
   197  			expectedErr:  nil,
   198  		},
   199  		{
   200  			name:        "unknown validator",
   201  			indices:     set.NewBits(2),
   202  			vdrs:        []*Validator{vdr0, vdr1},
   203  			expectedErr: ErrUnknownValidator,
   204  		},
   205  		{
   206  			name:    "two filtered out",
   207  			indices: set.NewBits(),
   208  			vdrs: []*Validator{
   209  				vdr0,
   210  				vdr1,
   211  			},
   212  			expectedVdrs: []*Validator{},
   213  			expectedErr:  nil,
   214  		},
   215  		{
   216  			name:    "one filtered out",
   217  			indices: set.NewBits(1),
   218  			vdrs: []*Validator{
   219  				vdr0,
   220  				vdr1,
   221  			},
   222  			expectedVdrs: []*Validator{
   223  				vdr1,
   224  			},
   225  			expectedErr: nil,
   226  		},
   227  		{
   228  			name:    "none filtered out",
   229  			indices: set.NewBits(0, 1),
   230  			vdrs: []*Validator{
   231  				vdr0,
   232  				vdr1,
   233  			},
   234  			expectedVdrs: []*Validator{
   235  				vdr0,
   236  				vdr1,
   237  			},
   238  			expectedErr: nil,
   239  		},
   240  	}
   241  
   242  	for _, tt := range tests {
   243  		t.Run(tt.name, func(t *testing.T) {
   244  			require := require.New(t)
   245  
   246  			vdrs, err := FilterValidators(tt.indices, tt.vdrs)
   247  			require.ErrorIs(err, tt.expectedErr)
   248  			if tt.expectedErr != nil {
   249  				return
   250  			}
   251  			require.Equal(tt.expectedVdrs, vdrs)
   252  		})
   253  	}
   254  }
   255  
   256  func TestSumWeight(t *testing.T) {
   257  	vdr0 := &Validator{
   258  		Weight: 1,
   259  	}
   260  	vdr1 := &Validator{
   261  		Weight: 2,
   262  	}
   263  	vdr2 := &Validator{
   264  		Weight: math.MaxUint64,
   265  	}
   266  
   267  	type test struct {
   268  		name        string
   269  		vdrs        []*Validator
   270  		expectedSum uint64
   271  		expectedErr error
   272  	}
   273  
   274  	tests := []test{
   275  		{
   276  			name:        "empty",
   277  			vdrs:        []*Validator{},
   278  			expectedSum: 0,
   279  		},
   280  		{
   281  			name:        "one",
   282  			vdrs:        []*Validator{vdr0},
   283  			expectedSum: 1,
   284  		},
   285  		{
   286  			name:        "two",
   287  			vdrs:        []*Validator{vdr0, vdr1},
   288  			expectedSum: 3,
   289  		},
   290  		{
   291  			name:        "overflow",
   292  			vdrs:        []*Validator{vdr0, vdr2},
   293  			expectedErr: ErrWeightOverflow,
   294  		},
   295  	}
   296  
   297  	for _, tt := range tests {
   298  		t.Run(tt.name, func(t *testing.T) {
   299  			require := require.New(t)
   300  
   301  			sum, err := SumWeight(tt.vdrs)
   302  			require.ErrorIs(err, tt.expectedErr)
   303  			if tt.expectedErr != nil {
   304  				return
   305  			}
   306  			require.Equal(tt.expectedSum, sum)
   307  		})
   308  	}
   309  }
   310  
   311  func BenchmarkGetCanonicalValidatorSet(b *testing.B) {
   312  	pChainHeight := uint64(1)
   313  	subnetID := ids.GenerateTestID()
   314  	numNodes := 10_000
   315  	getValidatorOutputs := make([]*validators.GetValidatorOutput, 0, numNodes)
   316  	for i := 0; i < numNodes; i++ {
   317  		nodeID := ids.GenerateTestNodeID()
   318  		blsPrivateKey, err := bls.NewSecretKey()
   319  		require.NoError(b, err)
   320  		blsPublicKey := bls.PublicFromSecretKey(blsPrivateKey)
   321  		getValidatorOutputs = append(getValidatorOutputs, &validators.GetValidatorOutput{
   322  			NodeID:    nodeID,
   323  			PublicKey: blsPublicKey,
   324  			Weight:    20,
   325  		})
   326  	}
   327  
   328  	for _, size := range []int{0, 1, 10, 100, 1_000, 10_000} {
   329  		getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput)
   330  		for i := 0; i < size; i++ {
   331  			validator := getValidatorOutputs[i]
   332  			getValidatorsOutput[validator.NodeID] = validator
   333  		}
   334  		validatorState := &validatorstest.State{
   335  			GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
   336  				return getValidatorsOutput, nil
   337  			},
   338  		}
   339  
   340  		b.Run(strconv.Itoa(size), func(b *testing.B) {
   341  			for i := 0; i < b.N; i++ {
   342  				_, _, err := GetCanonicalValidatorSet(context.Background(), validatorState, pChainHeight, subnetID)
   343  				require.NoError(b, err)
   344  			}
   345  		})
   346  	}
   347  }