github.com/evdatsion/aphelion-dpos-bft@v0.32.1/privval/signer_validator_endpoint_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/evdatsion/aphelion-dpos-bft/crypto/ed25519"
    13  	cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common"
    14  	"github.com/evdatsion/aphelion-dpos-bft/libs/log"
    15  
    16  	"github.com/evdatsion/aphelion-dpos-bft/types"
    17  )
    18  
    19  var (
    20  	testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second
    21  
    22  	testTimeoutReadWrite    = 100 * time.Millisecond
    23  	testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one
    24  
    25  	testTimeoutHeartbeat    = 10 * time.Millisecond
    26  	testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one
    27  )
    28  
    29  type socketTestCase struct {
    30  	addr   string
    31  	dialer SocketDialer
    32  }
    33  
    34  func socketTestCases(t *testing.T) []socketTestCase {
    35  	tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t))
    36  	unixFilePath, err := testUnixAddr()
    37  	require.NoError(t, err)
    38  	unixAddr := fmt.Sprintf("unix://%s", unixFilePath)
    39  	return []socketTestCase{
    40  		{
    41  			addr:   tcpAddr,
    42  			dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()),
    43  		},
    44  		{
    45  			addr:   unixAddr,
    46  			dialer: DialUnixFn(unixFilePath),
    47  		},
    48  	}
    49  }
    50  
    51  func TestSocketPVAddress(t *testing.T) {
    52  	for _, tc := range socketTestCases(t) {
    53  		// Execute the test within a closure to ensure the deferred statements
    54  		// are called between each for loop iteration, for isolated test cases.
    55  		func() {
    56  			var (
    57  				chainID                            = cmn.RandStr(12)
    58  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer)
    59  			)
    60  			defer validatorEndpoint.Stop()
    61  			defer serviceEndpoint.Stop()
    62  
    63  			serviceAddr := serviceEndpoint.privVal.GetPubKey().Address()
    64  			validatorAddr := validatorEndpoint.GetPubKey().Address()
    65  
    66  			assert.Equal(t, serviceAddr, validatorAddr)
    67  		}()
    68  	}
    69  }
    70  
    71  func TestSocketPVPubKey(t *testing.T) {
    72  	for _, tc := range socketTestCases(t) {
    73  		func() {
    74  			var (
    75  				chainID                            = cmn.RandStr(12)
    76  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
    77  					t,
    78  					chainID,
    79  					types.NewMockPV(),
    80  					tc.addr,
    81  					tc.dialer)
    82  			)
    83  			defer validatorEndpoint.Stop()
    84  			defer serviceEndpoint.Stop()
    85  
    86  			clientKey := validatorEndpoint.GetPubKey()
    87  			privvalPubKey := serviceEndpoint.privVal.GetPubKey()
    88  
    89  			assert.Equal(t, privvalPubKey, clientKey)
    90  		}()
    91  	}
    92  }
    93  
    94  func TestSocketPVProposal(t *testing.T) {
    95  	for _, tc := range socketTestCases(t) {
    96  		func() {
    97  			var (
    98  				chainID                            = cmn.RandStr(12)
    99  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
   100  					t,
   101  					chainID,
   102  					types.NewMockPV(),
   103  					tc.addr,
   104  					tc.dialer)
   105  
   106  				ts             = time.Now()
   107  				privProposal   = &types.Proposal{Timestamp: ts}
   108  				clientProposal = &types.Proposal{Timestamp: ts}
   109  			)
   110  			defer validatorEndpoint.Stop()
   111  			defer serviceEndpoint.Stop()
   112  
   113  			require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal))
   114  			require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal))
   115  
   116  			assert.Equal(t, privProposal.Signature, clientProposal.Signature)
   117  		}()
   118  	}
   119  }
   120  
   121  func TestSocketPVVote(t *testing.T) {
   122  	for _, tc := range socketTestCases(t) {
   123  		func() {
   124  			var (
   125  				chainID                            = cmn.RandStr(12)
   126  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
   127  					t,
   128  					chainID,
   129  					types.NewMockPV(),
   130  					tc.addr,
   131  					tc.dialer)
   132  
   133  				ts    = time.Now()
   134  				vType = types.PrecommitType
   135  				want  = &types.Vote{Timestamp: ts, Type: vType}
   136  				have  = &types.Vote{Timestamp: ts, Type: vType}
   137  			)
   138  			defer validatorEndpoint.Stop()
   139  			defer serviceEndpoint.Stop()
   140  
   141  			require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
   142  			require.NoError(t, validatorEndpoint.SignVote(chainID, have))
   143  			assert.Equal(t, want.Signature, have.Signature)
   144  		}()
   145  	}
   146  }
   147  
   148  func TestSocketPVVoteResetDeadline(t *testing.T) {
   149  	for _, tc := range socketTestCases(t) {
   150  		func() {
   151  			var (
   152  				chainID                            = cmn.RandStr(12)
   153  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
   154  					t,
   155  					chainID,
   156  					types.NewMockPV(),
   157  					tc.addr,
   158  					tc.dialer)
   159  
   160  				ts    = time.Now()
   161  				vType = types.PrecommitType
   162  				want  = &types.Vote{Timestamp: ts, Type: vType}
   163  				have  = &types.Vote{Timestamp: ts, Type: vType}
   164  			)
   165  			defer validatorEndpoint.Stop()
   166  			defer serviceEndpoint.Stop()
   167  
   168  			time.Sleep(testTimeoutReadWrite2o3)
   169  
   170  			require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
   171  			require.NoError(t, validatorEndpoint.SignVote(chainID, have))
   172  			assert.Equal(t, want.Signature, have.Signature)
   173  
   174  			// This would exceed the deadline if it was not extended by the previous message
   175  			time.Sleep(testTimeoutReadWrite2o3)
   176  
   177  			require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
   178  			require.NoError(t, validatorEndpoint.SignVote(chainID, have))
   179  			assert.Equal(t, want.Signature, have.Signature)
   180  		}()
   181  	}
   182  }
   183  
   184  func TestSocketPVVoteKeepalive(t *testing.T) {
   185  	for _, tc := range socketTestCases(t) {
   186  		func() {
   187  			var (
   188  				chainID                            = cmn.RandStr(12)
   189  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
   190  					t,
   191  					chainID,
   192  					types.NewMockPV(),
   193  					tc.addr,
   194  					tc.dialer)
   195  
   196  				ts    = time.Now()
   197  				vType = types.PrecommitType
   198  				want  = &types.Vote{Timestamp: ts, Type: vType}
   199  				have  = &types.Vote{Timestamp: ts, Type: vType}
   200  			)
   201  			defer validatorEndpoint.Stop()
   202  			defer serviceEndpoint.Stop()
   203  
   204  			time.Sleep(testTimeoutReadWrite * 2)
   205  
   206  			require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want))
   207  			require.NoError(t, validatorEndpoint.SignVote(chainID, have))
   208  			assert.Equal(t, want.Signature, have.Signature)
   209  		}()
   210  	}
   211  }
   212  
   213  func TestSocketPVDeadline(t *testing.T) {
   214  	for _, tc := range socketTestCases(t) {
   215  		func() {
   216  			var (
   217  				listenc           = make(chan struct{})
   218  				thisConnTimeout   = 100 * time.Millisecond
   219  				validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout)
   220  			)
   221  
   222  			go func(sc *SignerValidatorEndpoint) {
   223  				defer close(listenc)
   224  
   225  				// Note: the TCP connection times out at the accept() phase,
   226  				// whereas the Unix domain sockets connection times out while
   227  				// attempting to fetch the remote signer's public key.
   228  				assert.True(t, IsConnTimeout(sc.Start()))
   229  
   230  				assert.False(t, sc.IsRunning())
   231  			}(validatorEndpoint)
   232  
   233  			for {
   234  				_, err := cmn.Connect(tc.addr)
   235  				if err == nil {
   236  					break
   237  				}
   238  			}
   239  
   240  			<-listenc
   241  		}()
   242  	}
   243  }
   244  
   245  func TestRemoteSignVoteErrors(t *testing.T) {
   246  	for _, tc := range socketTestCases(t) {
   247  		func() {
   248  			var (
   249  				chainID                            = cmn.RandStr(12)
   250  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
   251  					t,
   252  					chainID,
   253  					types.NewErroringMockPV(),
   254  					tc.addr,
   255  					tc.dialer)
   256  
   257  				ts    = time.Now()
   258  				vType = types.PrecommitType
   259  				vote  = &types.Vote{Timestamp: ts, Type: vType}
   260  			)
   261  			defer validatorEndpoint.Stop()
   262  			defer serviceEndpoint.Stop()
   263  
   264  			err := validatorEndpoint.SignVote("", vote)
   265  			require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   266  
   267  			err = serviceEndpoint.privVal.SignVote(chainID, vote)
   268  			require.Error(t, err)
   269  			err = validatorEndpoint.SignVote(chainID, vote)
   270  			require.Error(t, err)
   271  		}()
   272  	}
   273  }
   274  
   275  func TestRemoteSignProposalErrors(t *testing.T) {
   276  	for _, tc := range socketTestCases(t) {
   277  		func() {
   278  			var (
   279  				chainID                            = cmn.RandStr(12)
   280  				validatorEndpoint, serviceEndpoint = testSetupSocketPair(
   281  					t,
   282  					chainID,
   283  					types.NewErroringMockPV(),
   284  					tc.addr,
   285  					tc.dialer)
   286  
   287  				ts       = time.Now()
   288  				proposal = &types.Proposal{Timestamp: ts}
   289  			)
   290  			defer validatorEndpoint.Stop()
   291  			defer serviceEndpoint.Stop()
   292  
   293  			err := validatorEndpoint.SignProposal("", proposal)
   294  			require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
   295  
   296  			err = serviceEndpoint.privVal.SignProposal(chainID, proposal)
   297  			require.Error(t, err)
   298  
   299  			err = validatorEndpoint.SignProposal(chainID, proposal)
   300  			require.Error(t, err)
   301  		}()
   302  	}
   303  }
   304  
   305  func TestErrUnexpectedResponse(t *testing.T) {
   306  	for _, tc := range socketTestCases(t) {
   307  		func() {
   308  			var (
   309  				logger  = log.TestingLogger()
   310  				chainID = cmn.RandStr(12)
   311  				readyCh = make(chan struct{})
   312  				errCh   = make(chan error, 1)
   313  
   314  				serviceEndpoint = NewSignerServiceEndpoint(
   315  					logger,
   316  					chainID,
   317  					types.NewMockPV(),
   318  					tc.dialer,
   319  				)
   320  
   321  				validatorEndpoint = newSignerValidatorEndpoint(
   322  					logger,
   323  					tc.addr,
   324  					testTimeoutReadWrite)
   325  			)
   326  
   327  			testStartEndpoint(t, readyCh, validatorEndpoint)
   328  			defer validatorEndpoint.Stop()
   329  			SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint)
   330  			SignerServiceEndpointConnRetries(100)(serviceEndpoint)
   331  			// we do not want to Start() the remote signer here and instead use the connection to
   332  			// reply with intentionally wrong replies below:
   333  			rsConn, err := serviceEndpoint.connect()
   334  			defer rsConn.Close()
   335  			require.NoError(t, err)
   336  			require.NotNil(t, rsConn)
   337  			// send over public key to get the remote signer running:
   338  			go testReadWriteResponse(t, &PubKeyResponse{}, rsConn)
   339  			<-readyCh
   340  
   341  			// Proposal:
   342  			go func(errc chan error) {
   343  				errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{})
   344  			}(errCh)
   345  
   346  			// read request and write wrong response:
   347  			go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn)
   348  			err = <-errCh
   349  			require.Error(t, err)
   350  			require.Equal(t, err, ErrUnexpectedResponse)
   351  
   352  			// Vote:
   353  			go func(errc chan error) {
   354  				errc <- validatorEndpoint.SignVote(chainID, &types.Vote{})
   355  			}(errCh)
   356  			// read request and write wrong response:
   357  			go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn)
   358  			err = <-errCh
   359  			require.Error(t, err)
   360  			require.Equal(t, err, ErrUnexpectedResponse)
   361  		}()
   362  	}
   363  }
   364  
   365  func TestRetryConnToRemoteSigner(t *testing.T) {
   366  	for _, tc := range socketTestCases(t) {
   367  		func() {
   368  			var (
   369  				logger  = log.TestingLogger()
   370  				chainID = cmn.RandStr(12)
   371  				readyCh = make(chan struct{})
   372  
   373  				serviceEndpoint = NewSignerServiceEndpoint(
   374  					logger,
   375  					chainID,
   376  					types.NewMockPV(),
   377  					tc.dialer,
   378  				)
   379  				thisConnTimeout   = testTimeoutReadWrite
   380  				validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout)
   381  			)
   382  			// Ping every:
   383  			SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint)
   384  
   385  			SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint)
   386  			SignerServiceEndpointConnRetries(10)(serviceEndpoint)
   387  
   388  			testStartEndpoint(t, readyCh, validatorEndpoint)
   389  			defer validatorEndpoint.Stop()
   390  			require.NoError(t, serviceEndpoint.Start())
   391  			assert.True(t, serviceEndpoint.IsRunning())
   392  
   393  			<-readyCh
   394  			time.Sleep(testTimeoutHeartbeat * 2)
   395  
   396  			serviceEndpoint.Stop()
   397  			rs2 := NewSignerServiceEndpoint(
   398  				logger,
   399  				chainID,
   400  				types.NewMockPV(),
   401  				tc.dialer,
   402  			)
   403  			// let some pings pass
   404  			time.Sleep(testTimeoutHeartbeat3o2)
   405  			require.NoError(t, rs2.Start())
   406  			assert.True(t, rs2.IsRunning())
   407  			defer rs2.Stop()
   408  
   409  			// give the client some time to re-establish the conn to the remote signer
   410  			// should see sth like this in the logs:
   411  			//
   412  			// E[10016-01-10|17:12:46.128] Ping                                         err="remote signer timed out"
   413  			// I[10016-01-10|17:16:42.447] Re-created connection to remote signer       impl=SocketVal
   414  			time.Sleep(testTimeoutReadWrite * 2)
   415  		}()
   416  	}
   417  }
   418  
   419  func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint {
   420  	proto, address := cmn.ProtocolAndAddress(addr)
   421  
   422  	ln, err := net.Listen(proto, address)
   423  	logger.Info("Listening at", "proto", proto, "address", address)
   424  	if err != nil {
   425  		panic(err)
   426  	}
   427  
   428  	var listener net.Listener
   429  
   430  	if proto == "unix" {
   431  		unixLn := NewUnixListener(ln)
   432  		UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn)
   433  		UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn)
   434  		listener = unixLn
   435  	} else {
   436  		tcpLn := NewTCPListener(ln, ed25519.GenPrivKey())
   437  		TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn)
   438  		TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn)
   439  		listener = tcpLn
   440  	}
   441  
   442  	return NewSignerValidatorEndpoint(logger, listener)
   443  }
   444  
   445  func testSetupSocketPair(
   446  	t *testing.T,
   447  	chainID string,
   448  	privValidator types.PrivValidator,
   449  	addr string,
   450  	socketDialer SocketDialer,
   451  ) (*SignerValidatorEndpoint, *SignerServiceEndpoint) {
   452  	var (
   453  		logger          = log.TestingLogger()
   454  		privVal         = privValidator
   455  		readyc          = make(chan struct{})
   456  		serviceEndpoint = NewSignerServiceEndpoint(
   457  			logger,
   458  			chainID,
   459  			privVal,
   460  			socketDialer,
   461  		)
   462  
   463  		thisConnTimeout   = testTimeoutReadWrite
   464  		validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout)
   465  	)
   466  
   467  	SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint)
   468  	SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint)
   469  	SignerServiceEndpointConnRetries(1e6)(serviceEndpoint)
   470  
   471  	testStartEndpoint(t, readyc, validatorEndpoint)
   472  
   473  	require.NoError(t, serviceEndpoint.Start())
   474  	assert.True(t, serviceEndpoint.IsRunning())
   475  
   476  	<-readyc
   477  
   478  	return validatorEndpoint, serviceEndpoint
   479  }
   480  
   481  func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) {
   482  	_, err := readMsg(rsConn)
   483  	require.NoError(t, err)
   484  
   485  	err = writeMsg(rsConn, resp)
   486  	require.NoError(t, err)
   487  }
   488  
   489  func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) {
   490  	go func(sc *SignerValidatorEndpoint) {
   491  		require.NoError(t, sc.Start())
   492  		assert.True(t, sc.IsRunning())
   493  
   494  		readyCh <- struct{}{}
   495  	}(sc)
   496  }
   497  
   498  // testFreeTCPAddr claims a free port so we don't block on listener being ready.
   499  func testFreeTCPAddr(t *testing.T) string {
   500  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   501  	require.NoError(t, err)
   502  	defer ln.Close()
   503  
   504  	return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port)
   505  }