github.com/Finschia/ostracon@v1.1.5/privval/signer_listener_endpoint_test.go (about)

     1  package privval
     2  
     3  import (
     4  	"github.com/Finschia/ostracon/privval/internal"
     5  	"net"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/Finschia/ostracon/crypto/ed25519"
    13  	"github.com/Finschia/ostracon/libs/log"
    14  	tmnet "github.com/Finschia/ostracon/libs/net"
    15  	tmrand "github.com/Finschia/ostracon/libs/rand"
    16  	"github.com/Finschia/ostracon/types"
    17  )
    18  
    19  var (
    20  	testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second
    21  
    22  	testTimeoutReadWrite    = 1000 * time.Millisecond // increase timeout for slow test env
    23  	testTimeoutReadWrite2o3 = 60 * time.Millisecond   // 2/3 of the other one
    24  )
    25  
    26  type dialerTestCase struct {
    27  	addr   string
    28  	dialer SocketDialer
    29  }
    30  
    31  // TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We
    32  // don't need this for Unix sockets because the OS instantly knows the state of
    33  // both ends of the socket connection. This basically causes the
    34  // SignerDialerEndpoint.dialer() call inside SignerDialerEndpoint.acceptNewConnection() to return
    35  // successfully immediately, putting an instant stop to any retry attempts.
    36  func TestSignerRemoteRetryTCPOnly(t *testing.T) {
    37  	var (
    38  		attemptCh = make(chan int)
    39  		retries   = 10
    40  	)
    41  
    42  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    43  	require.NoError(t, err)
    44  
    45  	// Continuously Accept connection and close {attempts} times
    46  	go func(ln net.Listener, attemptCh chan<- int) {
    47  		attempts := 0
    48  		for {
    49  			conn, err := ln.Accept()
    50  			require.NoError(t, err)
    51  
    52  			err = conn.Close()
    53  			require.NoError(t, err)
    54  
    55  			attempts++
    56  
    57  			if attempts == retries {
    58  				attemptCh <- attempts
    59  				break
    60  			}
    61  		}
    62  	}(ln, attemptCh)
    63  
    64  	dialerEndpoint := NewSignerDialerEndpoint(
    65  		log.TestingLogger(),
    66  		DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()),
    67  	)
    68  	SignerDialerEndpointTimeoutReadWrite(time.Millisecond)(dialerEndpoint)
    69  	SignerDialerEndpointConnRetries(retries)(dialerEndpoint)
    70  
    71  	chainID := tmrand.Str(12)
    72  	mockPV := types.NewMockPV()
    73  	signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV)
    74  
    75  	err = signerServer.Start()
    76  	require.NoError(t, err)
    77  	t.Cleanup(func() {
    78  		if err := signerServer.Stop(); err != nil {
    79  			t.Error(err)
    80  		}
    81  	})
    82  
    83  	select {
    84  	case attempts := <-attemptCh:
    85  		assert.Equal(t, retries, attempts)
    86  	case <-time.After(1500 * time.Millisecond):
    87  		t.Error("expected remote to observe connection attempts")
    88  	}
    89  }
    90  
    91  func TestRetryConnToRemoteSigner(t *testing.T) {
    92  	for _, tc := range getDialerTestCases(t) {
    93  		var (
    94  			logger           = log.TestingLogger()
    95  			chainID          = tmrand.Str(12)
    96  			mockPV           = types.NewMockPV()
    97  			endpointIsOpenCh = make(chan struct{})
    98  			thisConnTimeout  = testTimeoutReadWrite
    99  			listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout)
   100  		)
   101  
   102  		dialerEndpoint := NewSignerDialerEndpoint(
   103  			logger,
   104  			tc.dialer,
   105  		)
   106  		SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
   107  		SignerDialerEndpointConnRetries(10)(dialerEndpoint)
   108  
   109  		signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV)
   110  
   111  		startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh)
   112  		t.Cleanup(func() {
   113  			if err := listenerEndpoint.Stop(); err != nil {
   114  				t.Error(err)
   115  			}
   116  		})
   117  
   118  		require.NoError(t, signerServer.Start())
   119  		assert.True(t, signerServer.IsRunning())
   120  		<-endpointIsOpenCh
   121  		if err := signerServer.Stop(); err != nil {
   122  			t.Error(err)
   123  		}
   124  
   125  		dialerEndpoint2 := NewSignerDialerEndpoint(
   126  			logger,
   127  			tc.dialer,
   128  		)
   129  		signerServer2 := NewSignerServer(dialerEndpoint2, chainID, mockPV)
   130  
   131  		// let some pings pass
   132  		require.NoError(t, signerServer2.Start())
   133  		assert.True(t, signerServer2.IsRunning())
   134  		t.Cleanup(func() {
   135  			if err := signerServer2.Stop(); err != nil {
   136  				t.Error(err)
   137  			}
   138  		})
   139  
   140  		// give the client some time to re-establish the conn to the remote signer
   141  		// should see sth like this in the logs:
   142  		//
   143  		// E[10016-01-10|17:12:46.128] Ping                                         err="remote signer timed out"
   144  		// I[10016-01-10|17:16:42.447] Re-created connection to remote signer       impl=SocketVal
   145  		time.Sleep(testTimeoutReadWrite * 2)
   146  	}
   147  }
   148  
   149  func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
   150  	proto, address := tmnet.ProtocolAndAddress(addr)
   151  
   152  	ln, err := net.Listen(proto, address)
   153  	logger.Info("SignerListener: Listening", "proto", proto, "address", address)
   154  	if err != nil {
   155  		panic(err)
   156  	}
   157  
   158  	var listener net.Listener
   159  
   160  	if proto == "unix" {
   161  		unixLn := NewUnixListener(ln)
   162  		UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn)
   163  		UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn)
   164  		listener = unixLn
   165  	} else {
   166  		tcpLn := NewTCPListener(ln, ed25519.GenPrivKey())
   167  		TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn)
   168  		TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn)
   169  		listener = tcpLn
   170  	}
   171  
   172  	return NewSignerListenerEndpoint(
   173  		logger,
   174  		listener,
   175  		SignerListenerEndpointTimeoutReadWrite(testTimeoutReadWrite),
   176  	)
   177  }
   178  
   179  func startListenerEndpointAsync(t *testing.T, sle *SignerListenerEndpoint, endpointIsOpenCh chan struct{}) {
   180  	go func(sle *SignerListenerEndpoint) {
   181  		require.NoError(t, sle.Start())
   182  		assert.True(t, sle.IsRunning())
   183  		close(endpointIsOpenCh)
   184  	}(sle)
   185  }
   186  
   187  func getMockEndpoints(
   188  	t *testing.T,
   189  	addr string,
   190  	socketDialer SocketDialer,
   191  ) (*SignerListenerEndpoint, *SignerDialerEndpoint) {
   192  
   193  	var (
   194  		logger           = log.TestingLogger()
   195  		endpointIsOpenCh = make(chan struct{})
   196  
   197  		dialerEndpoint = NewSignerDialerEndpoint(
   198  			logger,
   199  			socketDialer,
   200  		)
   201  
   202  		listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite)
   203  	)
   204  
   205  	SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
   206  	SignerDialerEndpointConnRetries(1e6)(dialerEndpoint)
   207  
   208  	startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh)
   209  
   210  	require.NoError(t, dialerEndpoint.Start())
   211  	assert.True(t, dialerEndpoint.IsRunning())
   212  
   213  	<-endpointIsOpenCh
   214  
   215  	return listenerEndpoint, dialerEndpoint
   216  }
   217  
   218  func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) {
   219  	cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", []string{"127.0.0.1"}))
   220  	_, ok := cut.connFilter.(*internal.IpFilter)
   221  	assert.True(t, ok)
   222  }
   223  
   224  func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) {
   225  	cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", []string{"don't care"}))
   226  	_, ok := cut.connFilter.(*internal.NullObject)
   227  	assert.True(t, ok)
   228  }