trpc.group/trpc-go/trpc-go@v1.0.3/transport/server_transport_stream_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport_test
    15  
    16  import (
    17  	"context"
    18  	"encoding/binary"
    19  	"encoding/json"
    20  	"net"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  
    26  	_ "trpc.group/trpc-go/trpc-go"
    27  	"trpc.group/trpc-go/trpc-go/codec"
    28  	"trpc.group/trpc-go/trpc-go/transport"
    29  )
    30  
    31  // TestStreamTCPListenAndServe tests listen and send.
    32  func TestStreamTCPListenAndServe(t *testing.T) {
    33  	st := transport.NewServerStreamTransport()
    34  	go func() {
    35  		err := st.ListenAndServe(context.Background(),
    36  			transport.WithListenNetwork("tcp"),
    37  			transport.WithListenAddress(":12013"),
    38  			transport.WithHandler(&echoHandler{}),
    39  			transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}),
    40  		)
    41  		if err != nil {
    42  			t.Logf("ListenAndServe fail:%v", err)
    43  		}
    44  	}()
    45  
    46  	ctx, f := context.WithTimeout(context.Background(), 200*time.Millisecond)
    47  	defer f()
    48  	req := &helloRequest{
    49  		Name: "trpc",
    50  		Msg:  "HelloWorld",
    51  	}
    52  
    53  	data, err := json.Marshal(req)
    54  	if err != nil {
    55  		t.Fatalf("json marshal fail:%v", err)
    56  	}
    57  	headData := make([]byte, 8)
    58  	binary.BigEndian.PutUint32(headData[:4], defaultStreamID)
    59  	binary.BigEndian.PutUint32(headData[4:8], uint32(len(data)))
    60  	reqData := append(headData, data...)
    61  
    62  	ctx, msg := codec.WithNewMessage(ctx)
    63  	msg.WithStreamID(defaultStreamID)
    64  
    65  	time.Sleep(time.Millisecond * 20)
    66  	ct := transport.NewClientStreamTransport()
    67  	err = ct.Init(ctx, transport.WithDialNetwork("tcp"), transport.WithDialAddress(":12013"),
    68  		transport.WithClientFramerBuilder(&multiplexedFramerBuilder{}),
    69  		transport.WithMsg(msg))
    70  	assert.Nil(t, err)
    71  
    72  	err = ct.Send(ctx, reqData)
    73  	assert.Nil(t, err)
    74  	err = st.Send(ctx, reqData)
    75  	assert.NotNil(t, err)
    76  
    77  	rsp, err := ct.Recv(ctx)
    78  	assert.Nil(t, err)
    79  	assert.NotNil(t, rsp)
    80  	ct.Close(ctx)
    81  	err = ct.Send(ctx, reqData)
    82  	assert.NotNil(t, err)
    83  
    84  }
    85  
    86  // TestStreamTCPListenAndServeFail tests listen and send failures.
    87  func TestStreamTCPListenAndServeFail(t *testing.T) {
    88  	st := transport.NewServerStreamTransport()
    89  	go func() {
    90  		err := st.ListenAndServe(context.Background(),
    91  			transport.WithListenNetwork("tcp"),
    92  			transport.WithListenAddress(":12014"),
    93  			transport.WithHandler(&echoHandler{}),
    94  			transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}),
    95  		)
    96  		if err != nil {
    97  			t.Logf("ListenAndServe fail:%v", err)
    98  		}
    99  	}()
   100  
   101  	ctx, f := context.WithTimeout(context.Background(), 200*time.Millisecond)
   102  	defer f()
   103  	req := &helloRequest{
   104  		Name: "trpc",
   105  		Msg:  "HelloWorld",
   106  	}
   107  
   108  	data, err := json.Marshal(req)
   109  	if err != nil {
   110  		t.Fatalf("json marshal fail:%v", err)
   111  	}
   112  	headData := make([]byte, 8)
   113  	binary.BigEndian.PutUint32(headData[:4], defaultStreamID)
   114  	binary.BigEndian.PutUint32(headData[4:8], uint32(len(data)))
   115  	reqData := append(headData, data...)
   116  
   117  	ctx, msg := codec.WithNewMessage(ctx)
   118  	msg.WithStreamID(defaultStreamID)
   119  
   120  	time.Sleep(time.Millisecond * 20)
   121  	ct := transport.NewClientStreamTransport()
   122  	err = ct.Init(ctx, transport.WithDialNetwork("tcp"), transport.WithDialAddress(":12015"),
   123  		transport.WithClientFramerBuilder(&multiplexedFramerBuilder{}))
   124  	assert.NotNil(t, err)
   125  	err = ct.Send(ctx, reqData)
   126  	assert.NotNil(t, err)
   127  	_, err = ct.Recv(ctx)
   128  	assert.NotNil(t, err)
   129  	ct.Close(ctx)
   130  
   131  	// Test opts pool is nil.
   132  	err = ct.Init(ctx, transport.WithDialPool(nil))
   133  	assert.NotNil(t, err)
   134  
   135  	// Test frame builder is nil.
   136  	err = ct.Init(ctx)
   137  	assert.NotNil(t, err)
   138  
   139  	// test context.
   140  	ct = transport.NewClientStreamTransport()
   141  	err = ct.Init(ctx, transport.WithDialNetwork("tcp"), transport.WithDialAddress(":12014"),
   142  		transport.WithClientFramerBuilder(&multiplexedFramerBuilder{}))
   143  	assert.NotNil(t, err)
   144  
   145  	ctx = context.Background()
   146  	ctx, msg = codec.WithNewMessage(ctx)
   147  	msg.WithStreamID(defaultStreamID)
   148  	ctx, cancel := context.WithCancel(ctx)
   149  	go func() {
   150  		time.Sleep(10 * time.Millisecond)
   151  		cancel()
   152  	}()
   153  	_, err = ct.Recv(ctx)
   154  	// type:framework, code:161, msg:tcp client transport canceled before Write: context canceled
   155  	assert.NotNil(t, err)
   156  
   157  	ctx = context.Background()
   158  	ctx, msg = codec.WithNewMessage(ctx)
   159  	msg.WithStreamID(defaultStreamID)
   160  	ctx, cancel = context.WithTimeout(ctx, 50*time.Millisecond)
   161  	defer cancel()
   162  	_, err = ct.Recv(ctx)
   163  	// type:framework, code:101, msg:tcp client transport timeout before Write: context deadline exceeded
   164  	assert.NotNil(t, err)
   165  
   166  }
   167  
   168  // TestStreamTCPListenAndServeSend tests listen and send failures.
   169  func TestStreamTCPListenAndServeSend(t *testing.T) {
   170  	lnAddr := "127.0.0.1:12016"
   171  	st := transport.NewServerStreamTransport()
   172  	go func() {
   173  		err := st.ListenAndServe(context.Background(),
   174  			transport.WithListenNetwork("tcp"),
   175  			transport.WithListenAddress(lnAddr),
   176  			transport.WithHandler(&echoStreamHandler{}),
   177  			transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}),
   178  		)
   179  		if err != nil {
   180  			t.Logf("ListenAndServe fail:%v", err)
   181  		}
   182  	}()
   183  	time.Sleep(20 * time.Millisecond)
   184  	req := &helloRequest{
   185  		Name: "trpc",
   186  		Msg:  "HelloWorld",
   187  	}
   188  
   189  	data, err := json.Marshal(req)
   190  	if err != nil {
   191  		t.Fatalf("json marshal fail:%v", err)
   192  	}
   193  	headData := make([]byte, 8)
   194  	binary.BigEndian.PutUint32(headData[:4], defaultStreamID)
   195  	binary.BigEndian.PutUint32(headData[4:8], uint32(len(data)))
   196  	reqData := append(headData, data...)
   197  
   198  	ctx := context.Background()
   199  	ctx, msg := codec.WithNewMessage(ctx)
   200  	msg.WithStreamID(defaultStreamID)
   201  	fb := &multiplexedFramerBuilder{}
   202  
   203  	// Test IO EOF.
   204  	port := getFreeAddr("tcp")
   205  	la := "127.0.0.1" + port
   206  	ct := transport.NewClientStreamTransport()
   207  	err = ct.Init(ctx, transport.WithDialNetwork("tcp"), transport.WithDialAddress(lnAddr),
   208  		transport.WithClientFramerBuilder(fb), transport.WithMsg(msg), transport.WithLocalAddr(la))
   209  	assert.Nil(t, err)
   210  	time.Sleep(100 * time.Millisecond)
   211  	raddr, err := net.ResolveTCPAddr("tcp", la)
   212  	assert.Nil(t, err)
   213  	laddr, err := net.ResolveTCPAddr("tcp", lnAddr)
   214  	assert.Nil(t, err)
   215  	msg.WithRemoteAddr(raddr)
   216  	msg.WithLocalAddr(laddr)
   217  	err = st.Send(ctx, reqData)
   218  	assert.Nil(t, err)
   219  }