github.com/line/ostracon@v1.0.10-0.20230328032236-7f20145f065d/privval/signer_client_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	cryptoproto "github.com/tendermint/tendermint/proto/tendermint/crypto"
    12  	privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval"
    13  	tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
    14  
    15  	"github.com/line/ostracon/crypto"
    16  	"github.com/line/ostracon/crypto/ed25519"
    17  	"github.com/line/ostracon/crypto/tmhash"
    18  	"github.com/line/ostracon/crypto/vrf"
    19  	tmrand "github.com/line/ostracon/libs/rand"
    20  	ocprivvalproto "github.com/line/ostracon/proto/ostracon/privval"
    21  	"github.com/line/ostracon/types"
    22  )
    23  
    24  type signerTestCase struct {
    25  	chainID      string
    26  	mockPV       types.PrivValidator
    27  	signerClient *SignerClient
    28  	signerServer *SignerServer
    29  }
    30  
    31  func getSignerTestCases(t *testing.T, mockPV types.PrivValidator, start bool) []signerTestCase {
    32  	testCases := make([]signerTestCase, 0)
    33  
    34  	// Get test cases for each possible dialer (DialTCP / DialUnix / etc)
    35  	for _, dtc := range getDialerTestCases(t) {
    36  		chainID := tmrand.Str(12)
    37  		mockKey := ed25519.GenPrivKey()
    38  		if mockPV == nil {
    39  			mockPV = types.NewMockPVWithParams(mockKey, false, false)
    40  		}
    41  
    42  		// get a pair of signer listener, signer dialer endpoints
    43  		sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer)
    44  		sc, err := NewSignerClient(sl, chainID)
    45  		require.NoError(t, err)
    46  		ss := NewSignerServer(sd, chainID, mockPV)
    47  
    48  		if start {
    49  			err = ss.Start()
    50  			require.NoError(t, err)
    51  		}
    52  
    53  		tc := signerTestCase{
    54  			chainID:      chainID,
    55  			mockPV:       mockPV,
    56  			signerClient: sc,
    57  			signerServer: ss,
    58  		}
    59  
    60  		testCases = append(testCases, tc)
    61  	}
    62  
    63  	return testCases
    64  }
    65  
    66  func TestSignerClose(t *testing.T) {
    67  	for _, tc := range getSignerTestCases(t, nil, true) {
    68  		err := tc.signerClient.Close()
    69  		assert.NoError(t, err)
    70  
    71  		err = tc.signerServer.Stop()
    72  		assert.NoError(t, err)
    73  	}
    74  }
    75  
    76  func TestSignerGetPubKey(t *testing.T) {
    77  	for _, tc := range getSignerTestCases(t, nil, true) {
    78  		tc := tc
    79  		t.Cleanup(func() {
    80  			if err := tc.signerServer.Stop(); err != nil {
    81  				t.Error(err)
    82  			}
    83  		})
    84  		t.Cleanup(func() {
    85  			if err := tc.signerClient.Close(); err != nil {
    86  				t.Error(err)
    87  			}
    88  		})
    89  
    90  		pubKey, err := tc.signerClient.GetPubKey()
    91  		require.NoError(t, err)
    92  		expectedPubKey, err := tc.mockPV.GetPubKey()
    93  		require.NoError(t, err)
    94  
    95  		assert.Equal(t, expectedPubKey, pubKey)
    96  
    97  		pubKey, err = tc.signerClient.GetPubKey()
    98  		require.NoError(t, err)
    99  		expectedpk, err := tc.mockPV.GetPubKey()
   100  		require.NoError(t, err)
   101  		expectedAddr := expectedpk.Address()
   102  
   103  		assert.Equal(t, expectedAddr, pubKey.Address())
   104  	}
   105  }
   106  
   107  func TestSignerProposal(t *testing.T) {
   108  	for _, tc := range getSignerTestCases(t, nil, true) {
   109  		ts := time.Now()
   110  		hash := tmrand.Bytes(tmhash.Size)
   111  		have := &types.Proposal{
   112  			Type:      tmproto.ProposalType,
   113  			Height:    1,
   114  			Round:     2,
   115  			POLRound:  2,
   116  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   117  			Timestamp: ts,
   118  		}
   119  		want := &types.Proposal{
   120  			Type:      tmproto.ProposalType,
   121  			Height:    1,
   122  			Round:     2,
   123  			POLRound:  2,
   124  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   125  			Timestamp: ts,
   126  		}
   127  
   128  		tc := tc
   129  		t.Cleanup(func() {
   130  			if err := tc.signerServer.Stop(); err != nil {
   131  				t.Error(err)
   132  			}
   133  		})
   134  		t.Cleanup(func() {
   135  			if err := tc.signerClient.Close(); err != nil {
   136  				t.Error(err)
   137  			}
   138  		})
   139  
   140  		require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want.ToProto()))
   141  		require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have.ToProto()))
   142  
   143  		assert.Equal(t, want.Signature, have.Signature)
   144  	}
   145  }
   146  
   147  func TestSignerGenerateVRFProof(t *testing.T) {
   148  	message := []byte("hello, world")
   149  	for _, tc := range getSignerTestCases(t, nil, true) {
   150  		tc := tc
   151  		t.Cleanup(func() {
   152  			if err := tc.signerServer.Stop(); err != nil {
   153  				t.Error(err)
   154  			}
   155  		})
   156  		t.Cleanup(func() {
   157  			if err := tc.signerClient.Close(); err != nil {
   158  				t.Error(err)
   159  			}
   160  		})
   161  
   162  		proof, err := tc.signerClient.GenerateVRFProof(message)
   163  		require.Nil(t, err)
   164  		require.True(t, len(proof) > 0)
   165  		output, err := vrf.ProofToHash(vrf.Proof(proof))
   166  		require.Nil(t, err)
   167  		require.NotNil(t, output)
   168  		pubKey, err := tc.signerClient.GetPubKey()
   169  		require.Nil(t, err)
   170  		ed25519PubKey, ok := pubKey.(ed25519.PubKey)
   171  		require.True(t, ok)
   172  		expected, err := vrf.Verify(ed25519PubKey, vrf.Proof(proof), message)
   173  		require.Nil(t, err)
   174  		assert.True(t, expected)
   175  	}
   176  }
   177  
   178  func TestSignerVote(t *testing.T) {
   179  	for _, tc := range getSignerTestCases(t, nil, true) {
   180  		ts := time.Now()
   181  		hash := tmrand.Bytes(tmhash.Size)
   182  		valAddr := tmrand.Bytes(crypto.AddressSize)
   183  		want := &types.Vote{
   184  			Type:             tmproto.PrecommitType,
   185  			Height:           1,
   186  			Round:            2,
   187  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   188  			Timestamp:        ts,
   189  			ValidatorAddress: valAddr,
   190  			ValidatorIndex:   1,
   191  		}
   192  
   193  		have := &types.Vote{
   194  			Type:             tmproto.PrecommitType,
   195  			Height:           1,
   196  			Round:            2,
   197  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   198  			Timestamp:        ts,
   199  			ValidatorAddress: valAddr,
   200  			ValidatorIndex:   1,
   201  		}
   202  
   203  		tc := tc
   204  		t.Cleanup(func() {
   205  			if err := tc.signerServer.Stop(); err != nil {
   206  				t.Error(err)
   207  			}
   208  		})
   209  		t.Cleanup(func() {
   210  			if err := tc.signerClient.Close(); err != nil {
   211  				t.Error(err)
   212  			}
   213  		})
   214  
   215  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   216  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   217  
   218  		assert.Equal(t, want.Signature, have.Signature)
   219  	}
   220  }
   221  
   222  func TestSignerVoteResetDeadline(t *testing.T) {
   223  	for _, tc := range getSignerTestCases(t, nil, true) {
   224  		ts := time.Now()
   225  		hash := tmrand.Bytes(tmhash.Size)
   226  		valAddr := tmrand.Bytes(crypto.AddressSize)
   227  		want := &types.Vote{
   228  			Type:             tmproto.PrecommitType,
   229  			Height:           1,
   230  			Round:            2,
   231  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   232  			Timestamp:        ts,
   233  			ValidatorAddress: valAddr,
   234  			ValidatorIndex:   1,
   235  		}
   236  
   237  		have := &types.Vote{
   238  			Type:             tmproto.PrecommitType,
   239  			Height:           1,
   240  			Round:            2,
   241  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   242  			Timestamp:        ts,
   243  			ValidatorAddress: valAddr,
   244  			ValidatorIndex:   1,
   245  		}
   246  
   247  		tc := tc
   248  		t.Cleanup(func() {
   249  			if err := tc.signerServer.Stop(); err != nil {
   250  				t.Error(err)
   251  			}
   252  		})
   253  		t.Cleanup(func() {
   254  			if err := tc.signerClient.Close(); err != nil {
   255  				t.Error(err)
   256  			}
   257  		})
   258  
   259  		time.Sleep(testTimeoutReadWrite2o3)
   260  
   261  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   262  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   263  		assert.Equal(t, want.Signature, have.Signature)
   264  
   265  		// TODO(jleni): Clarify what is actually being tested
   266  
   267  		// This would exceed the deadline if it was not extended by the previous message
   268  		time.Sleep(testTimeoutReadWrite2o3)
   269  
   270  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   271  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   272  		assert.Equal(t, want.Signature, have.Signature)
   273  	}
   274  }
   275  
   276  func TestSignerVoteKeepAlive(t *testing.T) {
   277  	for _, tc := range getSignerTestCases(t, nil, true) {
   278  		ts := time.Now()
   279  		hash := tmrand.Bytes(tmhash.Size)
   280  		valAddr := tmrand.Bytes(crypto.AddressSize)
   281  		want := &types.Vote{
   282  			Type:             tmproto.PrecommitType,
   283  			Height:           1,
   284  			Round:            2,
   285  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   286  			Timestamp:        ts,
   287  			ValidatorAddress: valAddr,
   288  			ValidatorIndex:   1,
   289  		}
   290  
   291  		have := &types.Vote{
   292  			Type:             tmproto.PrecommitType,
   293  			Height:           1,
   294  			Round:            2,
   295  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   296  			Timestamp:        ts,
   297  			ValidatorAddress: valAddr,
   298  			ValidatorIndex:   1,
   299  		}
   300  
   301  		tc := tc
   302  		t.Cleanup(func() {
   303  			if err := tc.signerServer.Stop(); err != nil {
   304  				t.Error(err)
   305  			}
   306  		})
   307  		t.Cleanup(func() {
   308  			if err := tc.signerClient.Close(); err != nil {
   309  				t.Error(err)
   310  			}
   311  		})
   312  
   313  		// Check that even if the client does not request a
   314  		// signature for a long time. The service is still available
   315  
   316  		// in this particular case, we use the dialer logger to ensure that
   317  		// test messages are properly interleaved in the test logs
   318  		tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------")
   319  		time.Sleep(testTimeoutReadWrite * 3)
   320  		tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------")
   321  
   322  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   323  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   324  
   325  		assert.Equal(t, want.Signature, have.Signature)
   326  	}
   327  }
   328  
   329  func TestSignerSignProposalErrors(t *testing.T) {
   330  	for _, tc := range getSignerTestCases(t, nil, true) {
   331  		// Replace service with a mock that always fails
   332  		tc.signerServer.privVal = types.NewErroringMockPV()
   333  		tc.mockPV = types.NewErroringMockPV()
   334  
   335  		tc := tc
   336  		t.Cleanup(func() {
   337  			if err := tc.signerServer.Stop(); err != nil {
   338  				t.Error(err)
   339  			}
   340  		})
   341  		t.Cleanup(func() {
   342  			if err := tc.signerClient.Close(); err != nil {
   343  				t.Error(err)
   344  			}
   345  		})
   346  
   347  		ts := time.Now()
   348  		hash := tmrand.Bytes(tmhash.Size)
   349  		proposal := &types.Proposal{
   350  			Type:      tmproto.ProposalType,
   351  			Height:    1,
   352  			Round:     2,
   353  			POLRound:  2,
   354  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   355  			Timestamp: ts,
   356  			Signature: []byte("signature"),
   357  		}
   358  
   359  		err := tc.signerClient.SignProposal(tc.chainID, proposal.ToProto())
   360  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   361  
   362  		err = tc.mockPV.SignProposal(tc.chainID, proposal.ToProto())
   363  		require.Error(t, err)
   364  
   365  		err = tc.signerClient.SignProposal(tc.chainID, proposal.ToProto())
   366  		require.Error(t, err)
   367  	}
   368  }
   369  
   370  func TestSignerSignVoteErrors(t *testing.T) {
   371  	for _, tc := range getSignerTestCases(t, nil, true) {
   372  		ts := time.Now()
   373  		hash := tmrand.Bytes(tmhash.Size)
   374  		valAddr := tmrand.Bytes(crypto.AddressSize)
   375  		vote := &types.Vote{
   376  			Type:             tmproto.PrecommitType,
   377  			Height:           1,
   378  			Round:            2,
   379  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   380  			Timestamp:        ts,
   381  			ValidatorAddress: valAddr,
   382  			ValidatorIndex:   1,
   383  		}
   384  
   385  		// Replace signer service privval with one that always fails
   386  		tc.signerServer.privVal = types.NewErroringMockPV()
   387  		tc.mockPV = types.NewErroringMockPV()
   388  
   389  		tc := tc
   390  		t.Cleanup(func() {
   391  			if err := tc.signerServer.Stop(); err != nil {
   392  				t.Error(err)
   393  			}
   394  		})
   395  		t.Cleanup(func() {
   396  			if err := tc.signerClient.Close(); err != nil {
   397  				t.Error(err)
   398  			}
   399  		})
   400  
   401  		err := tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   402  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   403  
   404  		err = tc.mockPV.SignVote(tc.chainID, vote.ToProto())
   405  		require.Error(t, err)
   406  
   407  		err = tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   408  		require.Error(t, err)
   409  	}
   410  }
   411  
   412  func brokenHandler(privVal types.PrivValidator, request ocprivvalproto.Message,
   413  	chainID string) (ocprivvalproto.Message, error) {
   414  	var res ocprivvalproto.Message
   415  	var err error
   416  
   417  	switch r := request.Sum.(type) {
   418  	// This is broken and will answer most requests with a pubkey response
   419  	case *ocprivvalproto.Message_PubKeyRequest:
   420  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   421  	case *ocprivvalproto.Message_SignVoteRequest:
   422  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   423  	case *ocprivvalproto.Message_SignProposalRequest:
   424  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   425  	case *ocprivvalproto.Message_PingRequest:
   426  		err, res = nil, mustWrapMsg(&privvalproto.PingResponse{})
   427  	default:
   428  		err = fmt.Errorf("unknown msg: %v", r)
   429  	}
   430  
   431  	return res, err
   432  }
   433  
   434  func TestSignerUnexpectedResponse(t *testing.T) {
   435  	for _, tc := range getSignerTestCases(t, types.NewMockPV(), false) {
   436  		tc.signerServer.SetRequestHandler(brokenHandler)
   437  		err := tc.signerServer.Start()
   438  		if err != nil {
   439  			panic(err)
   440  		}
   441  
   442  		tc := tc
   443  		t.Cleanup(func() {
   444  			if err := tc.signerServer.Stop(); err != nil {
   445  				t.Error(err)
   446  			}
   447  		})
   448  		t.Cleanup(func() {
   449  			if err := tc.signerClient.Close(); err != nil {
   450  				t.Error(err)
   451  			}
   452  		})
   453  
   454  		ts := time.Now()
   455  		want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
   456  
   457  		e := tc.signerClient.SignVote(tc.chainID, want.ToProto())
   458  		assert.EqualError(t, e, "empty response")
   459  	}
   460  }