github.com/vipernet-xyz/tm@v0.34.24/privval/signer_listener_endpoint_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"net"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/vipernet-xyz/tm/crypto/ed25519"
    12  	"github.com/vipernet-xyz/tm/libs/log"
    13  	tmnet "github.com/vipernet-xyz/tm/libs/net"
    14  	tmrand "github.com/vipernet-xyz/tm/libs/rand"
    15  	"github.com/vipernet-xyz/tm/types"
    16  )
    17  
    18  var (
    19  	testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second
    20  
    21  	testTimeoutReadWrite    = 100 * time.Millisecond
    22  	testTimeoutReadWrite2o3 = 60 * time.Millisecond // 2/3 of the other one
    23  )
    24  
    25  type dialerTestCase struct {
    26  	addr   string
    27  	dialer SocketDialer
    28  }
    29  
    30  // TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We
    31  // don't need this for Unix sockets because the OS instantly knows the state of
    32  // both ends of the socket connection. This basically causes the
    33  // SignerDialerEndpoint.dialer() call inside SignerDialerEndpoint.acceptNewConnection() to return
    34  // successfully immediately, putting an instant stop to any retry attempts.
    35  func TestSignerRemoteRetryTCPOnly(t *testing.T) {
    36  	var (
    37  		attemptCh = make(chan int)
    38  		retries   = 10
    39  	)
    40  
    41  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    42  	require.NoError(t, err)
    43  
    44  	// Continuously Accept connection and close {attempts} times
    45  	go func(ln net.Listener, attemptCh chan<- int) {
    46  		attempts := 0
    47  		for {
    48  			conn, err := ln.Accept()
    49  			require.NoError(t, err)
    50  
    51  			err = conn.Close()
    52  			require.NoError(t, err)
    53  
    54  			attempts++
    55  
    56  			if attempts == retries {
    57  				attemptCh <- attempts
    58  				break
    59  			}
    60  		}
    61  	}(ln, attemptCh)
    62  
    63  	dialerEndpoint := NewSignerDialerEndpoint(
    64  		log.TestingLogger(),
    65  		DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()),
    66  	)
    67  	SignerDialerEndpointTimeoutReadWrite(time.Millisecond)(dialerEndpoint)
    68  	SignerDialerEndpointConnRetries(retries)(dialerEndpoint)
    69  
    70  	chainID := tmrand.Str(12)
    71  	mockPV := types.NewMockPV()
    72  	signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV)
    73  
    74  	err = signerServer.Start()
    75  	require.NoError(t, err)
    76  	t.Cleanup(func() {
    77  		if err := signerServer.Stop(); err != nil {
    78  			t.Error(err)
    79  		}
    80  	})
    81  
    82  	select {
    83  	case attempts := <-attemptCh:
    84  		assert.Equal(t, retries, attempts)
    85  	case <-time.After(1500 * time.Millisecond):
    86  		t.Error("expected remote to observe connection attempts")
    87  	}
    88  }
    89  
    90  func TestRetryConnToRemoteSigner(t *testing.T) {
    91  	for _, tc := range getDialerTestCases(t) {
    92  		var (
    93  			logger           = log.TestingLogger()
    94  			chainID          = tmrand.Str(12)
    95  			mockPV           = types.NewMockPV()
    96  			endpointIsOpenCh = make(chan struct{})
    97  			thisConnTimeout  = testTimeoutReadWrite
    98  			listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout)
    99  		)
   100  
   101  		dialerEndpoint := NewSignerDialerEndpoint(
   102  			logger,
   103  			tc.dialer,
   104  		)
   105  		SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
   106  		SignerDialerEndpointConnRetries(10)(dialerEndpoint)
   107  
   108  		signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV)
   109  
   110  		startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh)
   111  		t.Cleanup(func() {
   112  			if err := listenerEndpoint.Stop(); err != nil {
   113  				t.Error(err)
   114  			}
   115  		})
   116  
   117  		require.NoError(t, signerServer.Start())
   118  		assert.True(t, signerServer.IsRunning())
   119  		<-endpointIsOpenCh
   120  		if err := signerServer.Stop(); err != nil {
   121  			t.Error(err)
   122  		}
   123  
   124  		dialerEndpoint2 := NewSignerDialerEndpoint(
   125  			logger,
   126  			tc.dialer,
   127  		)
   128  		signerServer2 := NewSignerServer(dialerEndpoint2, chainID, mockPV)
   129  
   130  		// let some pings pass
   131  		require.NoError(t, signerServer2.Start())
   132  		assert.True(t, signerServer2.IsRunning())
   133  		t.Cleanup(func() {
   134  			if err := signerServer2.Stop(); err != nil {
   135  				t.Error(err)
   136  			}
   137  		})
   138  
   139  		// give the client some time to re-establish the conn to the remote signer
   140  		// should see sth like this in the logs:
   141  		//
   142  		// E[10016-01-10|17:12:46.128] Ping                                         err="remote signer timed out"
   143  		// I[10016-01-10|17:16:42.447] Re-created connection to remote signer       impl=SocketVal
   144  		time.Sleep(testTimeoutReadWrite * 2)
   145  	}
   146  }
   147  
   148  func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
   149  	proto, address := tmnet.ProtocolAndAddress(addr)
   150  
   151  	ln, err := net.Listen(proto, address)
   152  	logger.Info("SignerListener: Listening", "proto", proto, "address", address)
   153  	if err != nil {
   154  		panic(err)
   155  	}
   156  
   157  	var listener net.Listener
   158  
   159  	if proto == "unix" {
   160  		unixLn := NewUnixListener(ln)
   161  		UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn)
   162  		UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn)
   163  		listener = unixLn
   164  	} else {
   165  		tcpLn := NewTCPListener(ln, ed25519.GenPrivKey())
   166  		TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn)
   167  		TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn)
   168  		listener = tcpLn
   169  	}
   170  
   171  	return NewSignerListenerEndpoint(
   172  		logger,
   173  		listener,
   174  		SignerListenerEndpointTimeoutReadWrite(testTimeoutReadWrite),
   175  	)
   176  }
   177  
   178  func startListenerEndpointAsync(t *testing.T, sle *SignerListenerEndpoint, endpointIsOpenCh chan struct{}) {
   179  	go func(sle *SignerListenerEndpoint) {
   180  		require.NoError(t, sle.Start())
   181  		assert.True(t, sle.IsRunning())
   182  		close(endpointIsOpenCh)
   183  	}(sle)
   184  }
   185  
   186  func getMockEndpoints(
   187  	t *testing.T,
   188  	addr string,
   189  	socketDialer SocketDialer,
   190  ) (*SignerListenerEndpoint, *SignerDialerEndpoint) {
   191  
   192  	var (
   193  		logger           = log.TestingLogger()
   194  		endpointIsOpenCh = make(chan struct{})
   195  
   196  		dialerEndpoint = NewSignerDialerEndpoint(
   197  			logger,
   198  			socketDialer,
   199  		)
   200  
   201  		listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite)
   202  	)
   203  
   204  	SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
   205  	SignerDialerEndpointConnRetries(1e6)(dialerEndpoint)
   206  
   207  	startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh)
   208  
   209  	require.NoError(t, dialerEndpoint.Start())
   210  	assert.True(t, dialerEndpoint.IsRunning())
   211  
   212  	<-endpointIsOpenCh
   213  
   214  	return listenerEndpoint, dialerEndpoint
   215  }