github.com/lazyledger/lazyledger-core@v0.35.0-dev.0.20210613111200-4c651f053571/privval/signer_client_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/lazyledger/lazyledger-core/crypto"
    12  	"github.com/lazyledger/lazyledger-core/crypto/tmhash"
    13  	tmrand "github.com/lazyledger/lazyledger-core/libs/rand"
    14  	cryptoproto "github.com/lazyledger/lazyledger-core/proto/tendermint/crypto"
    15  	privvalproto "github.com/lazyledger/lazyledger-core/proto/tendermint/privval"
    16  	tmproto "github.com/lazyledger/lazyledger-core/proto/tendermint/types"
    17  	"github.com/lazyledger/lazyledger-core/types"
    18  )
    19  
    20  type signerTestCase struct {
    21  	chainID      string
    22  	mockPV       types.PrivValidator
    23  	signerClient *SignerClient
    24  	signerServer *SignerServer
    25  }
    26  
    27  func getSignerTestCases(t *testing.T) []signerTestCase {
    28  	testCases := make([]signerTestCase, 0)
    29  
    30  	// Get test cases for each possible dialer (DialTCP / DialUnix / etc)
    31  	for _, dtc := range getDialerTestCases(t) {
    32  		chainID := tmrand.Str(12)
    33  		mockPV := types.NewMockPV()
    34  
    35  		// get a pair of signer listener, signer dialer endpoints
    36  		sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer)
    37  		sc, err := NewSignerClient(sl, chainID)
    38  		require.NoError(t, err)
    39  		ss := NewSignerServer(sd, chainID, mockPV)
    40  
    41  		err = ss.Start()
    42  		require.NoError(t, err)
    43  
    44  		tc := signerTestCase{
    45  			chainID:      chainID,
    46  			mockPV:       mockPV,
    47  			signerClient: sc,
    48  			signerServer: ss,
    49  		}
    50  
    51  		testCases = append(testCases, tc)
    52  	}
    53  
    54  	return testCases
    55  }
    56  
    57  func TestSignerClose(t *testing.T) {
    58  	for _, tc := range getSignerTestCases(t) {
    59  		err := tc.signerClient.Close()
    60  		assert.NoError(t, err)
    61  
    62  		err = tc.signerServer.Stop()
    63  		assert.NoError(t, err)
    64  	}
    65  }
    66  
    67  func TestSignerPing(t *testing.T) {
    68  	for _, tc := range getSignerTestCases(t) {
    69  		tc := tc
    70  		t.Cleanup(func() {
    71  			if err := tc.signerServer.Stop(); err != nil {
    72  				t.Error(err)
    73  			}
    74  		})
    75  		t.Cleanup(func() {
    76  			if err := tc.signerClient.Close(); err != nil {
    77  				t.Error(err)
    78  			}
    79  		})
    80  
    81  		err := tc.signerClient.Ping()
    82  		assert.NoError(t, err)
    83  	}
    84  }
    85  
    86  func TestSignerGetPubKey(t *testing.T) {
    87  	for _, tc := range getSignerTestCases(t) {
    88  		tc := tc
    89  		t.Cleanup(func() {
    90  			if err := tc.signerServer.Stop(); err != nil {
    91  				t.Error(err)
    92  			}
    93  		})
    94  		t.Cleanup(func() {
    95  			if err := tc.signerClient.Close(); err != nil {
    96  				t.Error(err)
    97  			}
    98  		})
    99  
   100  		pubKey, err := tc.signerClient.GetPubKey()
   101  		require.NoError(t, err)
   102  		expectedPubKey, err := tc.mockPV.GetPubKey()
   103  		require.NoError(t, err)
   104  
   105  		assert.Equal(t, expectedPubKey, pubKey)
   106  
   107  		pubKey, err = tc.signerClient.GetPubKey()
   108  		require.NoError(t, err)
   109  		expectedpk, err := tc.mockPV.GetPubKey()
   110  		require.NoError(t, err)
   111  		expectedAddr := expectedpk.Address()
   112  
   113  		assert.Equal(t, expectedAddr, pubKey.Address())
   114  	}
   115  }
   116  
   117  func TestSignerProposal(t *testing.T) {
   118  	for _, tc := range getSignerTestCases(t) {
   119  		ts := time.Now()
   120  		hash := tmrand.Bytes(tmhash.Size)
   121  		have := &types.Proposal{
   122  			Type:      tmproto.ProposalType,
   123  			Height:    1,
   124  			Round:     2,
   125  			POLRound:  2,
   126  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   127  			Timestamp: ts,
   128  			DAHeader:  &types.DataAvailabilityHeader{},
   129  		}
   130  		want := &types.Proposal{
   131  			Type:      tmproto.ProposalType,
   132  			Height:    1,
   133  			Round:     2,
   134  			POLRound:  2,
   135  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   136  			Timestamp: ts,
   137  			DAHeader:  &types.DataAvailabilityHeader{},
   138  		}
   139  
   140  		tc := tc
   141  		t.Cleanup(func() {
   142  			if err := tc.signerServer.Stop(); err != nil {
   143  				t.Error(err)
   144  			}
   145  		})
   146  		t.Cleanup(func() {
   147  			if err := tc.signerClient.Close(); err != nil {
   148  				t.Error(err)
   149  			}
   150  		})
   151  
   152  		p, err := want.ToProto()
   153  		require.NoError(t, err)
   154  		err = tc.mockPV.SignProposal(tc.chainID, p)
   155  		require.NoError(t, err)
   156  
   157  		p, err = have.ToProto()
   158  		require.NoError(t, err)
   159  		err = tc.signerClient.SignProposal(tc.chainID, p)
   160  		require.NoError(t, err)
   161  
   162  		assert.Equal(t, want.Signature, have.Signature)
   163  	}
   164  }
   165  
   166  func TestSignerVote(t *testing.T) {
   167  	for _, tc := range getSignerTestCases(t) {
   168  		ts := time.Now()
   169  		hash := tmrand.Bytes(tmhash.Size)
   170  		valAddr := tmrand.Bytes(crypto.AddressSize)
   171  		want := &types.Vote{
   172  			Type:             tmproto.PrecommitType,
   173  			Height:           1,
   174  			Round:            2,
   175  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   176  			Timestamp:        ts,
   177  			ValidatorAddress: valAddr,
   178  			ValidatorIndex:   1,
   179  		}
   180  
   181  		have := &types.Vote{
   182  			Type:             tmproto.PrecommitType,
   183  			Height:           1,
   184  			Round:            2,
   185  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   186  			Timestamp:        ts,
   187  			ValidatorAddress: valAddr,
   188  			ValidatorIndex:   1,
   189  		}
   190  
   191  		tc := tc
   192  		t.Cleanup(func() {
   193  			if err := tc.signerServer.Stop(); err != nil {
   194  				t.Error(err)
   195  			}
   196  		})
   197  		t.Cleanup(func() {
   198  			if err := tc.signerClient.Close(); err != nil {
   199  				t.Error(err)
   200  			}
   201  		})
   202  
   203  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   204  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   205  
   206  		assert.Equal(t, want.Signature, have.Signature)
   207  	}
   208  }
   209  
   210  func TestSignerVoteResetDeadline(t *testing.T) {
   211  	for _, tc := range getSignerTestCases(t) {
   212  		ts := time.Now()
   213  		hash := tmrand.Bytes(tmhash.Size)
   214  		valAddr := tmrand.Bytes(crypto.AddressSize)
   215  		want := &types.Vote{
   216  			Type:             tmproto.PrecommitType,
   217  			Height:           1,
   218  			Round:            2,
   219  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   220  			Timestamp:        ts,
   221  			ValidatorAddress: valAddr,
   222  			ValidatorIndex:   1,
   223  		}
   224  
   225  		have := &types.Vote{
   226  			Type:             tmproto.PrecommitType,
   227  			Height:           1,
   228  			Round:            2,
   229  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   230  			Timestamp:        ts,
   231  			ValidatorAddress: valAddr,
   232  			ValidatorIndex:   1,
   233  		}
   234  
   235  		tc := tc
   236  		t.Cleanup(func() {
   237  			if err := tc.signerServer.Stop(); err != nil {
   238  				t.Error(err)
   239  			}
   240  		})
   241  		t.Cleanup(func() {
   242  			if err := tc.signerClient.Close(); err != nil {
   243  				t.Error(err)
   244  			}
   245  		})
   246  
   247  		time.Sleep(testTimeoutReadWrite2o3)
   248  
   249  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   250  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   251  		assert.Equal(t, want.Signature, have.Signature)
   252  
   253  		// TODO(jleni): Clarify what is actually being tested
   254  
   255  		// This would exceed the deadline if it was not extended by the previous message
   256  		time.Sleep(testTimeoutReadWrite2o3)
   257  
   258  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   259  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   260  		assert.Equal(t, want.Signature, have.Signature)
   261  	}
   262  }
   263  
   264  func TestSignerVoteKeepAlive(t *testing.T) {
   265  	for _, tc := range getSignerTestCases(t) {
   266  		ts := time.Now()
   267  		hash := tmrand.Bytes(tmhash.Size)
   268  		valAddr := tmrand.Bytes(crypto.AddressSize)
   269  		want := &types.Vote{
   270  			Type:             tmproto.PrecommitType,
   271  			Height:           1,
   272  			Round:            2,
   273  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   274  			Timestamp:        ts,
   275  			ValidatorAddress: valAddr,
   276  			ValidatorIndex:   1,
   277  		}
   278  
   279  		have := &types.Vote{
   280  			Type:             tmproto.PrecommitType,
   281  			Height:           1,
   282  			Round:            2,
   283  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   284  			Timestamp:        ts,
   285  			ValidatorAddress: valAddr,
   286  			ValidatorIndex:   1,
   287  		}
   288  
   289  		tc := tc
   290  		t.Cleanup(func() {
   291  			if err := tc.signerServer.Stop(); err != nil {
   292  				t.Error(err)
   293  			}
   294  		})
   295  		t.Cleanup(func() {
   296  			if err := tc.signerClient.Close(); err != nil {
   297  				t.Error(err)
   298  			}
   299  		})
   300  
   301  		// Check that even if the client does not request a
   302  		// signature for a long time. The service is still available
   303  
   304  		// in this particular case, we use the dialer logger to ensure that
   305  		// test messages are properly interleaved in the test logs
   306  		tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------")
   307  		time.Sleep(testTimeoutReadWrite * 3)
   308  		tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------")
   309  
   310  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   311  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   312  
   313  		assert.Equal(t, want.Signature, have.Signature)
   314  	}
   315  }
   316  
   317  func TestSignerSignProposalErrors(t *testing.T) {
   318  	for _, tc := range getSignerTestCases(t) {
   319  		// Replace service with a mock that always fails
   320  		tc.signerServer.privVal = types.NewErroringMockPV()
   321  		tc.mockPV = types.NewErroringMockPV()
   322  
   323  		tc := tc
   324  		t.Cleanup(func() {
   325  			if err := tc.signerServer.Stop(); err != nil {
   326  				t.Error(err)
   327  			}
   328  		})
   329  		t.Cleanup(func() {
   330  			if err := tc.signerClient.Close(); err != nil {
   331  				t.Error(err)
   332  			}
   333  		})
   334  
   335  		ts := time.Now()
   336  		hash := tmrand.Bytes(tmhash.Size)
   337  		proposal := &types.Proposal{
   338  			Type:      tmproto.ProposalType,
   339  			Height:    1,
   340  			Round:     2,
   341  			POLRound:  2,
   342  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   343  			Timestamp: ts,
   344  			Signature: []byte("signature"),
   345  			DAHeader:  &types.DataAvailabilityHeader{},
   346  		}
   347  
   348  		p, err := proposal.ToProto()
   349  		require.NoError(t, err)
   350  		err = tc.signerClient.SignProposal(tc.chainID, p)
   351  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   352  
   353  		p, err = proposal.ToProto()
   354  		require.NoError(t, err)
   355  		err = tc.mockPV.SignProposal(tc.chainID, p)
   356  		require.Error(t, err)
   357  
   358  		p, err = proposal.ToProto()
   359  		require.NoError(t, err)
   360  		err = tc.signerClient.SignProposal(tc.chainID, p)
   361  		require.Error(t, err)
   362  	}
   363  }
   364  
   365  func TestSignerSignVoteErrors(t *testing.T) {
   366  	for _, tc := range getSignerTestCases(t) {
   367  		ts := time.Now()
   368  		hash := tmrand.Bytes(tmhash.Size)
   369  		valAddr := tmrand.Bytes(crypto.AddressSize)
   370  		vote := &types.Vote{
   371  			Type:             tmproto.PrecommitType,
   372  			Height:           1,
   373  			Round:            2,
   374  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   375  			Timestamp:        ts,
   376  			ValidatorAddress: valAddr,
   377  			ValidatorIndex:   1,
   378  			Signature:        []byte("signature"),
   379  		}
   380  
   381  		// Replace signer service privval with one that always fails
   382  		tc.signerServer.privVal = types.NewErroringMockPV()
   383  		tc.mockPV = types.NewErroringMockPV()
   384  
   385  		tc := tc
   386  		t.Cleanup(func() {
   387  			if err := tc.signerServer.Stop(); err != nil {
   388  				t.Error(err)
   389  			}
   390  		})
   391  		t.Cleanup(func() {
   392  			if err := tc.signerClient.Close(); err != nil {
   393  				t.Error(err)
   394  			}
   395  		})
   396  
   397  		err := tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   398  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   399  
   400  		err = tc.mockPV.SignVote(tc.chainID, vote.ToProto())
   401  		require.Error(t, err)
   402  
   403  		err = tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   404  		require.Error(t, err)
   405  	}
   406  }
   407  
   408  func brokenHandler(privVal types.PrivValidator, request privvalproto.Message,
   409  	chainID string) (privvalproto.Message, error) {
   410  	var res privvalproto.Message
   411  	var err error
   412  
   413  	switch r := request.Sum.(type) {
   414  	// This is broken and will answer most requests with a pubkey response
   415  	case *privvalproto.Message_PubKeyRequest:
   416  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   417  	case *privvalproto.Message_SignVoteRequest:
   418  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   419  	case *privvalproto.Message_SignProposalRequest:
   420  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   421  	case *privvalproto.Message_PingRequest:
   422  		err, res = nil, mustWrapMsg(&privvalproto.PingResponse{})
   423  	default:
   424  		err = fmt.Errorf("unknown msg: %v", r)
   425  	}
   426  
   427  	return res, err
   428  }
   429  
   430  func TestSignerUnexpectedResponse(t *testing.T) {
   431  	for _, tc := range getSignerTestCases(t) {
   432  		tc.signerServer.privVal = types.NewMockPV()
   433  		tc.mockPV = types.NewMockPV()
   434  
   435  		tc.signerServer.SetRequestHandler(brokenHandler)
   436  
   437  		tc := tc
   438  		t.Cleanup(func() {
   439  			if err := tc.signerServer.Stop(); err != nil {
   440  				t.Error(err)
   441  			}
   442  		})
   443  		t.Cleanup(func() {
   444  			if err := tc.signerClient.Close(); err != nil {
   445  				t.Error(err)
   446  			}
   447  		})
   448  
   449  		ts := time.Now()
   450  		want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
   451  
   452  		e := tc.signerClient.SignVote(tc.chainID, want.ToProto())
   453  		assert.EqualError(t, e, "empty response")
   454  	}
   455  }