github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/p2p/connection/connection_test.go (about)

     1  package connection
     2  
     3  import (
     4  	"net"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  func createMConnection(conn net.Conn) *MConnection {
    13  	onReceive := func(chID byte, msgBytes []byte) {
    14  	}
    15  	onError := func(r interface{}) {
    16  	}
    17  	c := createMConnectionWithCallbacks(conn, onReceive, onError)
    18  	return c
    19  }
    20  
    21  func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection {
    22  	chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
    23  	c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, DefaultMConnConfig())
    24  	return c
    25  }
    26  
    27  func TestMConnectionSend(t *testing.T) {
    28  	assert, require := assert.New(t), require.New(t)
    29  
    30  	server, client := net.Pipe()
    31  	defer server.Close()
    32  	defer client.Close()
    33  
    34  	mconn := createMConnection(client)
    35  	_, err := mconn.Start()
    36  	require.Nil(err)
    37  	defer mconn.Stop()
    38  
    39  	msg := "Ant-Man"
    40  	assert.True(mconn.Send(0x01, msg))
    41  	// Note: subsequent Send/TrySend calls could pass because we are reading from
    42  	// the send queue in a separate goroutine.
    43  	server.Read(make([]byte, len(msg)))
    44  	assert.True(mconn.CanSend(0x01))
    45  
    46  	msg = "Spider-Man"
    47  	assert.True(mconn.TrySend(0x01, msg))
    48  	server.Read(make([]byte, len(msg)))
    49  
    50  	assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
    51  	assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown")
    52  }
    53  
    54  func TestMConnectionReceive(t *testing.T) {
    55  	assert, require := assert.New(t), require.New(t)
    56  
    57  	server, client := net.Pipe()
    58  	defer server.Close()
    59  	defer client.Close()
    60  
    61  	receivedCh := make(chan []byte)
    62  	errorsCh := make(chan interface{})
    63  	onReceive := func(chID byte, msgBytes []byte) {
    64  		receivedCh <- msgBytes
    65  	}
    66  	onError := func(r interface{}) {
    67  		errorsCh <- r
    68  	}
    69  	mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
    70  	_, err := mconn1.Start()
    71  	require.Nil(err)
    72  	defer mconn1.Stop()
    73  
    74  	mconn2 := createMConnection(server)
    75  	_, err = mconn2.Start()
    76  	require.Nil(err)
    77  	defer mconn2.Stop()
    78  
    79  	msg := "Cyclops"
    80  	assert.True(mconn2.Send(0x01, msg))
    81  
    82  	select {
    83  	case receivedBytes := <-receivedCh:
    84  		assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal
    85  	case err := <-errorsCh:
    86  		t.Fatalf("Expected %s, got %+v", msg, err)
    87  	case <-time.After(500 * time.Millisecond):
    88  		t.Fatalf("Did not receive %s message in 500ms", msg)
    89  	}
    90  }
    91  
    92  func TestMConnectionStopsAndReturnsError(t *testing.T) {
    93  	assert, require := assert.New(t), require.New(t)
    94  
    95  	server, client := net.Pipe()
    96  	defer server.Close()
    97  	defer client.Close()
    98  
    99  	receivedCh := make(chan []byte)
   100  	errorsCh := make(chan interface{})
   101  	onReceive := func(chID byte, msgBytes []byte) {
   102  		receivedCh <- msgBytes
   103  	}
   104  	onError := func(r interface{}) {
   105  		errorsCh <- r
   106  	}
   107  	mconn := createMConnectionWithCallbacks(client, onReceive, onError)
   108  	_, err := mconn.Start()
   109  	require.Nil(err)
   110  	defer mconn.Stop()
   111  
   112  	client.Close()
   113  
   114  	select {
   115  	case receivedBytes := <-receivedCh:
   116  		t.Fatalf("Expected error, got %v", receivedBytes)
   117  	case err := <-errorsCh:
   118  		assert.NotNil(err)
   119  		assert.False(mconn.IsRunning())
   120  	case <-time.After(500 * time.Millisecond):
   121  		t.Fatal("Did not receive error in 500ms")
   122  	}
   123  }