github.com/aakash4dev/cometbft@v0.38.2/types/vote_test.go (about)

     1  package types
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/cosmos/gogoproto/proto"
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/aakash4dev/cometbft/crypto"
    12  	"github.com/aakash4dev/cometbft/crypto/ed25519"
    13  	"github.com/aakash4dev/cometbft/crypto/tmhash"
    14  	"github.com/aakash4dev/cometbft/libs/protoio"
    15  	cmtproto "github.com/aakash4dev/cometbft/proto/tendermint/types"
    16  	cmttime "github.com/aakash4dev/cometbft/types/time"
    17  )
    18  
    19  func examplePrevote() *Vote {
    20  	return exampleVote(byte(cmtproto.PrevoteType))
    21  }
    22  
    23  func examplePrecommit() *Vote {
    24  	vote := exampleVote(byte(cmtproto.PrecommitType))
    25  	vote.ExtensionSignature = []byte("signature")
    26  	return vote
    27  }
    28  
    29  func exampleVote(t byte) *Vote {
    30  	var stamp, err = time.Parse(TimeFormat, "2017-12-25T03:00:01.234Z")
    31  	if err != nil {
    32  		panic(err)
    33  	}
    34  
    35  	return &Vote{
    36  		Type:      cmtproto.SignedMsgType(t),
    37  		Height:    12345,
    38  		Round:     2,
    39  		Timestamp: stamp,
    40  		BlockID: BlockID{
    41  			Hash: tmhash.Sum([]byte("blockID_hash")),
    42  			PartSetHeader: PartSetHeader{
    43  				Total: 1000000,
    44  				Hash:  tmhash.Sum([]byte("blockID_part_set_header_hash")),
    45  			},
    46  		},
    47  		ValidatorAddress: crypto.AddressHash([]byte("validator_address")),
    48  		ValidatorIndex:   56789,
    49  	}
    50  }
    51  
    52  func TestVoteSignable(t *testing.T) {
    53  	vote := examplePrecommit()
    54  	v := vote.ToProto()
    55  	signBytes := VoteSignBytes("test_chain_id", v)
    56  	pb := CanonicalizeVote("test_chain_id", v)
    57  	expected, err := protoio.MarshalDelimited(&pb)
    58  	require.NoError(t, err)
    59  
    60  	require.Equal(t, expected, signBytes, "Got unexpected sign bytes for Vote.")
    61  }
    62  
    63  func TestVoteSignBytesTestVectors(t *testing.T) {
    64  
    65  	tests := []struct {
    66  		chainID string
    67  		vote    *Vote
    68  		want    []byte
    69  	}{
    70  		0: {
    71  			"", &Vote{},
    72  			// NOTE: Height and Round are skipped here. This case needs to be considered while parsing.
    73  			[]byte{0xd, 0x2a, 0xb, 0x8, 0x80, 0x92, 0xb8, 0xc3, 0x98, 0xfe, 0xff, 0xff, 0xff, 0x1},
    74  		},
    75  		// with proper (fixed size) height and round (PreCommit):
    76  		1: {
    77  			"", &Vote{Height: 1, Round: 1, Type: cmtproto.PrecommitType},
    78  			[]byte{
    79  				0x21,                                   // length
    80  				0x8,                                    // (field_number << 3) | wire_type
    81  				0x2,                                    // PrecommitType
    82  				0x11,                                   // (field_number << 3) | wire_type
    83  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // height
    84  				0x19,                                   // (field_number << 3) | wire_type
    85  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // round
    86  				0x2a, // (field_number << 3) | wire_type
    87  				// remaining fields (timestamp):
    88  				0xb, 0x8, 0x80, 0x92, 0xb8, 0xc3, 0x98, 0xfe, 0xff, 0xff, 0xff, 0x1},
    89  		},
    90  		// with proper (fixed size) height and round (PreVote):
    91  		2: {
    92  			"", &Vote{Height: 1, Round: 1, Type: cmtproto.PrevoteType},
    93  			[]byte{
    94  				0x21,                                   // length
    95  				0x8,                                    // (field_number << 3) | wire_type
    96  				0x1,                                    // PrevoteType
    97  				0x11,                                   // (field_number << 3) | wire_type
    98  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // height
    99  				0x19,                                   // (field_number << 3) | wire_type
   100  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // round
   101  				0x2a, // (field_number << 3) | wire_type
   102  				// remaining fields (timestamp):
   103  				0xb, 0x8, 0x80, 0x92, 0xb8, 0xc3, 0x98, 0xfe, 0xff, 0xff, 0xff, 0x1},
   104  		},
   105  		3: {
   106  			"", &Vote{Height: 1, Round: 1},
   107  			[]byte{
   108  				0x1f,                                   // length
   109  				0x11,                                   // (field_number << 3) | wire_type
   110  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // height
   111  				0x19,                                   // (field_number << 3) | wire_type
   112  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // round
   113  				// remaining fields (timestamp):
   114  				0x2a,
   115  				0xb, 0x8, 0x80, 0x92, 0xb8, 0xc3, 0x98, 0xfe, 0xff, 0xff, 0xff, 0x1},
   116  		},
   117  		// containing non-empty chain_id:
   118  		4: {
   119  			"test_chain_id", &Vote{Height: 1, Round: 1},
   120  			[]byte{
   121  				0x2e,                                   // length
   122  				0x11,                                   // (field_number << 3) | wire_type
   123  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // height
   124  				0x19,                                   // (field_number << 3) | wire_type
   125  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // round
   126  				// remaining fields:
   127  				0x2a,                                                                // (field_number << 3) | wire_type
   128  				0xb, 0x8, 0x80, 0x92, 0xb8, 0xc3, 0x98, 0xfe, 0xff, 0xff, 0xff, 0x1, // timestamp
   129  				// (field_number << 3) | wire_type
   130  				0x32,
   131  				0xd, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64}, // chainID
   132  		},
   133  		// containing vote extension
   134  		5: {
   135  			"test_chain_id", &Vote{
   136  				Height:    1,
   137  				Round:     1,
   138  				Extension: []byte("extension"),
   139  			},
   140  			[]byte{
   141  				0x2e,                                   // length
   142  				0x11,                                   // (field_number << 3) | wire_type
   143  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // height
   144  				0x19,                                   // (field_number << 3) | wire_type
   145  				0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // round
   146  				// remaning fields:
   147  				0x2a,                                                                // (field_number << 3) | wire_type
   148  				0xb, 0x8, 0x80, 0x92, 0xb8, 0xc3, 0x98, 0xfe, 0xff, 0xff, 0xff, 0x1, // timestamp
   149  				// (field_number << 3) | wire_type
   150  				0x32,
   151  				0xd, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, // chainID
   152  			}, // chainID
   153  		},
   154  	}
   155  	for i, tc := range tests {
   156  		v := tc.vote.ToProto()
   157  		got := VoteSignBytes(tc.chainID, v)
   158  		assert.Equal(t, len(tc.want), len(got), "test case #%v: got unexpected sign bytes length for Vote.", i)
   159  		assert.Equal(t, tc.want, got, "test case #%v: got unexpected sign bytes for Vote.", i)
   160  	}
   161  }
   162  
   163  func TestVoteProposalNotEq(t *testing.T) {
   164  	cv := CanonicalizeVote("", &cmtproto.Vote{Height: 1, Round: 1})
   165  	p := CanonicalizeProposal("", &cmtproto.Proposal{Height: 1, Round: 1})
   166  	vb, err := proto.Marshal(&cv)
   167  	require.NoError(t, err)
   168  	pb, err := proto.Marshal(&p)
   169  	require.NoError(t, err)
   170  	require.NotEqual(t, vb, pb)
   171  }
   172  
   173  func TestVoteVerifySignature(t *testing.T) {
   174  	privVal := NewMockPV()
   175  	pubkey, err := privVal.GetPubKey()
   176  	require.NoError(t, err)
   177  
   178  	vote := examplePrecommit()
   179  	v := vote.ToProto()
   180  	signBytes := VoteSignBytes("test_chain_id", v)
   181  
   182  	// sign it
   183  	err = privVal.SignVote("test_chain_id", v)
   184  	require.NoError(t, err)
   185  
   186  	// verify the same vote
   187  	valid := pubkey.VerifySignature(VoteSignBytes("test_chain_id", v), v.Signature)
   188  	require.True(t, valid)
   189  
   190  	// serialize, deserialize and verify again....
   191  	precommit := new(cmtproto.Vote)
   192  	bs, err := proto.Marshal(v)
   193  	require.NoError(t, err)
   194  	err = proto.Unmarshal(bs, precommit)
   195  	require.NoError(t, err)
   196  
   197  	// verify the transmitted vote
   198  	newSignBytes := VoteSignBytes("test_chain_id", precommit)
   199  	require.Equal(t, string(signBytes), string(newSignBytes))
   200  	valid = pubkey.VerifySignature(newSignBytes, precommit.Signature)
   201  	require.True(t, valid)
   202  }
   203  
   204  // TestVoteExtension tests that the vote verification behaves correctly in each case
   205  // of vote extension being set on the vote.
   206  func TestVoteExtension(t *testing.T) {
   207  	testCases := []struct {
   208  		name             string
   209  		extension        []byte
   210  		includeSignature bool
   211  		expectError      bool
   212  	}{
   213  		{
   214  			name:             "all fields present",
   215  			extension:        []byte("extension"),
   216  			includeSignature: true,
   217  			expectError:      false,
   218  		},
   219  		{
   220  			name:             "no extension signature",
   221  			extension:        []byte("extension"),
   222  			includeSignature: false,
   223  			expectError:      true,
   224  		},
   225  		{
   226  			name:             "empty extension",
   227  			includeSignature: true,
   228  			expectError:      false,
   229  		},
   230  		{
   231  			name:             "no extension and no signature",
   232  			includeSignature: false,
   233  			expectError:      true,
   234  		},
   235  	}
   236  
   237  	for _, tc := range testCases {
   238  		t.Run(tc.name, func(t *testing.T) {
   239  			height, round := int64(1), int32(0)
   240  			privVal := NewMockPV()
   241  			pk, err := privVal.GetPubKey()
   242  			require.NoError(t, err)
   243  			vote := &Vote{
   244  				ValidatorAddress: pk.Address(),
   245  				ValidatorIndex:   0,
   246  				Height:           height,
   247  				Round:            round,
   248  				Timestamp:        cmttime.Now(),
   249  				Type:             cmtproto.PrecommitType,
   250  				BlockID:          makeBlockIDRandom(),
   251  			}
   252  
   253  			v := vote.ToProto()
   254  			err = privVal.SignVote("test_chain_id", v)
   255  			require.NoError(t, err)
   256  			vote.Signature = v.Signature
   257  			if tc.includeSignature {
   258  				vote.ExtensionSignature = v.ExtensionSignature
   259  			}
   260  			err = vote.VerifyExtension("test_chain_id", pk)
   261  			if tc.expectError {
   262  				require.Error(t, err)
   263  			} else {
   264  				require.NoError(t, err)
   265  			}
   266  		})
   267  	}
   268  }
   269  
   270  func TestIsVoteTypeValid(t *testing.T) {
   271  	tc := []struct {
   272  		name string
   273  		in   cmtproto.SignedMsgType
   274  		out  bool
   275  	}{
   276  		{"Prevote", cmtproto.PrevoteType, true},
   277  		{"Precommit", cmtproto.PrecommitType, true},
   278  		{"InvalidType", cmtproto.SignedMsgType(0x3), false},
   279  	}
   280  
   281  	for _, tt := range tc {
   282  		tt := tt
   283  		t.Run(tt.name, func(st *testing.T) {
   284  			if rs := IsVoteTypeValid(tt.in); rs != tt.out {
   285  				t.Errorf("got unexpected Vote type. Expected:\n%v\nGot:\n%v", rs, tt.out)
   286  			}
   287  		})
   288  	}
   289  }
   290  
   291  func TestVoteVerify(t *testing.T) {
   292  	privVal := NewMockPV()
   293  	pubkey, err := privVal.GetPubKey()
   294  	require.NoError(t, err)
   295  
   296  	vote := examplePrevote()
   297  	vote.ValidatorAddress = pubkey.Address()
   298  
   299  	err = vote.Verify("test_chain_id", ed25519.GenPrivKey().PubKey())
   300  	if assert.Error(t, err) {
   301  		assert.Equal(t, ErrVoteInvalidValidatorAddress, err)
   302  	}
   303  
   304  	err = vote.Verify("test_chain_id", pubkey)
   305  	if assert.Error(t, err) {
   306  		assert.Equal(t, ErrVoteInvalidSignature, err)
   307  	}
   308  }
   309  
   310  func TestVoteString(t *testing.T) {
   311  	str := examplePrecommit().String()
   312  	expected := `Vote{56789:6AF1F4111082 12345/02/SIGNED_MSG_TYPE_PRECOMMIT(Precommit) 8B01023386C3 000000000000 000000000000 @ 2017-12-25T03:00:01.234Z}` //nolint:lll //ignore line length for tests
   313  	if str != expected {
   314  		t.Errorf("got unexpected string for Vote. Expected:\n%v\nGot:\n%v", expected, str)
   315  	}
   316  
   317  	str2 := examplePrevote().String()
   318  	expected = `Vote{56789:6AF1F4111082 12345/02/SIGNED_MSG_TYPE_PREVOTE(Prevote) 8B01023386C3 000000000000 000000000000 @ 2017-12-25T03:00:01.234Z}` //nolint:lll //ignore line length for tests
   319  	if str2 != expected {
   320  		t.Errorf("got unexpected string for Vote. Expected:\n%v\nGot:\n%v", expected, str2)
   321  	}
   322  }
   323  
   324  func signVote(t *testing.T, pv PrivValidator, chainID string, vote *Vote) {
   325  	t.Helper()
   326  
   327  	v := vote.ToProto()
   328  	require.NoError(t, pv.SignVote(chainID, v))
   329  	vote.Signature = v.Signature
   330  	vote.ExtensionSignature = v.ExtensionSignature
   331  }
   332  
   333  func TestValidVotes(t *testing.T) {
   334  	privVal := NewMockPV()
   335  
   336  	testCases := []struct {
   337  		name         string
   338  		vote         *Vote
   339  		malleateVote func(*Vote)
   340  	}{
   341  		{"good prevote", examplePrevote(), func(v *Vote) {}},
   342  		{"good precommit without vote extension", examplePrecommit(), func(v *Vote) { v.Extension = nil }},
   343  		{"good precommit with vote extension", examplePrecommit(), func(v *Vote) { v.Extension = []byte("extension") }},
   344  	}
   345  	for _, tc := range testCases {
   346  		signVote(t, privVal, "test_chain_id", tc.vote)
   347  		tc.malleateVote(tc.vote)
   348  		require.NoError(t, tc.vote.ValidateBasic(), "ValidateBasic for %s", tc.name)
   349  		require.NoError(t, tc.vote.EnsureExtension(), "EnsureExtension for %s", tc.name)
   350  	}
   351  }
   352  
   353  func TestInvalidVotes(t *testing.T) {
   354  	privVal := NewMockPV()
   355  
   356  	testCases := []struct {
   357  		name         string
   358  		malleateVote func(*Vote)
   359  	}{
   360  		{"negative height", func(v *Vote) { v.Height = -1 }},
   361  		{"negative round", func(v *Vote) { v.Round = -1 }},
   362  		{"zero Height", func(v *Vote) { v.Height = 0 }},
   363  		{"invalid block ID", func(v *Vote) { v.BlockID = BlockID{[]byte{1, 2, 3}, PartSetHeader{111, []byte("blockparts")}} }},
   364  		{"invalid address", func(v *Vote) { v.ValidatorAddress = make([]byte, 1) }},
   365  		{"invalid validator index", func(v *Vote) { v.ValidatorIndex = -1 }},
   366  		{"invalid signature", func(v *Vote) { v.Signature = nil }},
   367  		{"oversized signature", func(v *Vote) { v.Signature = make([]byte, MaxSignatureSize+1) }},
   368  	}
   369  	for _, tc := range testCases {
   370  		prevote := examplePrevote()
   371  		signVote(t, privVal, "test_chain_id", prevote)
   372  		tc.malleateVote(prevote)
   373  		require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s in invalid prevote", tc.name)
   374  		require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s in invalid prevote", tc.name)
   375  
   376  		precommit := examplePrecommit()
   377  		signVote(t, privVal, "test_chain_id", precommit)
   378  		tc.malleateVote(precommit)
   379  		require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s in invalid precommit", tc.name)
   380  		require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s in invalid precommit", tc.name)
   381  	}
   382  }
   383  
   384  func TestInvalidPrevotes(t *testing.T) {
   385  	privVal := NewMockPV()
   386  
   387  	testCases := []struct {
   388  		name         string
   389  		malleateVote func(*Vote)
   390  	}{
   391  		{"vote extension present", func(v *Vote) { v.Extension = []byte("extension") }},
   392  		{"vote extension signature present", func(v *Vote) { v.ExtensionSignature = []byte("signature") }},
   393  	}
   394  	for _, tc := range testCases {
   395  		prevote := examplePrevote()
   396  		signVote(t, privVal, "test_chain_id", prevote)
   397  		tc.malleateVote(prevote)
   398  		require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s", tc.name)
   399  		require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s", tc.name)
   400  	}
   401  }
   402  
   403  func TestInvalidPrecommitExtensions(t *testing.T) {
   404  	privVal := NewMockPV()
   405  
   406  	testCases := []struct {
   407  		name         string
   408  		malleateVote func(*Vote)
   409  	}{
   410  		{"vote extension present without signature", func(v *Vote) {
   411  			v.Extension = []byte("extension")
   412  			v.ExtensionSignature = nil
   413  		}},
   414  		{"oversized vote extension signature", func(v *Vote) { v.ExtensionSignature = make([]byte, MaxSignatureSize+1) }},
   415  	}
   416  	for _, tc := range testCases {
   417  		precommit := examplePrecommit()
   418  		signVote(t, privVal, "test_chain_id", precommit)
   419  		tc.malleateVote(precommit)
   420  		// ValidateBasic ensures that vote extensions, if present, are well formed
   421  		require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name)
   422  	}
   423  }
   424  
   425  func TestEnsureVoteExtension(t *testing.T) {
   426  	privVal := NewMockPV()
   427  
   428  	testCases := []struct {
   429  		name         string
   430  		malleateVote func(*Vote)
   431  		expectError  bool
   432  	}{
   433  		{"vote extension signature absent", func(v *Vote) {
   434  			v.Extension = nil
   435  			v.ExtensionSignature = nil
   436  		}, true},
   437  		{"vote extension signature present", func(v *Vote) {
   438  			v.ExtensionSignature = []byte("extension signature")
   439  		}, false},
   440  	}
   441  	for _, tc := range testCases {
   442  		precommit := examplePrecommit()
   443  		signVote(t, privVal, "test_chain_id", precommit)
   444  		tc.malleateVote(precommit)
   445  		if tc.expectError {
   446  			require.Error(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name)
   447  		} else {
   448  			require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name)
   449  		}
   450  	}
   451  }
   452  
   453  func TestVoteProtobuf(t *testing.T) {
   454  	privVal := NewMockPV()
   455  	vote := examplePrecommit()
   456  	v := vote.ToProto()
   457  	err := privVal.SignVote("test_chain_id", v)
   458  	vote.Signature = v.Signature
   459  	require.NoError(t, err)
   460  
   461  	testCases := []struct {
   462  		msg                 string
   463  		vote                *Vote
   464  		convertsOk          bool
   465  		passesValidateBasic bool
   466  	}{
   467  		{"success", vote, true, true},
   468  		{"fail vote validate basic", &Vote{}, true, false},
   469  	}
   470  	for _, tc := range testCases {
   471  		protoProposal := tc.vote.ToProto()
   472  
   473  		v, err := VoteFromProto(protoProposal)
   474  		if tc.convertsOk {
   475  			require.NoError(t, err)
   476  		} else {
   477  			require.Error(t, err)
   478  		}
   479  
   480  		err = v.ValidateBasic()
   481  		if tc.passesValidateBasic {
   482  			require.NoError(t, err)
   483  			require.Equal(t, tc.vote, v, tc.msg)
   484  		} else {
   485  			require.Error(t, err)
   486  		}
   487  	}
   488  }