github.com/evdatsion/aphelion-dpos-bft@v0.32.1/p2p/conn/connection_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"bytes"
     5  	"net"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/fortytw2/leaktest"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	amino "github.com/evdatsion/go-amino"
    14  	"github.com/evdatsion/aphelion-dpos-bft/libs/log"
    15  )
    16  
    17  const maxPingPongPacketSize = 1024 // bytes
    18  
    19  func createTestMConnection(conn net.Conn) *MConnection {
    20  	onReceive := func(chID byte, msgBytes []byte) {
    21  	}
    22  	onError := func(r interface{}) {
    23  	}
    24  	c := createMConnectionWithCallbacks(conn, onReceive, onError)
    25  	c.SetLogger(log.TestingLogger())
    26  	return c
    27  }
    28  
    29  func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection {
    30  	cfg := DefaultMConnConfig()
    31  	cfg.PingInterval = 90 * time.Millisecond
    32  	cfg.PongTimeout = 45 * time.Millisecond
    33  	chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
    34  	c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg)
    35  	c.SetLogger(log.TestingLogger())
    36  	return c
    37  }
    38  
    39  func TestMConnectionSendFlushStop(t *testing.T) {
    40  	server, client := NetPipe()
    41  	defer server.Close() // nolint: errcheck
    42  	defer client.Close() // nolint: errcheck
    43  
    44  	clientConn := createTestMConnection(client)
    45  	err := clientConn.Start()
    46  	require.Nil(t, err)
    47  	defer clientConn.Stop()
    48  
    49  	msg := []byte("abc")
    50  	assert.True(t, clientConn.Send(0x01, msg))
    51  
    52  	aminoMsgLength := 14
    53  
    54  	// start the reader in a new routine, so we can flush
    55  	errCh := make(chan error)
    56  	go func() {
    57  		msgB := make([]byte, aminoMsgLength)
    58  		_, err := server.Read(msgB)
    59  		if err != nil {
    60  			t.Fatal(err)
    61  		}
    62  		errCh <- err
    63  	}()
    64  
    65  	// stop the conn - it should flush all conns
    66  	clientConn.FlushStop()
    67  
    68  	timer := time.NewTimer(3 * time.Second)
    69  	select {
    70  	case <-errCh:
    71  	case <-timer.C:
    72  		t.Error("timed out waiting for msgs to be read")
    73  	}
    74  }
    75  
    76  func TestMConnectionSend(t *testing.T) {
    77  	server, client := NetPipe()
    78  	defer server.Close() // nolint: errcheck
    79  	defer client.Close() // nolint: errcheck
    80  
    81  	mconn := createTestMConnection(client)
    82  	err := mconn.Start()
    83  	require.Nil(t, err)
    84  	defer mconn.Stop()
    85  
    86  	msg := []byte("Ant-Man")
    87  	assert.True(t, mconn.Send(0x01, msg))
    88  	// Note: subsequent Send/TrySend calls could pass because we are reading from
    89  	// the send queue in a separate goroutine.
    90  	_, err = server.Read(make([]byte, len(msg)))
    91  	if err != nil {
    92  		t.Error(err)
    93  	}
    94  	assert.True(t, mconn.CanSend(0x01))
    95  
    96  	msg = []byte("Spider-Man")
    97  	assert.True(t, mconn.TrySend(0x01, msg))
    98  	_, err = server.Read(make([]byte, len(msg)))
    99  	if err != nil {
   100  		t.Error(err)
   101  	}
   102  
   103  	assert.False(t, mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
   104  	assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown")
   105  }
   106  
   107  func TestMConnectionReceive(t *testing.T) {
   108  	server, client := NetPipe()
   109  	defer server.Close() // nolint: errcheck
   110  	defer client.Close() // nolint: errcheck
   111  
   112  	receivedCh := make(chan []byte)
   113  	errorsCh := make(chan interface{})
   114  	onReceive := func(chID byte, msgBytes []byte) {
   115  		receivedCh <- msgBytes
   116  	}
   117  	onError := func(r interface{}) {
   118  		errorsCh <- r
   119  	}
   120  	mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
   121  	err := mconn1.Start()
   122  	require.Nil(t, err)
   123  	defer mconn1.Stop()
   124  
   125  	mconn2 := createTestMConnection(server)
   126  	err = mconn2.Start()
   127  	require.Nil(t, err)
   128  	defer mconn2.Stop()
   129  
   130  	msg := []byte("Cyclops")
   131  	assert.True(t, mconn2.Send(0x01, msg))
   132  
   133  	select {
   134  	case receivedBytes := <-receivedCh:
   135  		assert.Equal(t, []byte(msg), receivedBytes)
   136  	case err := <-errorsCh:
   137  		t.Fatalf("Expected %s, got %+v", msg, err)
   138  	case <-time.After(500 * time.Millisecond):
   139  		t.Fatalf("Did not receive %s message in 500ms", msg)
   140  	}
   141  }
   142  
   143  func TestMConnectionStatus(t *testing.T) {
   144  	server, client := NetPipe()
   145  	defer server.Close() // nolint: errcheck
   146  	defer client.Close() // nolint: errcheck
   147  
   148  	mconn := createTestMConnection(client)
   149  	err := mconn.Start()
   150  	require.Nil(t, err)
   151  	defer mconn.Stop()
   152  
   153  	status := mconn.Status()
   154  	assert.NotNil(t, status)
   155  	assert.Zero(t, status.Channels[0].SendQueueSize)
   156  }
   157  
   158  func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
   159  	server, client := net.Pipe()
   160  	defer server.Close()
   161  	defer client.Close()
   162  
   163  	receivedCh := make(chan []byte)
   164  	errorsCh := make(chan interface{})
   165  	onReceive := func(chID byte, msgBytes []byte) {
   166  		receivedCh <- msgBytes
   167  	}
   168  	onError := func(r interface{}) {
   169  		errorsCh <- r
   170  	}
   171  	mconn := createMConnectionWithCallbacks(client, onReceive, onError)
   172  	err := mconn.Start()
   173  	require.Nil(t, err)
   174  	defer mconn.Stop()
   175  
   176  	serverGotPing := make(chan struct{})
   177  	go func() {
   178  		// read ping
   179  		var pkt PacketPing
   180  		_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &pkt, maxPingPongPacketSize)
   181  		assert.Nil(t, err)
   182  		serverGotPing <- struct{}{}
   183  	}()
   184  	<-serverGotPing
   185  
   186  	pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
   187  	select {
   188  	case msgBytes := <-receivedCh:
   189  		t.Fatalf("Expected error, but got %v", msgBytes)
   190  	case err := <-errorsCh:
   191  		assert.NotNil(t, err)
   192  	case <-time.After(pongTimerExpired):
   193  		t.Fatalf("Expected to receive error after %v", pongTimerExpired)
   194  	}
   195  }
   196  
   197  func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
   198  	server, client := net.Pipe()
   199  	defer server.Close()
   200  	defer client.Close()
   201  
   202  	receivedCh := make(chan []byte)
   203  	errorsCh := make(chan interface{})
   204  	onReceive := func(chID byte, msgBytes []byte) {
   205  		receivedCh <- msgBytes
   206  	}
   207  	onError := func(r interface{}) {
   208  		errorsCh <- r
   209  	}
   210  	mconn := createMConnectionWithCallbacks(client, onReceive, onError)
   211  	err := mconn.Start()
   212  	require.Nil(t, err)
   213  	defer mconn.Stop()
   214  
   215  	// sending 3 pongs in a row (abuse)
   216  	_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPong{}))
   217  	require.Nil(t, err)
   218  	_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPong{}))
   219  	require.Nil(t, err)
   220  	_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPong{}))
   221  	require.Nil(t, err)
   222  
   223  	serverGotPing := make(chan struct{})
   224  	go func() {
   225  		// read ping (one byte)
   226  		var (
   227  			packet Packet
   228  			err    error
   229  		)
   230  		_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &packet, maxPingPongPacketSize)
   231  		require.Nil(t, err)
   232  		serverGotPing <- struct{}{}
   233  		// respond with pong
   234  		_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPong{}))
   235  		require.Nil(t, err)
   236  	}()
   237  	<-serverGotPing
   238  
   239  	pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
   240  	select {
   241  	case msgBytes := <-receivedCh:
   242  		t.Fatalf("Expected no data, but got %v", msgBytes)
   243  	case err := <-errorsCh:
   244  		t.Fatalf("Expected no error, but got %v", err)
   245  	case <-time.After(pongTimerExpired):
   246  		assert.True(t, mconn.IsRunning())
   247  	}
   248  }
   249  
   250  func TestMConnectionMultiplePings(t *testing.T) {
   251  	server, client := net.Pipe()
   252  	defer server.Close()
   253  	defer client.Close()
   254  
   255  	receivedCh := make(chan []byte)
   256  	errorsCh := make(chan interface{})
   257  	onReceive := func(chID byte, msgBytes []byte) {
   258  		receivedCh <- msgBytes
   259  	}
   260  	onError := func(r interface{}) {
   261  		errorsCh <- r
   262  	}
   263  	mconn := createMConnectionWithCallbacks(client, onReceive, onError)
   264  	err := mconn.Start()
   265  	require.Nil(t, err)
   266  	defer mconn.Stop()
   267  
   268  	// sending 3 pings in a row (abuse)
   269  	// see https://github.com/evdatsion/aphelion-dpos-bft/issues/1190
   270  	_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPing{}))
   271  	require.Nil(t, err)
   272  	var pkt PacketPong
   273  	_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &pkt, maxPingPongPacketSize)
   274  	require.Nil(t, err)
   275  	_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPing{}))
   276  	require.Nil(t, err)
   277  	_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &pkt, maxPingPongPacketSize)
   278  	require.Nil(t, err)
   279  	_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPing{}))
   280  	require.Nil(t, err)
   281  	_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &pkt, maxPingPongPacketSize)
   282  	require.Nil(t, err)
   283  
   284  	assert.True(t, mconn.IsRunning())
   285  }
   286  
   287  func TestMConnectionPingPongs(t *testing.T) {
   288  	// check that we are not leaking any go-routines
   289  	defer leaktest.CheckTimeout(t, 10*time.Second)()
   290  
   291  	server, client := net.Pipe()
   292  
   293  	defer server.Close()
   294  	defer client.Close()
   295  
   296  	receivedCh := make(chan []byte)
   297  	errorsCh := make(chan interface{})
   298  	onReceive := func(chID byte, msgBytes []byte) {
   299  		receivedCh <- msgBytes
   300  	}
   301  	onError := func(r interface{}) {
   302  		errorsCh <- r
   303  	}
   304  	mconn := createMConnectionWithCallbacks(client, onReceive, onError)
   305  	err := mconn.Start()
   306  	require.Nil(t, err)
   307  	defer mconn.Stop()
   308  
   309  	serverGotPing := make(chan struct{})
   310  	go func() {
   311  		// read ping
   312  		var pkt PacketPing
   313  		_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &pkt, maxPingPongPacketSize)
   314  		require.Nil(t, err)
   315  		serverGotPing <- struct{}{}
   316  		// respond with pong
   317  		_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPong{}))
   318  		require.Nil(t, err)
   319  
   320  		time.Sleep(mconn.config.PingInterval)
   321  
   322  		// read ping
   323  		_, err = cdc.UnmarshalBinaryLengthPrefixedReader(server, &pkt, maxPingPongPacketSize)
   324  		require.Nil(t, err)
   325  		// respond with pong
   326  		_, err = server.Write(cdc.MustMarshalBinaryLengthPrefixed(PacketPong{}))
   327  		require.Nil(t, err)
   328  	}()
   329  	<-serverGotPing
   330  
   331  	pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
   332  	select {
   333  	case msgBytes := <-receivedCh:
   334  		t.Fatalf("Expected no data, but got %v", msgBytes)
   335  	case err := <-errorsCh:
   336  		t.Fatalf("Expected no error, but got %v", err)
   337  	case <-time.After(2 * pongTimerExpired):
   338  		assert.True(t, mconn.IsRunning())
   339  	}
   340  }
   341  
   342  func TestMConnectionStopsAndReturnsError(t *testing.T) {
   343  	server, client := NetPipe()
   344  	defer server.Close() // nolint: errcheck
   345  	defer client.Close() // nolint: errcheck
   346  
   347  	receivedCh := make(chan []byte)
   348  	errorsCh := make(chan interface{})
   349  	onReceive := func(chID byte, msgBytes []byte) {
   350  		receivedCh <- msgBytes
   351  	}
   352  	onError := func(r interface{}) {
   353  		errorsCh <- r
   354  	}
   355  	mconn := createMConnectionWithCallbacks(client, onReceive, onError)
   356  	err := mconn.Start()
   357  	require.Nil(t, err)
   358  	defer mconn.Stop()
   359  
   360  	if err := client.Close(); err != nil {
   361  		t.Error(err)
   362  	}
   363  
   364  	select {
   365  	case receivedBytes := <-receivedCh:
   366  		t.Fatalf("Expected error, got %v", receivedBytes)
   367  	case err := <-errorsCh:
   368  		assert.NotNil(t, err)
   369  		assert.False(t, mconn.IsRunning())
   370  	case <-time.After(500 * time.Millisecond):
   371  		t.Fatal("Did not receive error in 500ms")
   372  	}
   373  }
   374  
   375  func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) (*MConnection, *MConnection) {
   376  	server, client := NetPipe()
   377  
   378  	onReceive := func(chID byte, msgBytes []byte) {}
   379  	onError := func(r interface{}) {}
   380  
   381  	// create client conn with two channels
   382  	chDescs := []*ChannelDescriptor{
   383  		{ID: 0x01, Priority: 1, SendQueueCapacity: 1},
   384  		{ID: 0x02, Priority: 1, SendQueueCapacity: 1},
   385  	}
   386  	mconnClient := NewMConnection(client, chDescs, onReceive, onError)
   387  	mconnClient.SetLogger(log.TestingLogger().With("module", "client"))
   388  	err := mconnClient.Start()
   389  	require.Nil(t, err)
   390  
   391  	// create server conn with 1 channel
   392  	// it fires on chOnErr when there's an error
   393  	serverLogger := log.TestingLogger().With("module", "server")
   394  	onError = func(r interface{}) {
   395  		chOnErr <- struct{}{}
   396  	}
   397  	mconnServer := createMConnectionWithCallbacks(server, onReceive, onError)
   398  	mconnServer.SetLogger(serverLogger)
   399  	err = mconnServer.Start()
   400  	require.Nil(t, err)
   401  	return mconnClient, mconnServer
   402  }
   403  
   404  func expectSend(ch chan struct{}) bool {
   405  	after := time.After(time.Second * 5)
   406  	select {
   407  	case <-ch:
   408  		return true
   409  	case <-after:
   410  		return false
   411  	}
   412  }
   413  
   414  func TestMConnectionReadErrorBadEncoding(t *testing.T) {
   415  	chOnErr := make(chan struct{})
   416  	mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
   417  	defer mconnClient.Stop()
   418  	defer mconnServer.Stop()
   419  
   420  	client := mconnClient.conn
   421  
   422  	// send badly encoded msgPacket
   423  	bz := cdc.MustMarshalBinaryLengthPrefixed(PacketMsg{})
   424  	bz[4] += 0x01 // Invalid prefix bytes.
   425  
   426  	// Write it.
   427  	_, err := client.Write(bz)
   428  	assert.Nil(t, err)
   429  	assert.True(t, expectSend(chOnErr), "badly encoded msgPacket")
   430  }
   431  
   432  func TestMConnectionReadErrorUnknownChannel(t *testing.T) {
   433  	chOnErr := make(chan struct{})
   434  	mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
   435  	defer mconnClient.Stop()
   436  	defer mconnServer.Stop()
   437  
   438  	msg := []byte("Ant-Man")
   439  
   440  	// fail to send msg on channel unknown by client
   441  	assert.False(t, mconnClient.Send(0x03, msg))
   442  
   443  	// send msg on channel unknown by the server.
   444  	// should cause an error
   445  	assert.True(t, mconnClient.Send(0x02, msg))
   446  	assert.True(t, expectSend(chOnErr), "unknown channel")
   447  }
   448  
   449  func TestMConnectionReadErrorLongMessage(t *testing.T) {
   450  	chOnErr := make(chan struct{})
   451  	chOnRcv := make(chan struct{})
   452  
   453  	mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
   454  	defer mconnClient.Stop()
   455  	defer mconnServer.Stop()
   456  
   457  	mconnServer.onReceive = func(chID byte, msgBytes []byte) {
   458  		chOnRcv <- struct{}{}
   459  	}
   460  
   461  	client := mconnClient.conn
   462  
   463  	// send msg thats just right
   464  	var err error
   465  	var buf = new(bytes.Buffer)
   466  	var packet = PacketMsg{
   467  		ChannelID: 0x01,
   468  		EOF:       1,
   469  		Bytes:     make([]byte, mconnClient.config.MaxPacketMsgPayloadSize),
   470  	}
   471  	_, err = cdc.MarshalBinaryLengthPrefixedWriter(buf, packet)
   472  	assert.Nil(t, err)
   473  	_, err = client.Write(buf.Bytes())
   474  	assert.Nil(t, err)
   475  	assert.True(t, expectSend(chOnRcv), "msg just right")
   476  
   477  	// send msg thats too long
   478  	buf = new(bytes.Buffer)
   479  	packet = PacketMsg{
   480  		ChannelID: 0x01,
   481  		EOF:       1,
   482  		Bytes:     make([]byte, mconnClient.config.MaxPacketMsgPayloadSize+100),
   483  	}
   484  	_, err = cdc.MarshalBinaryLengthPrefixedWriter(buf, packet)
   485  	assert.Nil(t, err)
   486  	_, err = client.Write(buf.Bytes())
   487  	assert.NotNil(t, err)
   488  	assert.True(t, expectSend(chOnErr), "msg too long")
   489  }
   490  
   491  func TestMConnectionReadErrorUnknownMsgType(t *testing.T) {
   492  	chOnErr := make(chan struct{})
   493  	mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr)
   494  	defer mconnClient.Stop()
   495  	defer mconnServer.Stop()
   496  
   497  	// send msg with unknown msg type
   498  	err := amino.EncodeUvarint(mconnClient.conn, 4)
   499  	assert.Nil(t, err)
   500  	_, err = mconnClient.conn.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF})
   501  	assert.Nil(t, err)
   502  	assert.True(t, expectSend(chOnErr), "unknown msg type")
   503  }
   504  
   505  func TestMConnectionTrySend(t *testing.T) {
   506  	server, client := NetPipe()
   507  	defer server.Close()
   508  	defer client.Close()
   509  
   510  	mconn := createTestMConnection(client)
   511  	err := mconn.Start()
   512  	require.Nil(t, err)
   513  	defer mconn.Stop()
   514  
   515  	msg := []byte("Semicolon-Woman")
   516  	resultCh := make(chan string, 2)
   517  	assert.True(t, mconn.TrySend(0x01, msg))
   518  	server.Read(make([]byte, len(msg)))
   519  	assert.True(t, mconn.CanSend(0x01))
   520  	assert.True(t, mconn.TrySend(0x01, msg))
   521  	assert.False(t, mconn.CanSend(0x01))
   522  	go func() {
   523  		mconn.TrySend(0x01, msg)
   524  		resultCh <- "TrySend"
   525  	}()
   526  	assert.False(t, mconn.CanSend(0x01))
   527  	assert.False(t, mconn.TrySend(0x01, msg))
   528  	assert.Equal(t, "TrySend", <-resultCh)
   529  }