github.com/Finschia/ostracon@v1.1.5/privval/signer_client_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"fmt"
     5  	curve25519voi "github.com/oasisprotocol/curve25519-voi/primitives/ed25519"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	cryptoproto "github.com/tendermint/tendermint/proto/tendermint/crypto"
    13  	privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval"
    14  	tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
    15  
    16  	"github.com/Finschia/ostracon/crypto"
    17  	"github.com/Finschia/ostracon/crypto/ed25519"
    18  	"github.com/Finschia/ostracon/crypto/tmhash"
    19  	tmrand "github.com/Finschia/ostracon/libs/rand"
    20  	ocprivvalproto "github.com/Finschia/ostracon/proto/ostracon/privval"
    21  	"github.com/Finschia/ostracon/types"
    22  	vrf "github.com/oasisprotocol/curve25519-voi/primitives/ed25519/extra/ecvrf"
    23  )
    24  
    25  type signerTestCase struct {
    26  	chainID      string
    27  	mockPV       types.PrivValidator
    28  	signerClient *SignerClient
    29  	signerServer *SignerServer
    30  }
    31  
    32  func getSignerTestCases(t *testing.T, mockPV types.PrivValidator, start bool) []signerTestCase {
    33  	testCases := make([]signerTestCase, 0)
    34  
    35  	// Get test cases for each possible dialer (DialTCP / DialUnix / etc)
    36  	for _, dtc := range getDialerTestCases(t) {
    37  		chainID := tmrand.Str(12)
    38  		mockKey := ed25519.GenPrivKey()
    39  		if mockPV == nil {
    40  			mockPV = types.NewMockPVWithParams(mockKey, false, false)
    41  		}
    42  
    43  		// get a pair of signer listener, signer dialer endpoints
    44  		sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer)
    45  		sc, err := NewSignerClient(sl, chainID)
    46  		require.NoError(t, err)
    47  		ss := NewSignerServer(sd, chainID, mockPV)
    48  
    49  		if start {
    50  			err = ss.Start()
    51  			require.NoError(t, err)
    52  		}
    53  
    54  		tc := signerTestCase{
    55  			chainID:      chainID,
    56  			mockPV:       mockPV,
    57  			signerClient: sc,
    58  			signerServer: ss,
    59  		}
    60  
    61  		testCases = append(testCases, tc)
    62  	}
    63  
    64  	return testCases
    65  }
    66  
    67  func TestSignerClose(t *testing.T) {
    68  	for _, tc := range getSignerTestCases(t, nil, true) {
    69  		err := tc.signerClient.Close()
    70  		assert.NoError(t, err)
    71  
    72  		err = tc.signerServer.Stop()
    73  		assert.NoError(t, err)
    74  	}
    75  }
    76  
    77  func TestSignerGetPubKey(t *testing.T) {
    78  	for _, tc := range getSignerTestCases(t, nil, true) {
    79  		tc := tc
    80  		t.Cleanup(func() {
    81  			if err := tc.signerServer.Stop(); err != nil {
    82  				t.Error(err)
    83  			}
    84  		})
    85  		t.Cleanup(func() {
    86  			if err := tc.signerClient.Close(); err != nil {
    87  				t.Error(err)
    88  			}
    89  		})
    90  
    91  		pubKey, err := tc.signerClient.GetPubKey()
    92  		require.NoError(t, err)
    93  		expectedPubKey, err := tc.mockPV.GetPubKey()
    94  		require.NoError(t, err)
    95  
    96  		assert.Equal(t, expectedPubKey, pubKey)
    97  
    98  		pubKey, err = tc.signerClient.GetPubKey()
    99  		require.NoError(t, err)
   100  		expectedpk, err := tc.mockPV.GetPubKey()
   101  		require.NoError(t, err)
   102  		expectedAddr := expectedpk.Address()
   103  
   104  		assert.Equal(t, expectedAddr, pubKey.Address())
   105  	}
   106  }
   107  
   108  func TestSignerProposal(t *testing.T) {
   109  	for _, tc := range getSignerTestCases(t, nil, true) {
   110  		ts := time.Now()
   111  		hash := tmrand.Bytes(tmhash.Size)
   112  		have := &types.Proposal{
   113  			Type:      tmproto.ProposalType,
   114  			Height:    1,
   115  			Round:     2,
   116  			POLRound:  2,
   117  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   118  			Timestamp: ts,
   119  		}
   120  		want := &types.Proposal{
   121  			Type:      tmproto.ProposalType,
   122  			Height:    1,
   123  			Round:     2,
   124  			POLRound:  2,
   125  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   126  			Timestamp: ts,
   127  		}
   128  
   129  		tc := tc
   130  		t.Cleanup(func() {
   131  			if err := tc.signerServer.Stop(); err != nil {
   132  				t.Error(err)
   133  			}
   134  		})
   135  		t.Cleanup(func() {
   136  			if err := tc.signerClient.Close(); err != nil {
   137  				t.Error(err)
   138  			}
   139  		})
   140  
   141  		require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want.ToProto()))
   142  		require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have.ToProto()))
   143  
   144  		assert.Equal(t, want.Signature, have.Signature)
   145  	}
   146  }
   147  
   148  func TestSignerGenerateVRFProof(t *testing.T) {
   149  	message := []byte("hello, world")
   150  	for _, tc := range getSignerTestCases(t, nil, true) {
   151  		tc := tc
   152  		t.Cleanup(func() {
   153  			if err := tc.signerServer.Stop(); err != nil {
   154  				t.Error(err)
   155  			}
   156  		})
   157  		t.Cleanup(func() {
   158  			if err := tc.signerClient.Close(); err != nil {
   159  				t.Error(err)
   160  			}
   161  		})
   162  
   163  		proof, err := tc.signerClient.GenerateVRFProof(message)
   164  		require.Nil(t, err)
   165  		require.True(t, len(proof) > 0)
   166  		output, err := vrf.ProofToHash(proof)
   167  		require.Nil(t, err)
   168  		require.NotNil(t, output)
   169  		pubKey, err := tc.signerClient.GetPubKey()
   170  		require.Nil(t, err)
   171  		ed25519PubKey, ok := pubKey.(ed25519.PubKey)
   172  		require.True(t, ok)
   173  		flag, bz := vrf.Verify(curve25519voi.PublicKey(ed25519PubKey), proof, message)
   174  		require.NotNil(t, bz)
   175  		assert.True(t, flag)
   176  	}
   177  }
   178  
   179  func TestSignerVote(t *testing.T) {
   180  	for _, tc := range getSignerTestCases(t, nil, true) {
   181  		ts := time.Now()
   182  		hash := tmrand.Bytes(tmhash.Size)
   183  		valAddr := tmrand.Bytes(crypto.AddressSize)
   184  		want := &types.Vote{
   185  			Type:             tmproto.PrecommitType,
   186  			Height:           1,
   187  			Round:            2,
   188  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   189  			Timestamp:        ts,
   190  			ValidatorAddress: valAddr,
   191  			ValidatorIndex:   1,
   192  		}
   193  
   194  		have := &types.Vote{
   195  			Type:             tmproto.PrecommitType,
   196  			Height:           1,
   197  			Round:            2,
   198  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   199  			Timestamp:        ts,
   200  			ValidatorAddress: valAddr,
   201  			ValidatorIndex:   1,
   202  		}
   203  
   204  		tc := tc
   205  		t.Cleanup(func() {
   206  			if err := tc.signerServer.Stop(); err != nil {
   207  				t.Error(err)
   208  			}
   209  		})
   210  		t.Cleanup(func() {
   211  			if err := tc.signerClient.Close(); err != nil {
   212  				t.Error(err)
   213  			}
   214  		})
   215  
   216  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   217  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   218  
   219  		assert.Equal(t, want.Signature, have.Signature)
   220  	}
   221  }
   222  
   223  func TestSignerVoteResetDeadline(t *testing.T) {
   224  	for _, tc := range getSignerTestCases(t, nil, true) {
   225  		ts := time.Now()
   226  		hash := tmrand.Bytes(tmhash.Size)
   227  		valAddr := tmrand.Bytes(crypto.AddressSize)
   228  		want := &types.Vote{
   229  			Type:             tmproto.PrecommitType,
   230  			Height:           1,
   231  			Round:            2,
   232  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   233  			Timestamp:        ts,
   234  			ValidatorAddress: valAddr,
   235  			ValidatorIndex:   1,
   236  		}
   237  
   238  		have := &types.Vote{
   239  			Type:             tmproto.PrecommitType,
   240  			Height:           1,
   241  			Round:            2,
   242  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   243  			Timestamp:        ts,
   244  			ValidatorAddress: valAddr,
   245  			ValidatorIndex:   1,
   246  		}
   247  
   248  		tc := tc
   249  		t.Cleanup(func() {
   250  			if err := tc.signerServer.Stop(); err != nil {
   251  				t.Error(err)
   252  			}
   253  		})
   254  		t.Cleanup(func() {
   255  			if err := tc.signerClient.Close(); err != nil {
   256  				t.Error(err)
   257  			}
   258  		})
   259  
   260  		time.Sleep(testTimeoutReadWrite2o3)
   261  
   262  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   263  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   264  		assert.Equal(t, want.Signature, have.Signature)
   265  
   266  		// TODO(jleni): Clarify what is actually being tested
   267  
   268  		// This would exceed the deadline if it was not extended by the previous message
   269  		time.Sleep(testTimeoutReadWrite2o3)
   270  
   271  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   272  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   273  		assert.Equal(t, want.Signature, have.Signature)
   274  	}
   275  }
   276  
   277  func TestSignerVoteKeepAlive(t *testing.T) {
   278  	for _, tc := range getSignerTestCases(t, nil, true) {
   279  		ts := time.Now()
   280  		hash := tmrand.Bytes(tmhash.Size)
   281  		valAddr := tmrand.Bytes(crypto.AddressSize)
   282  		want := &types.Vote{
   283  			Type:             tmproto.PrecommitType,
   284  			Height:           1,
   285  			Round:            2,
   286  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   287  			Timestamp:        ts,
   288  			ValidatorAddress: valAddr,
   289  			ValidatorIndex:   1,
   290  		}
   291  
   292  		have := &types.Vote{
   293  			Type:             tmproto.PrecommitType,
   294  			Height:           1,
   295  			Round:            2,
   296  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   297  			Timestamp:        ts,
   298  			ValidatorAddress: valAddr,
   299  			ValidatorIndex:   1,
   300  		}
   301  
   302  		tc := tc
   303  		t.Cleanup(func() {
   304  			if err := tc.signerServer.Stop(); err != nil {
   305  				t.Error(err)
   306  			}
   307  		})
   308  		t.Cleanup(func() {
   309  			if err := tc.signerClient.Close(); err != nil {
   310  				t.Error(err)
   311  			}
   312  		})
   313  
   314  		// Check that even if the client does not request a
   315  		// signature for a long time. The service is still available
   316  
   317  		// in this particular case, we use the dialer logger to ensure that
   318  		// test messages are properly interleaved in the test logs
   319  		tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------")
   320  		time.Sleep(testTimeoutReadWrite * 3)
   321  		tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------")
   322  
   323  		require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto()))
   324  		require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto()))
   325  
   326  		assert.Equal(t, want.Signature, have.Signature)
   327  	}
   328  }
   329  
   330  func TestSignerSignProposalErrors(t *testing.T) {
   331  	for _, tc := range getSignerTestCases(t, nil, true) {
   332  		// Replace service with a mock that always fails
   333  		tc.signerServer.privVal = types.NewErroringMockPV()
   334  		tc.mockPV = types.NewErroringMockPV()
   335  
   336  		tc := tc
   337  		t.Cleanup(func() {
   338  			if err := tc.signerServer.Stop(); err != nil {
   339  				t.Error(err)
   340  			}
   341  		})
   342  		t.Cleanup(func() {
   343  			if err := tc.signerClient.Close(); err != nil {
   344  				t.Error(err)
   345  			}
   346  		})
   347  
   348  		ts := time.Now()
   349  		hash := tmrand.Bytes(tmhash.Size)
   350  		proposal := &types.Proposal{
   351  			Type:      tmproto.ProposalType,
   352  			Height:    1,
   353  			Round:     2,
   354  			POLRound:  2,
   355  			BlockID:   types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   356  			Timestamp: ts,
   357  			Signature: []byte("signature"),
   358  		}
   359  
   360  		err := tc.signerClient.SignProposal(tc.chainID, proposal.ToProto())
   361  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   362  
   363  		err = tc.mockPV.SignProposal(tc.chainID, proposal.ToProto())
   364  		require.Error(t, err)
   365  
   366  		err = tc.signerClient.SignProposal(tc.chainID, proposal.ToProto())
   367  		require.Error(t, err)
   368  	}
   369  }
   370  
   371  func TestSignerSignVoteErrors(t *testing.T) {
   372  	for _, tc := range getSignerTestCases(t, nil, true) {
   373  		ts := time.Now()
   374  		hash := tmrand.Bytes(tmhash.Size)
   375  		valAddr := tmrand.Bytes(crypto.AddressSize)
   376  		vote := &types.Vote{
   377  			Type:             tmproto.PrecommitType,
   378  			Height:           1,
   379  			Round:            2,
   380  			BlockID:          types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}},
   381  			Timestamp:        ts,
   382  			ValidatorAddress: valAddr,
   383  			ValidatorIndex:   1,
   384  		}
   385  
   386  		// Replace signer service privval with one that always fails
   387  		tc.signerServer.privVal = types.NewErroringMockPV()
   388  		tc.mockPV = types.NewErroringMockPV()
   389  
   390  		tc := tc
   391  		t.Cleanup(func() {
   392  			if err := tc.signerServer.Stop(); err != nil {
   393  				t.Error(err)
   394  			}
   395  		})
   396  		t.Cleanup(func() {
   397  			if err := tc.signerClient.Close(); err != nil {
   398  				t.Error(err)
   399  			}
   400  		})
   401  
   402  		err := tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   403  		require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   404  
   405  		err = tc.mockPV.SignVote(tc.chainID, vote.ToProto())
   406  		require.Error(t, err)
   407  
   408  		err = tc.signerClient.SignVote(tc.chainID, vote.ToProto())
   409  		require.Error(t, err)
   410  	}
   411  }
   412  
   413  func brokenHandler(privVal types.PrivValidator, request ocprivvalproto.Message,
   414  	chainID string) (ocprivvalproto.Message, error) {
   415  	var res ocprivvalproto.Message
   416  	var err error
   417  
   418  	switch r := request.Sum.(type) {
   419  	// This is broken and will answer most requests with a pubkey response
   420  	case *ocprivvalproto.Message_PubKeyRequest:
   421  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   422  	case *ocprivvalproto.Message_SignVoteRequest:
   423  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   424  	case *ocprivvalproto.Message_SignProposalRequest:
   425  		res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil})
   426  	case *ocprivvalproto.Message_PingRequest:
   427  		err, res = nil, mustWrapMsg(&privvalproto.PingResponse{})
   428  	default:
   429  		err = fmt.Errorf("unknown msg: %v", r)
   430  	}
   431  
   432  	return res, err
   433  }
   434  
   435  func TestSignerUnexpectedResponse(t *testing.T) {
   436  	for _, tc := range getSignerTestCases(t, types.NewMockPV(), false) {
   437  		tc.signerServer.SetRequestHandler(brokenHandler)
   438  		err := tc.signerServer.Start()
   439  		if err != nil {
   440  			panic(err)
   441  		}
   442  
   443  		tc := tc
   444  		t.Cleanup(func() {
   445  			if err := tc.signerServer.Stop(); err != nil {
   446  				t.Error(err)
   447  			}
   448  		})
   449  		t.Cleanup(func() {
   450  			if err := tc.signerClient.Close(); err != nil {
   451  				t.Error(err)
   452  			}
   453  		})
   454  
   455  		ts := time.Now()
   456  		want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
   457  
   458  		e := tc.signerClient.SignVote(tc.chainID, want.ToProto())
   459  		assert.EqualError(t, e, "empty response")
   460  	}
   461  }