github.com/MetalBlockchain/metalgo@v1.11.9/vms/platformvm/warp/validator_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package warp
     5  
     6  import (
     7  	"context"
     8  	"math"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/require"
    13  	"go.uber.org/mock/gomock"
    14  
    15  	"github.com/MetalBlockchain/metalgo/ids"
    16  	"github.com/MetalBlockchain/metalgo/snow/validators"
    17  	"github.com/MetalBlockchain/metalgo/utils/crypto/bls"
    18  	"github.com/MetalBlockchain/metalgo/utils/set"
    19  )
    20  
    21  func TestGetCanonicalValidatorSet(t *testing.T) {
    22  	type test struct {
    23  		name           string
    24  		stateF         func(*gomock.Controller) validators.State
    25  		expectedVdrs   []*Validator
    26  		expectedWeight uint64
    27  		expectedErr    error
    28  	}
    29  
    30  	tests := []test{
    31  		{
    32  			name: "can't get validator set",
    33  			stateF: func(ctrl *gomock.Controller) validators.State {
    34  				state := validators.NewMockState(ctrl)
    35  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(nil, errTest)
    36  				return state
    37  			},
    38  			expectedErr: errTest,
    39  		},
    40  		{
    41  			name: "all validators have public keys; no duplicate pub keys",
    42  			stateF: func(ctrl *gomock.Controller) validators.State {
    43  				state := validators.NewMockState(ctrl)
    44  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(
    45  					map[ids.NodeID]*validators.GetValidatorOutput{
    46  						testVdrs[0].nodeID: {
    47  							NodeID:    testVdrs[0].nodeID,
    48  							PublicKey: testVdrs[0].vdr.PublicKey,
    49  							Weight:    testVdrs[0].vdr.Weight,
    50  						},
    51  						testVdrs[1].nodeID: {
    52  							NodeID:    testVdrs[1].nodeID,
    53  							PublicKey: testVdrs[1].vdr.PublicKey,
    54  							Weight:    testVdrs[1].vdr.Weight,
    55  						},
    56  					},
    57  					nil,
    58  				)
    59  				return state
    60  			},
    61  			expectedVdrs:   []*Validator{testVdrs[0].vdr, testVdrs[1].vdr},
    62  			expectedWeight: 6,
    63  			expectedErr:    nil,
    64  		},
    65  		{
    66  			name: "all validators have public keys; duplicate pub keys",
    67  			stateF: func(ctrl *gomock.Controller) validators.State {
    68  				state := validators.NewMockState(ctrl)
    69  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(
    70  					map[ids.NodeID]*validators.GetValidatorOutput{
    71  						testVdrs[0].nodeID: {
    72  							NodeID:    testVdrs[0].nodeID,
    73  							PublicKey: testVdrs[0].vdr.PublicKey,
    74  							Weight:    testVdrs[0].vdr.Weight,
    75  						},
    76  						testVdrs[1].nodeID: {
    77  							NodeID:    testVdrs[1].nodeID,
    78  							PublicKey: testVdrs[1].vdr.PublicKey,
    79  							Weight:    testVdrs[1].vdr.Weight,
    80  						},
    81  						testVdrs[2].nodeID: {
    82  							NodeID:    testVdrs[2].nodeID,
    83  							PublicKey: testVdrs[0].vdr.PublicKey,
    84  							Weight:    testVdrs[0].vdr.Weight,
    85  						},
    86  					},
    87  					nil,
    88  				)
    89  				return state
    90  			},
    91  			expectedVdrs: []*Validator{
    92  				{
    93  					PublicKey:      testVdrs[0].vdr.PublicKey,
    94  					PublicKeyBytes: testVdrs[0].vdr.PublicKeyBytes,
    95  					Weight:         testVdrs[0].vdr.Weight * 2,
    96  					NodeIDs: []ids.NodeID{
    97  						testVdrs[0].nodeID,
    98  						testVdrs[2].nodeID,
    99  					},
   100  				},
   101  				testVdrs[1].vdr,
   102  			},
   103  			expectedWeight: 9,
   104  			expectedErr:    nil,
   105  		},
   106  		{
   107  			name: "validator without public key; no duplicate pub keys",
   108  			stateF: func(ctrl *gomock.Controller) validators.State {
   109  				state := validators.NewMockState(ctrl)
   110  				state.EXPECT().GetValidatorSet(gomock.Any(), pChainHeight, subnetID).Return(
   111  					map[ids.NodeID]*validators.GetValidatorOutput{
   112  						testVdrs[0].nodeID: {
   113  							NodeID:    testVdrs[0].nodeID,
   114  							PublicKey: nil,
   115  							Weight:    testVdrs[0].vdr.Weight,
   116  						},
   117  						testVdrs[1].nodeID: {
   118  							NodeID:    testVdrs[1].nodeID,
   119  							PublicKey: testVdrs[1].vdr.PublicKey,
   120  							Weight:    testVdrs[1].vdr.Weight,
   121  						},
   122  					},
   123  					nil,
   124  				)
   125  				return state
   126  			},
   127  			expectedVdrs:   []*Validator{testVdrs[1].vdr},
   128  			expectedWeight: 6,
   129  			expectedErr:    nil,
   130  		},
   131  	}
   132  
   133  	for _, tt := range tests {
   134  		t.Run(tt.name, func(t *testing.T) {
   135  			require := require.New(t)
   136  			ctrl := gomock.NewController(t)
   137  
   138  			state := tt.stateF(ctrl)
   139  
   140  			vdrs, weight, err := GetCanonicalValidatorSet(context.Background(), state, pChainHeight, subnetID)
   141  			require.ErrorIs(err, tt.expectedErr)
   142  			if err != nil {
   143  				return
   144  			}
   145  			require.Equal(tt.expectedWeight, weight)
   146  
   147  			// These are pointers so have to test equality like this
   148  			require.Len(vdrs, len(tt.expectedVdrs))
   149  			for i, expectedVdr := range tt.expectedVdrs {
   150  				gotVdr := vdrs[i]
   151  				expectedPKBytes := bls.PublicKeyToCompressedBytes(expectedVdr.PublicKey)
   152  				gotPKBytes := bls.PublicKeyToCompressedBytes(gotVdr.PublicKey)
   153  				require.Equal(expectedPKBytes, gotPKBytes)
   154  				require.Equal(expectedVdr.PublicKeyBytes, gotVdr.PublicKeyBytes)
   155  				require.Equal(expectedVdr.Weight, gotVdr.Weight)
   156  				require.ElementsMatch(expectedVdr.NodeIDs, gotVdr.NodeIDs)
   157  			}
   158  		})
   159  	}
   160  }
   161  
   162  func TestFilterValidators(t *testing.T) {
   163  	sk0, err := bls.NewSecretKey()
   164  	require.NoError(t, err)
   165  	pk0 := bls.PublicFromSecretKey(sk0)
   166  	vdr0 := &Validator{
   167  		PublicKey:      pk0,
   168  		PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk0),
   169  		Weight:         1,
   170  	}
   171  
   172  	sk1, err := bls.NewSecretKey()
   173  	require.NoError(t, err)
   174  	pk1 := bls.PublicFromSecretKey(sk1)
   175  	vdr1 := &Validator{
   176  		PublicKey:      pk1,
   177  		PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk1),
   178  		Weight:         2,
   179  	}
   180  
   181  	type test struct {
   182  		name         string
   183  		indices      set.Bits
   184  		vdrs         []*Validator
   185  		expectedVdrs []*Validator
   186  		expectedErr  error
   187  	}
   188  
   189  	tests := []test{
   190  		{
   191  			name:         "empty",
   192  			indices:      set.NewBits(),
   193  			vdrs:         []*Validator{},
   194  			expectedVdrs: []*Validator{},
   195  			expectedErr:  nil,
   196  		},
   197  		{
   198  			name:        "unknown validator",
   199  			indices:     set.NewBits(2),
   200  			vdrs:        []*Validator{vdr0, vdr1},
   201  			expectedErr: ErrUnknownValidator,
   202  		},
   203  		{
   204  			name:    "two filtered out",
   205  			indices: set.NewBits(),
   206  			vdrs: []*Validator{
   207  				vdr0,
   208  				vdr1,
   209  			},
   210  			expectedVdrs: []*Validator{},
   211  			expectedErr:  nil,
   212  		},
   213  		{
   214  			name:    "one filtered out",
   215  			indices: set.NewBits(1),
   216  			vdrs: []*Validator{
   217  				vdr0,
   218  				vdr1,
   219  			},
   220  			expectedVdrs: []*Validator{
   221  				vdr1,
   222  			},
   223  			expectedErr: nil,
   224  		},
   225  		{
   226  			name:    "none filtered out",
   227  			indices: set.NewBits(0, 1),
   228  			vdrs: []*Validator{
   229  				vdr0,
   230  				vdr1,
   231  			},
   232  			expectedVdrs: []*Validator{
   233  				vdr0,
   234  				vdr1,
   235  			},
   236  			expectedErr: nil,
   237  		},
   238  	}
   239  
   240  	for _, tt := range tests {
   241  		t.Run(tt.name, func(t *testing.T) {
   242  			require := require.New(t)
   243  
   244  			vdrs, err := FilterValidators(tt.indices, tt.vdrs)
   245  			require.ErrorIs(err, tt.expectedErr)
   246  			if tt.expectedErr != nil {
   247  				return
   248  			}
   249  			require.Equal(tt.expectedVdrs, vdrs)
   250  		})
   251  	}
   252  }
   253  
   254  func TestSumWeight(t *testing.T) {
   255  	vdr0 := &Validator{
   256  		Weight: 1,
   257  	}
   258  	vdr1 := &Validator{
   259  		Weight: 2,
   260  	}
   261  	vdr2 := &Validator{
   262  		Weight: math.MaxUint64,
   263  	}
   264  
   265  	type test struct {
   266  		name        string
   267  		vdrs        []*Validator
   268  		expectedSum uint64
   269  		expectedErr error
   270  	}
   271  
   272  	tests := []test{
   273  		{
   274  			name:        "empty",
   275  			vdrs:        []*Validator{},
   276  			expectedSum: 0,
   277  		},
   278  		{
   279  			name:        "one",
   280  			vdrs:        []*Validator{vdr0},
   281  			expectedSum: 1,
   282  		},
   283  		{
   284  			name:        "two",
   285  			vdrs:        []*Validator{vdr0, vdr1},
   286  			expectedSum: 3,
   287  		},
   288  		{
   289  			name:        "overflow",
   290  			vdrs:        []*Validator{vdr0, vdr2},
   291  			expectedErr: ErrWeightOverflow,
   292  		},
   293  	}
   294  
   295  	for _, tt := range tests {
   296  		t.Run(tt.name, func(t *testing.T) {
   297  			require := require.New(t)
   298  
   299  			sum, err := SumWeight(tt.vdrs)
   300  			require.ErrorIs(err, tt.expectedErr)
   301  			if tt.expectedErr != nil {
   302  				return
   303  			}
   304  			require.Equal(tt.expectedSum, sum)
   305  		})
   306  	}
   307  }
   308  
   309  func BenchmarkGetCanonicalValidatorSet(b *testing.B) {
   310  	pChainHeight := uint64(1)
   311  	subnetID := ids.GenerateTestID()
   312  	numNodes := 10_000
   313  	getValidatorOutputs := make([]*validators.GetValidatorOutput, 0, numNodes)
   314  	for i := 0; i < numNodes; i++ {
   315  		nodeID := ids.GenerateTestNodeID()
   316  		blsPrivateKey, err := bls.NewSecretKey()
   317  		require.NoError(b, err)
   318  		blsPublicKey := bls.PublicFromSecretKey(blsPrivateKey)
   319  		getValidatorOutputs = append(getValidatorOutputs, &validators.GetValidatorOutput{
   320  			NodeID:    nodeID,
   321  			PublicKey: blsPublicKey,
   322  			Weight:    20,
   323  		})
   324  	}
   325  
   326  	for _, size := range []int{0, 1, 10, 100, 1_000, 10_000} {
   327  		getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput)
   328  		for i := 0; i < size; i++ {
   329  			validator := getValidatorOutputs[i]
   330  			getValidatorsOutput[validator.NodeID] = validator
   331  		}
   332  		validatorState := &validators.TestState{
   333  			GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
   334  				return getValidatorsOutput, nil
   335  			},
   336  		}
   337  
   338  		b.Run(strconv.Itoa(size), func(b *testing.B) {
   339  			for i := 0; i < b.N; i++ {
   340  				_, _, err := GetCanonicalValidatorSet(context.Background(), validatorState, pChainHeight, subnetID)
   341  				require.NoError(b, err)
   342  			}
   343  		})
   344  	}
   345  }