trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/multiplex/multiplex_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package multiplex_test
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"encoding/binary"
    23  	"errors"
    24  	"io"
    25  	"net"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/stretchr/testify/require"
    32  
    33  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    34  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    35  	"trpc.group/trpc-go/trpc-go/transport/tnet"
    36  	"trpc.group/trpc-go/trpc-go/transport/tnet/multiplex"
    37  )
    38  
    39  var (
    40  	helloworld = []byte("hello world")
    41  	reqID      uint32
    42  )
    43  
    44  var (
    45  	_ (multiplexed.FrameParser) = (*simpleFrameParser)(nil)
    46  )
    47  
    48  /*
    49  |   4 byte  |  4 byte  | bodyLen byte |
    50  |  bodyLen  |    id    |      body    |
    51  */
    52  type simpleFrameParser struct {
    53  	isParseFail bool
    54  }
    55  
    56  func (fr *simpleFrameParser) Parse(reader io.Reader) (uint32, []byte, error) {
    57  	head := make([]byte, 8)
    58  	n, err := io.ReadFull(reader, head)
    59  	if err != nil {
    60  		return 0, nil, err
    61  	}
    62  
    63  	if fr.isParseFail {
    64  		return 0, nil, errors.New("decode fail")
    65  	}
    66  
    67  	if n != 8 {
    68  		return 0, nil, errors.New("invalid read full num")
    69  	}
    70  
    71  	bodyLen := binary.BigEndian.Uint32(head[:4])
    72  	id := binary.BigEndian.Uint32(head[4:8])
    73  	body := make([]byte, int(bodyLen))
    74  
    75  	n, err = io.ReadFull(reader, body)
    76  	if err != nil {
    77  		return 0, nil, err
    78  	}
    79  
    80  	if n != int(bodyLen) {
    81  		return 0, nil, errors.New("invalid read full body")
    82  	}
    83  
    84  	return id, body, nil
    85  }
    86  
    87  func encodeFrame(id uint32, body []byte) []byte {
    88  	bodyLen := len(body)
    89  	buf := bytes.NewBuffer(make([]byte, 0, 8+bodyLen))
    90  	if err := binary.Write(buf, binary.BigEndian, uint32(bodyLen)); err != nil {
    91  		panic(err)
    92  	}
    93  	if err := binary.Write(buf, binary.BigEndian, uint32(id)); err != nil {
    94  		panic(err)
    95  	}
    96  
    97  	if _, err := buf.Write(body); err != nil {
    98  		panic(err)
    99  	}
   100  
   101  	return buf.Bytes()
   102  }
   103  
   104  func getReqID() uint32 {
   105  	return atomic.AddUint32(&reqID, 1)
   106  }
   107  
   108  func echo(c net.Conn) {
   109  	io.Copy(c, c)
   110  }
   111  
   112  func beginServer(t *testing.T, handle func(net.Conn)) (net.Addr, context.CancelFunc) {
   113  	ctx, cancel := context.WithCancel(context.Background())
   114  	addrCh := make(chan net.Addr, 1)
   115  	go func() {
   116  		l, err := net.Listen("tcp", "127.0.0.1:0")
   117  		require.Nil(t, err)
   118  		addrCh <- l.Addr()
   119  		go func() {
   120  			for {
   121  				conn, err := l.Accept()
   122  				if err != nil {
   123  					require.NotNil(t, ctx.Err())
   124  					return
   125  				}
   126  				go handle(conn)
   127  			}
   128  		}()
   129  		<-ctx.Done()
   130  		l.Close()
   131  	}()
   132  	addr := <-addrCh
   133  	return addr, cancel
   134  }
   135  
   136  func TestBasic(t *testing.T) {
   137  	addr, cancel := beginServer(t, echo)
   138  	defer cancel()
   139  
   140  	getOpts := func() (uint32, multiplexed.GetOptions) {
   141  		id := getReqID()
   142  		opts := multiplexed.NewGetOptions()
   143  		opts.WithFrameParser(&simpleFrameParser{})
   144  		opts.WithVID(id)
   145  		return id, opts
   146  	}
   147  
   148  	t.Run("Multiple Conns Concurrent Read Write", func(t *testing.T) {
   149  		pool := multiplex.NewPool(
   150  			tnet.Dial,
   151  			multiplex.WithEnableMetrics(),
   152  			multiplex.WithMaxConcurrentVirConnsPerConn(500),
   153  		)
   154  		var wg sync.WaitGroup
   155  		for i := 0; i < 100; i++ {
   156  			wg.Add(1)
   157  			go func() {
   158  				defer wg.Done()
   159  				for i := 0; i < 100; i++ {
   160  					id, opts := getOpts()
   161  					conn, err := pool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   162  					require.Nil(t, err)
   163  
   164  					err = conn.Write(encodeFrame(id, helloworld))
   165  					require.Nil(t, err)
   166  					b, err := conn.Read()
   167  					require.Nil(t, err)
   168  					require.Equal(t, helloworld, b)
   169  					conn.Close()
   170  				}
   171  			}()
   172  		}
   173  		wg.Wait()
   174  	})
   175  }
   176  
   177  func TestGetConnection(t *testing.T) {
   178  	addr, cancel := beginServer(t, echo)
   179  	defer cancel()
   180  	muxPool := multiplex.NewPool(tnet.Dial)
   181  
   182  	getOpts := func() multiplexed.GetOptions {
   183  		opts := multiplexed.NewGetOptions()
   184  		opts.WithFrameParser(&simpleFrameParser{})
   185  		opts.WithVID(getReqID())
   186  		return opts
   187  	}
   188  
   189  	t.Run("Get Once", func(t *testing.T) {
   190  		opts := getOpts()
   191  		conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   192  		require.Nil(t, err)
   193  		conn.Close()
   194  	})
   195  	t.Run("Get Multiple Succeed", func(t *testing.T) {
   196  		opts := getOpts()
   197  		conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   198  		require.Nil(t, err)
   199  		conn.Close()
   200  		localAddr := conn.LocalAddr()
   201  		for i := 0; i < 9; i++ {
   202  			opts := getOpts()
   203  			conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   204  			require.Nil(t, err)
   205  			require.Equal(t, localAddr, conn.LocalAddr())
   206  			conn.Close()
   207  		}
   208  	})
   209  	t.Run("Exceed MaxConcurrentVirConns", func(t *testing.T) {
   210  		muxPool := multiplex.NewPool(tnet.Dial, multiplex.WithMaxConcurrentVirConnsPerConn(1))
   211  
   212  		opts := getOpts()
   213  		c1, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   214  		require.Nil(t, err)
   215  		defer c1.Close()
   216  
   217  		opts = getOpts()
   218  		c2, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   219  		require.Nil(t, err)
   220  		require.NotEqual(t, c1.LocalAddr(), c2.LocalAddr())
   221  		defer c2.Close()
   222  	})
   223  	t.Run("Request ID Already Exist", func(t *testing.T) {
   224  		opts := getOpts()
   225  		c1, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   226  		require.Nil(t, err)
   227  		defer c1.Close()
   228  
   229  		_, err = muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   230  		require.Equal(t, multiplex.ErrDuplicateID, err)
   231  	})
   232  	t.Run("Empty FrameParser", func(t *testing.T) {
   233  		opts := getOpts()
   234  		opts.WithFrameParser(nil)
   235  		_, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   236  		require.Contains(t, "frame parser is not provided", err.Error())
   237  	})
   238  }
   239  
   240  func TestDial(t *testing.T) {
   241  	addr, cancel := beginServer(t, echo)
   242  	defer cancel()
   243  
   244  	getOpts := func() (context.Context, context.CancelFunc, multiplexed.GetOptions) {
   245  		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(200*time.Millisecond))
   246  		opts := multiplexed.NewGetOptions()
   247  		opts.WithFrameParser(&simpleFrameParser{})
   248  		opts.WithVID(getReqID())
   249  		return ctx, cancel, opts
   250  	}
   251  
   252  	t.Run("Dial Succeed", func(t *testing.T) {
   253  		muxPool := multiplex.NewPool(tnet.Dial)
   254  		ctx, cancel, opts := getOpts()
   255  		defer cancel()
   256  		conn, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts)
   257  		require.Nil(t, err)
   258  		conn.Close()
   259  	})
   260  	t.Run("Dial Timeout", func(t *testing.T) {
   261  		muxPool := multiplex.NewPool(func(opts *connpool.DialOptions) (net.Conn, error) {
   262  			time.Sleep(time.Second)
   263  			return nil, errors.New("dial fail")
   264  		})
   265  		ctx, cancel, opts := getOpts()
   266  		defer cancel()
   267  		_, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts)
   268  		require.Equal(t, context.DeadlineExceeded, err)
   269  	})
   270  	t.Run("Dial Error", func(t *testing.T) {
   271  		muxPool := multiplex.NewPool(func(opts *connpool.DialOptions) (net.Conn, error) {
   272  			return nil, errors.New("dial error")
   273  		})
   274  		ctx, cancel, opts := getOpts()
   275  		defer cancel()
   276  		_, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts)
   277  		require.Equal(t, errors.New("dial error"), err)
   278  	})
   279  	t.Run("Dial Gonet", func(t *testing.T) {
   280  		muxPool := multiplex.NewPool(func(opts *connpool.DialOptions) (net.Conn, error) {
   281  			return net.Dial(opts.Network, opts.Address)
   282  		})
   283  		ctx, cancel, opts := getOpts()
   284  		defer cancel()
   285  		_, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts)
   286  		require.Contains(t, "dialed connection must implements tnet.Conn", err.Error())
   287  	})
   288  }
   289  
   290  func TestClose(t *testing.T) {
   291  	muxPool := multiplex.NewPool(tnet.Dial)
   292  	getOpts := func() (uint32, multiplexed.GetOptions) {
   293  		id := getReqID()
   294  		opts := multiplexed.NewGetOptions()
   295  		opts.WithFrameParser(&simpleFrameParser{})
   296  		opts.WithVID(id)
   297  		return id, opts
   298  	}
   299  
   300  	t.Run("Server Close Conn After Accept", func(t *testing.T) {
   301  		addr, cancel := beginServer(t, func(c net.Conn) {
   302  			c.Close()
   303  		})
   304  		defer cancel()
   305  		var wg sync.WaitGroup
   306  		for i := 0; i < 1000; i++ {
   307  			wg.Add(1)
   308  			_, opts := getOpts()
   309  			go func() {
   310  				defer wg.Done()
   311  				conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   312  				if err != nil {
   313  					return
   314  				}
   315  				_, err = conn.Read()
   316  				require.Contains(t, err.Error(), multiplex.ErrConnClosed.Error())
   317  				err = conn.Write(nil)
   318  				require.Contains(t, err.Error(), multiplex.ErrConnClosed.Error())
   319  				conn.Close()
   320  			}()
   321  		}
   322  		wg.Wait()
   323  	})
   324  
   325  	t.Run("Decode Fail", func(t *testing.T) {
   326  		addr, cancel := beginServer(t, echo)
   327  		defer cancel()
   328  		// return error when decode fail.
   329  		for i := 0; i < 5; i++ {
   330  			id, opts := getOpts()
   331  			opts.WithFrameParser(&simpleFrameParser{isParseFail: true})
   332  			conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   333  			require.Nil(t, err)
   334  
   335  			err = conn.Write(encodeFrame(id, helloworld))
   336  			require.Nil(t, err)
   337  			_, err = conn.Read()
   338  			require.Contains(t, err.Error(), "decode fail")
   339  			conn.Close()
   340  		}
   341  		// return nil when decode succeed.
   342  		for i := 0; i < 5; i++ {
   343  			id, opts := getOpts()
   344  			opts.WithFrameParser(&simpleFrameParser{})
   345  			conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts)
   346  			require.Nil(t, err)
   347  
   348  			err = conn.Write(encodeFrame(id, helloworld))
   349  			require.Nil(t, err)
   350  			_, err = conn.Read()
   351  			require.Nil(t, err)
   352  			conn.Close()
   353  		}
   354  	})
   355  }