github.com/vipernet-xyz/tm@v0.34.24/privval/signer_client_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/vipernet-xyz/tm/crypto"
    12  	"github.com/vipernet-xyz/tm/crypto/tmhash"
    13  	tmrand "github.com/vipernet-xyz/tm/libs/rand"
    14  	cryptoproto "github.com/vipernet-xyz/tm/proto/tendermint/crypto"
    15  	privvalproto "github.com/vipernet-xyz/tm/proto/tendermint/privval"
    16  	tmproto "github.com/vipernet-xyz/tm/proto/tendermint/types"
    17  	"github.com/vipernet-xyz/tm/types"
    18  )
    19  
    20  type signerTestCase struct {
    21  	chainID      string
    22  	mockPV       types.PrivValidator
    23  	signerClient *SignerClient
    24  	signerServer *SignerServer
    25  }
    26  
    27  func getSignerTestCases(t *testing.T) []signerTestCase {
    28  	testCases := make([]signerTestCase, 0)
    29  
    30  	// Get test cases for each possible dialer (DialTCP / DialUnix / etc)
    31  	for _, dtc := range getDialerTestCases(t) {
    32  		chainID := tmrand.Str(12)
    33  		mockPV := types.NewMockPV()
    34  
    35  		// get a pair of signer listener, signer dialer endpoints
    36  		sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer)
    37  		sc, err := NewSignerClient(sl, chainID)
    38  		require.NoError(t, err)
    39  		ss := NewSignerServer(sd, chainID, mockPV)
    40  
    41  		err = ss.Start()
    42  		require.NoError(t, err)
    43  
    44  		tc := signerTestCase{
    45  			chainID:      chainID,
    46  			mockPV:       mockPV,
    47  			signerClient: sc,
    48  			signerServer: ss,
    49  		}
    50  
    51  		testCases = append(testCases, tc)
    52  	}
    53  
    54  	return testCases
    55  }
    56  
    57  func TestSignerClose(t *testing.T) {
    58  	for _, tc := range getSignerTestCases(t) {
    59  		err := tc.signerClient.Close()
    60  		assert.NoError(t, err)
    61  
    62  		err = tc.signerServer.Stop()
    63  		assert.NoError(t, err)
    64  	}
    65  }
    66  
    67  func TestSignerPing(t *testing.T) {
    68  	for _, tc := range getSignerTestCases(t) {
    69  		tc := tc
    70  		t.Cleanup(func() {
    71  			if err := tc.signerServer.Stop(); err != nil {
    72  				t.Error(err)
    73  			}
    74  		})
    75  		t.Cleanup(func() {
    76  			if err := tc.signerClient.Close(); err != nil {
    77  				t.Error(err)
    78  			}
    79  		})
    80  
    81  		err := tc.signerClient.Ping()
    82  		assert.NoError(t, err)
    83  	}
    84  }
    85  
    86  func TestSignerGetPubKey(t *testing.T) {
    87  	for _, tc := range getSignerTestCases(t) {
    88  		tc := tc
    89  		t.Cleanup(func() {
    90  			if err := tc.signerServer.Stop(); err != nil {
    91  				t.Error(err)
    92  			}
    93  		})
    94  		t.Cleanup(func() {
    95  			if err := tc.signerClient.Close(); err != nil {
    96  				t.Error(err)
    97  			}
    98  		})
    99  
   100  		pubKey, err := tc.signerClient.GetPubKey()
   101  		require.NoError(t, err)
   102  		expectedPubKey, err := tc.mockPV.GetPubKey()
   103  		require.NoError(t, err)
   104  
   105  		assert.Equal(t, expectedPubKey, pubKey)
   106  
   107  		pubKey, err = tc.signerClient.GetPubKey()
   108  		require.NoError(t, err)
   109  		expectedpk, err := tc.mockPV.GetPubKey()
   110  		require.NoError(t, err)
   111  		expectedAddr := expectedpk.Address()
   112  
   113  		assert.Equal(t, expectedAddr, pubKey.Address())
   114  	}
   115  }
   116  
   117  func TestSignerProposal(t *testing.T) {
   118  	for _, tc := range getSignerTestCases(t) {
   119  		ts := time.Now()
   120  		hash := tmrand.Bytes(tmhash.Size)
   121  		have := &types.Proposal{
   122  			Type:      tmproto.ProposalType,
   123  			Height:    1,
   124  			Round:     2,
   125  			POLRound:  2,
   126  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   127  			Timestamp: ts,
   128  		}
   129  		want := &types.Proposal{
   130  			Type:      tmproto.ProposalType,
   131  			Height:    1,
   132  			Round:     2,
   133  			POLRound:  2,
   134  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   135  			Timestamp: ts,
   136  		}
   137  
   138  		tc := tc
   139  		t.Cleanup(func() {
   140  			if err := tc.signerServer.Stop(); err != nil {
   141  				t.Error(err)
   142  			}
   143  		})
   144  		t.Cleanup(func() {
   145  			if err := tc.signerClient.Close(); err != nil {
   146  				t.Error(err)
   147  			}
   148  		})
   149  
   150  		require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want.ToProto()))
   151  		require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have.ToProto()))
   152  
   153  		assert.Equal(t, want.Signature, have.Signature)
   154  	}
   155  }
   156  
   157  func TestSignerVote(t *testing.T) {
   158  	for _, tc := range getSignerTestCases(t) {
   159  		ts := time.Now()
   160  		hash := tmrand.Bytes(tmhash.Size)
   161  		valAddr := tmrand.Bytes(crypto.AddressSize)
   162  		want := &types.Vote{
   163  			Type:             tmproto.PrecommitType,
   164  			Height:           1,
   165  			Round:            2,
   166  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   167  			Timestamp:        ts,
   168  			ValidatorAddress: valAddr,
   169  			ValidatorIndex:   1,
   170  		}
   171  
   172  		have := &types.Vote{
   173  			Type:             tmproto.PrecommitType,
   174  			Height:           1,
   175  			Round:            2,
   176  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   177  			Timestamp:        ts,
   178  			ValidatorAddress: valAddr,
   179  			ValidatorIndex:   1,
   180  		}
   181  
   182  		tc := tc
   183  		t.Cleanup(func() {
   184  			if err := tc.signerServer.Stop(); err != nil {
   185  				t.Error(err)
   186  			}
   187  		})
   188  		t.Cleanup(func() {
   189  			if err := tc.signerClient.Close(); err != nil {
   190  				t.Error(err)
   191  			}
   192  		})
   193  
   194  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   195  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   196  
   197  		assert.Equal(t, want.Signature, have.Signature)
   198  	}
   199  }
   200  
   201  func TestSignerVoteResetDeadline(t *testing.T) {
   202  	for _, tc := range getSignerTestCases(t) {
   203  		ts := time.Now()
   204  		hash := tmrand.Bytes(tmhash.Size)
   205  		valAddr := tmrand.Bytes(crypto.AddressSize)
   206  		want := &types.Vote{
   207  			Type:             tmproto.PrecommitType,
   208  			Height:           1,
   209  			Round:            2,
   210  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   211  			Timestamp:        ts,
   212  			ValidatorAddress: valAddr,
   213  			ValidatorIndex:   1,
   214  		}
   215  
   216  		have := &types.Vote{
   217  			Type:             tmproto.PrecommitType,
   218  			Height:           1,
   219  			Round:            2,
   220  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   221  			Timestamp:        ts,
   222  			ValidatorAddress: valAddr,
   223  			ValidatorIndex:   1,
   224  		}
   225  
   226  		tc := tc
   227  		t.Cleanup(func() {
   228  			if err := tc.signerServer.Stop(); err != nil {
   229  				t.Error(err)
   230  			}
   231  		})
   232  		t.Cleanup(func() {
   233  			if err := tc.signerClient.Close(); err != nil {
   234  				t.Error(err)
   235  			}
   236  		})
   237  
   238  		time.Sleep(testTimeoutReadWrite2o3)
   239  
   240  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   241  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   242  		assert.Equal(t, want.Signature, have.Signature)
   243  
   244  		// TODO(jleni): Clarify what is actually being tested
   245  
   246  		// This would exceed the deadline if it was not extended by the previous message
   247  		time.Sleep(testTimeoutReadWrite2o3)
   248  
   249  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   250  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   251  		assert.Equal(t, want.Signature, have.Signature)
   252  	}
   253  }
   254  
   255  func TestSignerVoteKeepAlive(t *testing.T) {
   256  	for _, tc := range getSignerTestCases(t) {
   257  		ts := time.Now()
   258  		hash := tmrand.Bytes(tmhash.Size)
   259  		valAddr := tmrand.Bytes(crypto.AddressSize)
   260  		want := &types.Vote{
   261  			Type:             tmproto.PrecommitType,
   262  			Height:           1,
   263  			Round:            2,
   264  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   265  			Timestamp:        ts,
   266  			ValidatorAddress: valAddr,
   267  			ValidatorIndex:   1,
   268  		}
   269  
   270  		have := &types.Vote{
   271  			Type:             tmproto.PrecommitType,
   272  			Height:           1,
   273  			Round:            2,
   274  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   275  			Timestamp:        ts,
   276  			ValidatorAddress: valAddr,
   277  			ValidatorIndex:   1,
   278  		}
   279  
   280  		tc := tc
   281  		t.Cleanup(func() {
   282  			if err := tc.signerServer.Stop(); err != nil {
   283  				t.Error(err)
   284  			}
   285  		})
   286  		t.Cleanup(func() {
   287  			if err := tc.signerClient.Close(); err != nil {
   288  				t.Error(err)
   289  			}
   290  		})
   291  
   292  		// Check that even if the client does not request a
   293  		// signature for a long time. The service is still available
   294  
   295  		// in this particular case, we use the dialer logger to ensure that
   296  		// test messages are properly interleaved in the test logs
   297  		tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------")
   298  		time.Sleep(testTimeoutReadWrite * 3)
   299  		tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------")
   300  
   301  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   302  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   303  
   304  		assert.Equal(t, want.Signature, have.Signature)
   305  	}
   306  }
   307  
   308  func TestSignerSignProposalErrors(t *testing.T) {
   309  	for _, tc := range getSignerTestCases(t) {
   310  		// Replace service with a mock that always fails
   311  		tc.signerServer.privVal = types.NewErroringMockPV()
   312  		tc.mockPV = types.NewErroringMockPV()
   313  
   314  		tc := tc
   315  		t.Cleanup(func() {
   316  			if err := tc.signerServer.Stop(); err != nil {
   317  				t.Error(err)
   318  			}
   319  		})
   320  		t.Cleanup(func() {
   321  			if err := tc.signerClient.Close(); err != nil {
   322  				t.Error(err)
   323  			}
   324  		})
   325  
   326  		ts := time.Now()
   327  		hash := tmrand.Bytes(tmhash.Size)
   328  		proposal := &types.Proposal{
   329  			Type:      tmproto.ProposalType,
   330  			Height:    1,
   331  			Round:     2,
   332  			POLRound:  2,
   333  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   334  			Timestamp: ts,
   335  			Signature: []byte("signature"),
   336  		}
   337  
   338  		err := tc.signerClient.SignProposal(tc.chainID, proposal.ToProto())
   339  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   340  
   341  		err = tc.mockPV.SignProposal(tc.chainID, proposal.ToProto())
   342  		require.Error(t, err)
   343  
   344  		err = tc.signerClient.SignProposal(tc.chainID, proposal.ToProto())
   345  		require.Error(t, err)
   346  	}
   347  }
   348  
   349  func TestSignerSignVoteErrors(t *testing.T) {
   350  	for _, tc := range getSignerTestCases(t) {
   351  		ts := time.Now()
   352  		hash := tmrand.Bytes(tmhash.Size)
   353  		valAddr := tmrand.Bytes(crypto.AddressSize)
   354  		vote := &types.Vote{
   355  			Type:             tmproto.PrecommitType,
   356  			Height:           1,
   357  			Round:            2,
   358  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   359  			Timestamp:        ts,
   360  			ValidatorAddress: valAddr,
   361  			ValidatorIndex:   1,
   362  			Signature:        []byte("signature"),
   363  		}
   364  
   365  		// Replace signer service privval with one that always fails
   366  		tc.signerServer.privVal = types.NewErroringMockPV()
   367  		tc.mockPV = types.NewErroringMockPV()
   368  
   369  		tc := tc
   370  		t.Cleanup(func() {
   371  			if err := tc.signerServer.Stop(); err != nil {
   372  				t.Error(err)
   373  			}
   374  		})
   375  		t.Cleanup(func() {
   376  			if err := tc.signerClient.Close(); err != nil {
   377  				t.Error(err)
   378  			}
   379  		})
   380  
   381  		err := tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   382  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   383  
   384  		err = tc.mockPV.SignVote(tc.chainID, vote.ToProto())
   385  		require.Error(t, err)
   386  
   387  		err = tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   388  		require.Error(t, err)
   389  	}
   390  }
   391  
   392  func brokenHandler(privVal types.PrivValidator, request privvalproto.Message,
   393  	chainID string) (privvalproto.Message, error) {
   394  	var res privvalproto.Message
   395  	var err error
   396  
   397  	switch r := request.Sum.(type) {
   398  	// This is broken and will answer most requests with a pubkey response
   399  	case *privvalproto.Message_PubKeyRequest:
   400  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   401  	case *privvalproto.Message_SignVoteRequest:
   402  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   403  	case *privvalproto.Message_SignProposalRequest:
   404  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   405  	case *privvalproto.Message_PingRequest:
   406  		err, res = nil, mustWrapMsg(&privvalproto.PingResponse{})
   407  	default:
   408  		err = fmt.Errorf("unknown msg: %v", r)
   409  	}
   410  
   411  	return res, err
   412  }
   413  
   414  func TestSignerUnexpectedResponse(t *testing.T) {
   415  	for _, tc := range getSignerTestCases(t) {
   416  		tc.signerServer.privVal = types.NewMockPV()
   417  		tc.mockPV = types.NewMockPV()
   418  
   419  		tc.signerServer.SetRequestHandler(brokenHandler)
   420  
   421  		tc := tc
   422  		t.Cleanup(func() {
   423  			if err := tc.signerServer.Stop(); err != nil {
   424  				t.Error(err)
   425  			}
   426  		})
   427  		t.Cleanup(func() {
   428  			if err := tc.signerClient.Close(); err != nil {
   429  				t.Error(err)
   430  			}
   431  		})
   432  
   433  		ts := time.Now()
   434  		want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
   435  
   436  		e := tc.signerClient.SignVote(tc.chainID, want.ToProto())
   437  		assert.EqualError(t, e, "empty response")
   438  	}
   439  }