trpc.group/trpc-go/trpc-go@v1.0.3/pool/multiplexed/multiplexed_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package multiplexed
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"encoding/binary"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"log"
    24  	"math"
    25  	"net"
    26  	"strconv"
    27  	"sync"
    28  	"sync/atomic"
    29  	"testing"
    30  	"time"
    31  
    32  	"golang.org/x/sync/errgroup"
    33  	"trpc.group/trpc-go/trpc-go/codec"
    34  
    35  	"github.com/stretchr/testify/assert"
    36  	"github.com/stretchr/testify/require"
    37  	"github.com/stretchr/testify/suite"
    38  )
    39  
    40  func TestMultiplexedSuite(t *testing.T) {
    41  	suite.Run(t, &msuite{})
    42  }
    43  
    44  type msuite struct {
    45  	suite.Suite
    46  
    47  	network    string
    48  	udpNetwork string
    49  	address    string
    50  	udpAddr    string
    51  
    52  	ts *tcpServer
    53  	us *udpServer
    54  
    55  	requestID uint32
    56  }
    57  
    58  func (s *msuite) SetupSuite() {
    59  	s.ts = newTCPServer()
    60  	s.us = newUDPServer()
    61  
    62  	ctx := context.Background()
    63  	s.ts.start(ctx)
    64  	s.us.start(ctx)
    65  
    66  	s.address = s.ts.ln.Addr().String()
    67  	s.network = s.ts.ln.Addr().Network()
    68  
    69  	s.udpAddr = s.us.conn.LocalAddr().String()
    70  	s.udpNetwork = s.us.conn.LocalAddr().Network()
    71  
    72  	s.requestID = 1
    73  }
    74  
    75  func (s *msuite) TearDownSuite() {
    76  	s.ts.stop()
    77  	s.us.stop()
    78  }
    79  
    80  func (s *msuite) TearDownTest() {
    81  	// Close all the established tcp concreteConns after each test.
    82  	s.ts.closeConnections()
    83  }
    84  
    85  var errDecodeDelimited = errors.New("decode error")
    86  
    87  type lengthDelimitedFramer struct {
    88  	IsStream    bool
    89  	reader      io.Reader
    90  	decodeError bool
    91  	safe        bool
    92  }
    93  
    94  func (f *lengthDelimitedFramer) New(reader io.Reader) codec.Framer {
    95  	return &lengthDelimitedFramer{
    96  		IsStream:    f.IsStream,
    97  		reader:      reader,
    98  		decodeError: f.decodeError,
    99  		safe:        f.safe,
   100  	}
   101  }
   102  
   103  func (f *lengthDelimitedFramer) ReadFrame() ([]byte, error) {
   104  	return nil, nil
   105  }
   106  
   107  func (f *lengthDelimitedFramer) IsSafe() bool {
   108  	return f.safe
   109  }
   110  
   111  func (f *lengthDelimitedFramer) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
   112  	head := make([]byte, 8)
   113  	num, err := io.ReadFull(rc, head)
   114  	if err != nil {
   115  		return 0, nil, err
   116  	}
   117  
   118  	if f.decodeError {
   119  		return 0, nil, errDecodeDelimited
   120  	}
   121  
   122  	if num != 8 {
   123  		return 0, nil, errors.New("invalid read full num")
   124  	}
   125  
   126  	n := binary.BigEndian.Uint32(head[:4])
   127  	requestID := binary.BigEndian.Uint32(head[4:8])
   128  	body := make([]byte, int(n))
   129  
   130  	num, err = io.ReadFull(rc, body)
   131  	if err != nil {
   132  		return 0, nil, err
   133  	}
   134  
   135  	if num != int(n) {
   136  		return 0, nil, errors.New("invalid read full body")
   137  	}
   138  
   139  	if f.IsStream {
   140  		return requestID, append(head, body...), nil
   141  	}
   142  	return requestID, body, nil
   143  }
   144  
   145  type delimitedRequest struct {
   146  	requestID uint32
   147  	body      []byte
   148  }
   149  
   150  func (f *lengthDelimitedFramer) Encode(req *delimitedRequest) ([]byte, error) {
   151  	l := len(req.body)
   152  	buf := bytes.NewBuffer(make([]byte, 0, 8+l))
   153  	if err := binary.Write(buf, binary.BigEndian, uint32(l)); err != nil {
   154  		return nil, err
   155  	}
   156  	if err := binary.Write(buf, binary.BigEndian, req.requestID); err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	if err := binary.Write(buf, binary.BigEndian, req.body); err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	return buf.Bytes(), nil
   165  }
   166  
   167  func (s *msuite) TestMultiplexedDecodeErr() {
   168  	tests := []struct {
   169  		network string
   170  		address string
   171  		wantErr error
   172  	}{
   173  		{s.network, s.address, errDecodeDelimited},
   174  		{s.udpNetwork, s.udpAddr, context.DeadlineExceeded},
   175  	}
   176  
   177  	for _, tt := range tests {
   178  		id := atomic.AddUint32(&s.requestID, 1)
   179  		ld := &lengthDelimitedFramer{
   180  			decodeError: true,
   181  		}
   182  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   183  		m := New()
   184  		opts := NewGetOptions()
   185  		opts.WithVID(id)
   186  		opts.WithFrameParser(ld)
   187  		vc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts)
   188  		assert.Nil(s.T(), err)
   189  		body := []byte("hello world")
   190  		buf, err := ld.Encode(&delimitedRequest{
   191  			body:      body,
   192  			requestID: id,
   193  		})
   194  		require.Nil(s.T(), err)
   195  		require.Nil(s.T(), vc.Write(buf))
   196  		_, err = vc.Read()
   197  		assert.Equal(s.T(), err, tt.wantErr)
   198  		cancel()
   199  	}
   200  }
   201  
   202  func (s *msuite) TestMultiplexedGetConcurrent() {
   203  	count := 10
   204  	ld := &lengthDelimitedFramer{}
   205  	m := New()
   206  	tests := []struct {
   207  		network string
   208  		address string
   209  	}{
   210  		{s.network, s.address},
   211  		{s.udpNetwork, s.udpAddr},
   212  	}
   213  	for _, tt := range tests {
   214  		wg := sync.WaitGroup{}
   215  		wg.Add(count)
   216  		for i := 0; i < count; i++ {
   217  			go func(i int) {
   218  				defer wg.Done()
   219  				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   220  				id := atomic.AddUint32(&s.requestID, 1)
   221  				opts := NewGetOptions()
   222  				opts.WithVID(id)
   223  				opts.WithFrameParser(ld)
   224  				vc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts)
   225  				assert.Nil(s.T(), err)
   226  				body := []byte("hello world" + strconv.Itoa(i))
   227  				buf, err := ld.Encode(&delimitedRequest{
   228  					body:      body,
   229  					requestID: id,
   230  				})
   231  				assert.Nil(s.T(), err)
   232  				assert.Nil(s.T(), vc.Write(buf))
   233  				rsp, err := vc.Read()
   234  				assert.Nil(s.T(), err)
   235  				assert.Equal(s.T(), rsp, body)
   236  				cancel()
   237  			}(i)
   238  		}
   239  		wg.Wait()
   240  	}
   241  }
   242  
   243  func (s *msuite) TestMultiplexedGet() {
   244  	id := atomic.AddUint32(&s.requestID, 1)
   245  	ld := &lengthDelimitedFramer{}
   246  
   247  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
   248  	defer cancel()
   249  
   250  	m := New(WithConnectNumber(4), WithDropFull(true), WithQueueSize(50000))
   251  	opts := NewGetOptions()
   252  	opts.WithVID(id)
   253  	opts.WithFrameParser(ld)
   254  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   255  	assert.Nil(s.T(), err)
   256  
   257  	body := []byte("hello world")
   258  	buf, err := ld.Encode(&delimitedRequest{
   259  		body:      body,
   260  		requestID: id,
   261  	})
   262  	assert.Nil(s.T(), err)
   263  	assert.Nil(s.T(), vc.Write(buf))
   264  
   265  	rsp, err := vc.Read()
   266  	assert.Nil(s.T(), err)
   267  	assert.Equal(s.T(), rsp, body)
   268  }
   269  
   270  func (s *msuite) TestMultiplexedGetWithSafeFramer() {
   271  	id := atomic.AddUint32(&s.requestID, 1)
   272  	ld := &lengthDelimitedFramer{safe: true}
   273  
   274  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   275  	defer cancel()
   276  
   277  	m := New(WithConnectNumber(4), WithDropFull(true), WithQueueSize(50000))
   278  	opts := NewGetOptions()
   279  	opts.WithVID(id)
   280  	opts.WithFrameParser(ld)
   281  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   282  	assert.Nil(s.T(), err)
   283  
   284  	body := []byte("hello world")
   285  	buf, err := ld.Encode(&delimitedRequest{
   286  		body:      body,
   287  		requestID: id,
   288  	})
   289  	assert.Nil(s.T(), err)
   290  	assert.Nil(s.T(), vc.Write(buf))
   291  
   292  	rsp, err := vc.Read()
   293  	assert.Nil(s.T(), err)
   294  	assert.Equal(s.T(), rsp, body)
   295  }
   296  
   297  func (s *msuite) TestNoFramerParser() {
   298  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   299  	defer cancel()
   300  	m := New()
   301  	opts := NewGetOptions()
   302  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   303  	_, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   304  	assert.Equal(s.T(), err, ErrFrameParserNil)
   305  }
   306  
   307  func (s *msuite) TestContextDeadline() {
   308  	id := atomic.AddUint32(&s.requestID, 1)
   309  	ld := &lengthDelimitedFramer{}
   310  
   311  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   312  	defer cancel()
   313  
   314  	m := New()
   315  	opts := NewGetOptions()
   316  	opts.WithVID(id)
   317  	opts.WithFrameParser(ld)
   318  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   319  	assert.Nil(s.T(), err)
   320  	_, err = vc.Read()
   321  	assert.Equal(s.T(), err, context.DeadlineExceeded)
   322  	err = vc.Write([]byte("hello world"))
   323  	assert.Equal(s.T(), err, context.DeadlineExceeded)
   324  
   325  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   326  	defer cancel()
   327  	vc, err = m.GetMuxConn(ctx, s.network, s.address, opts)
   328  	assert.Nil(s.T(), err)
   329  
   330  	body := []byte("hello world")
   331  	buf, err := ld.Encode(&delimitedRequest{
   332  		body:      body,
   333  		requestID: id,
   334  	})
   335  	assert.Nil(s.T(), err)
   336  	assert.Nil(s.T(), vc.Write(buf))
   337  
   338  	rsp, err := vc.Read()
   339  	assert.Nil(s.T(), err)
   340  	assert.Equal(s.T(), rsp, body)
   341  }
   342  
   343  func (s *msuite) TestCloseConnection() {
   344  	id := atomic.AddUint32(&s.requestID, 1)
   345  	ld := &lengthDelimitedFramer{}
   346  
   347  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   348  	defer cancel()
   349  
   350  	m := New(WithConnectNumber(1))
   351  	opts := NewGetOptions()
   352  	opts.WithVID(id)
   353  	opts.WithFrameParser(ld)
   354  	_, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   355  	assert.Nil(s.T(), err)
   356  
   357  	time.Sleep(500 * time.Millisecond)
   358  	v, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address))
   359  	assert.True(s.T(), ok)
   360  	cs := v.(*Connections)
   361  	cs.conns[0].close(errors.New("fake error"), false)
   362  	_, ok = m.concreteConns.Load(makeNodeKey(s.network, s.address))
   363  	assert.False(s.T(), ok)
   364  }
   365  
   366  func (s *msuite) TestDuplicatedClose() {
   367  	id := atomic.AddUint32(&s.requestID, 1)
   368  	ld := &lengthDelimitedFramer{}
   369  
   370  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   371  	defer cancel()
   372  	m := New(WithConnectNumber(1))
   373  	opts := NewGetOptions()
   374  	opts.WithVID(id)
   375  	opts.WithFrameParser(ld)
   376  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   377  	assert.Nil(s.T(), err)
   378  
   379  	body := []byte("hello world")
   380  	buf, err := ld.Encode(&delimitedRequest{
   381  		body:      body,
   382  		requestID: id,
   383  	})
   384  	assert.Nil(s.T(), err)
   385  	assert.Nil(s.T(), vc.Write(buf))
   386  
   387  	rsp, err := vc.Read()
   388  	assert.Nil(s.T(), err)
   389  	assert.Equal(s.T(), rsp, body)
   390  
   391  	v, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address))
   392  	assert.True(s.T(), ok)
   393  	cs := v.(*Connections)
   394  	err1 := errors.New("error1")
   395  	err2 := errors.New("error2")
   396  	c := cs.conns[0]
   397  	c.close(err1, false)
   398  	c.close(err2, false)
   399  
   400  	_, err = vc.Read()
   401  	assert.Equal(s.T(), err, err1)
   402  }
   403  
   404  func (s *msuite) TestGetFail() {
   405  	ld := &lengthDelimitedFramer{}
   406  
   407  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   408  	defer cancel()
   409  
   410  	m := New()
   411  	opts := NewGetOptions()
   412  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   413  	opts.WithFrameParser(ld)
   414  	_, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   415  	assert.Nil(s.T(), err)
   416  
   417  	m.concreteConns.Store(makeNodeKey(s.network, s.address), &Connection{})
   418  	_, err = m.GetMuxConn(ctx, s.network, s.address, opts)
   419  	assert.NotNil(s.T(), err)
   420  }
   421  
   422  func (s *msuite) TestContextCancel() {
   423  	id := atomic.AddUint32(&s.requestID, 1)
   424  	ld := &lengthDelimitedFramer{}
   425  
   426  	// get with cancel.
   427  	ctx, cancel := context.WithCancel(context.Background())
   428  	cancel()
   429  	m := New()
   430  	opts := NewGetOptions()
   431  	opts.WithVID(id)
   432  	opts.WithFrameParser(ld)
   433  	_, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   434  	assert.NotNil(s.T(), err)
   435  }
   436  
   437  // test when send fails.
   438  func (s *msuite) TestSendFail() {
   439  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   440  	defer cancel()
   441  	m := New(WithDropFull(true), WithQueueSize(1))
   442  	opts := NewGetOptions()
   443  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   444  	opts.WithFrameParser(&emptyFrameParser{})
   445  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   446  	assert.Nil(s.T(), err)
   447  
   448  	body := []byte("hello world")
   449  	err = vc.Write(body)
   450  	assert.Nil(s.T(), err)
   451  	err = vc.Write(body)
   452  	assert.NotNil(s.T(), err)
   453  }
   454  
   455  func (s *msuite) TestWriteErrorCleanVirtualConnection() {
   456  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   457  	defer cancel()
   458  	m := New(WithDropFull(true), WithQueueSize(0))
   459  	opts := NewGetOptions()
   460  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   461  	opts.WithFrameParser(&emptyFrameParser{})
   462  	mc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   463  	assert.Nil(s.T(), err)
   464  	vc, ok := mc.(*VirtualConnection)
   465  	assert.True(s.T(), ok)
   466  
   467  	body := []byte("hello world")
   468  	err = vc.Write(body)
   469  	assert.NotNil(s.T(), err)
   470  	assert.Len(s.T(), vc.conn.virConns, 0)
   471  }
   472  
   473  func (s *msuite) TestReadErrorCleanVirtualConnection() {
   474  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
   475  	defer cancel()
   476  	m := New(WithDropFull(true), WithQueueSize(0))
   477  	opts := NewGetOptions()
   478  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   479  	opts.WithFrameParser(&lengthDelimitedFramer{})
   480  	mc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   481  	assert.Nil(s.T(), err)
   482  	vc, ok := mc.(*VirtualConnection)
   483  	assert.True(s.T(), ok)
   484  
   485  	time.Sleep(time.Millisecond * 100)
   486  	_, err = vc.Read()
   487  	assert.NotNil(s.T(), err)
   488  	assert.Len(s.T(), vc.conn.virConns, 0)
   489  }
   490  
   491  func (s *msuite) TestUdpMultiplexedReadTimeout() {
   492  	ld := &lengthDelimitedFramer{}
   493  
   494  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   495  	defer cancel()
   496  	m := New()
   497  	opts := NewGetOptions()
   498  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   499  	opts.WithFrameParser(ld)
   500  	vc, err := m.GetMuxConn(ctx, "udp", s.udpAddr, opts)
   501  	assert.Nil(s.T(), err)
   502  	_, err = vc.Read()
   503  	assert.Equal(s.T(), err, ctx.Err())
   504  }
   505  
   506  func (s *msuite) TestMultiplexedServerFail() {
   507  	tests := []struct {
   508  		network string
   509  		address string
   510  		exists  bool
   511  	}{
   512  		{s.network, "invalid address", false},
   513  		{s.udpNetwork, "invalid address", false},
   514  	}
   515  
   516  	for _, tt := range tests {
   517  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   518  		defer cancel()
   519  		m := New(
   520  			WithConnectNumber(1),
   521  			// On windows, it will try to use up all the timeout to do the dialling.
   522  			// So limit the dial timeout.
   523  			WithDialTimeout(time.Millisecond),
   524  		)
   525  		opts := NewGetOptions()
   526  		opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   527  		opts.WithFrameParser(&emptyFrameParser{})
   528  		_, err := m.GetMuxConn(ctx, tt.network, tt.address, opts)
   529  		s.T().Logf("m.GetMuxConn err: %+v\n", err)
   530  		// Because of possible out of order execution of goroutines,
   531  		// the error may or may not be nil.
   532  		if err != nil {
   533  			// If it is non-nil, it must be an expelled error.
   534  			require.True(s.T(), errors.Is(err, ErrConnectionsHaveBeenExpelled))
   535  		}
   536  		time.Sleep(10 * time.Millisecond)
   537  		_, ok := m.concreteConns.Load(makeNodeKey(tt.network, tt.address))
   538  		assert.Equal(s.T(), tt.exists, ok)
   539  	}
   540  }
   541  
   542  func (s *msuite) TestMultiplexedConcurrentGetInvalidAddr() {
   543  	const (
   544  		network     = "tcp"
   545  		invalidAddr = "invalid addr"
   546  	)
   547  	msg := codec.Message(context.Background())
   548  	msg.WithRequestID(atomic.AddUint32(&s.requestID, 1))
   549  
   550  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   551  	defer cancel()
   552  	m := New(WithConnectNumber(1))
   553  	opts := NewGetOptions()
   554  	opts.WithFrameParser(&emptyFrameParser{})
   555  	start := time.Now()
   556  	for n := 1; ; n++ {
   557  		if time.Since(start) > time.Second*10 {
   558  			require.FailNow(s.T(), "expected expelled error in 10s")
   559  		}
   560  		var eg errgroup.Group
   561  		for i := 0; i < n; i++ {
   562  			eg.Go(func() error {
   563  				_, err := m.GetMuxConn(ctx, network, invalidAddr, opts)
   564  				return err
   565  			})
   566  		}
   567  		if err := eg.Wait(); err != nil {
   568  			s.T().Logf("ok, m.GetMuxConn error: %+v\n", err)
   569  			break
   570  		}
   571  	}
   572  }
   573  
   574  func (s *msuite) TestWithLocalAddr() {
   575  	tests := []struct {
   576  		network string
   577  		address string
   578  	}{
   579  		{s.network, s.address},
   580  		{s.udpNetwork, s.udpAddr},
   581  	}
   582  	localAddr := "127.0.0.1"
   583  
   584  	for _, tt := range tests {
   585  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   586  		defer cancel()
   587  		m := New()
   588  		opts := NewGetOptions()
   589  		opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   590  		opts.WithLocalAddr(localAddr + ":")
   591  		ld := &lengthDelimitedFramer{}
   592  		opts.WithFrameParser(ld)
   593  		body := []byte("hello world")
   594  		buf, err := ld.Encode(&delimitedRequest{
   595  			body:      body,
   596  			requestID: s.requestID,
   597  		})
   598  		assert.Nil(s.T(), err)
   599  		mc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts)
   600  		assert.Nil(s.T(), err)
   601  		vc, ok := mc.(*VirtualConnection)
   602  		assert.True(s.T(), ok)
   603  		assert.Nil(s.T(), vc.Write(buf))
   604  		assert.Nil(s.T(), err)
   605  		_, err = vc.Read()
   606  		assert.Nil(s.T(), err)
   607  		if tt.network == s.network {
   608  			conn := vc.conn.getRawConn()
   609  			realAddr := conn.LocalAddr().(*net.TCPAddr).IP.String()
   610  			assert.Equal(s.T(), realAddr, localAddr)
   611  		} else if tt.network == s.udpNetwork {
   612  			realAddr := vc.conn.packetConn.LocalAddr().(*net.UDPAddr).IP.String()
   613  			assert.Equal(s.T(), realAddr, localAddr)
   614  		}
   615  	}
   616  }
   617  
   618  func (s *msuite) TestTCPReconnect() {
   619  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   620  	defer cancel()
   621  	m := New(WithConnectNumber(1))
   622  	opts := NewGetOptions()
   623  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   624  	ld := &lengthDelimitedFramer{}
   625  	opts.WithFrameParser(ld)
   626  	body := []byte("hello world")
   627  	buf, err := ld.Encode(&delimitedRequest{
   628  		body:      body,
   629  		requestID: s.requestID,
   630  	})
   631  	assert.Nil(s.T(), err)
   632  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   633  	assert.Nil(s.T(), err)
   634  	assert.Nil(s.T(), vc.Write(buf))
   635  	_, err = vc.Read()
   636  	assert.Nil(s.T(), err)
   637  
   638  	// close conn
   639  	val, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address))
   640  	assert.True(s.T(), ok)
   641  	c := val.(*Connections).conns[0]
   642  	conn := c.getRawConn()
   643  	conn.Close()
   644  	time.Sleep(100 * time.Millisecond)
   645  	vc, err = m.GetMuxConn(ctx, s.network, s.address, opts)
   646  	assert.Nil(s.T(), err)
   647  	assert.Nil(s.T(), vc.Write(buf))
   648  	_, err = vc.Read()
   649  	assert.Nil(s.T(), err)
   650  	_, ok = m.concreteConns.Load(makeNodeKey(s.network, s.address))
   651  	assert.True(s.T(), ok)
   652  
   653  	// timeout after reconnected
   654  	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
   655  	defer done()
   656  	vc, err = m.GetMuxConn(ctx, s.network, s.address, opts)
   657  	assert.Nil(s.T(), err)
   658  	_, err = vc.Read()
   659  	assert.ErrorIs(s.T(), err, context.DeadlineExceeded)
   660  }
   661  
   662  func (s *msuite) TestTCPReconnectMaxReconnectCount() {
   663  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   664  	defer cancel()
   665  	m := New(WithConnectNumber(1))
   666  	opts := NewGetOptions()
   667  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   668  	ld := &lengthDelimitedFramer{}
   669  	opts.WithFrameParser(ld)
   670  	_, err := m.GetMuxConn(ctx, s.network, "invalid address", opts)
   671  	assert.Nil(s.T(), err)
   672  	time.Sleep(time.Second)
   673  	_, ok := m.concreteConns.Load(makeNodeKey(s.network, "invalid address"))
   674  	assert.False(s.T(), ok)
   675  }
   676  
   677  func (s *msuite) TestStreamMultiplexd() {
   678  	id := atomic.AddUint32(&s.requestID, 1)
   679  
   680  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   681  	defer cancel()
   682  
   683  	m := New()
   684  	opts := NewGetOptions()
   685  	opts.WithVID(id)
   686  	ld := &lengthDelimitedFramer{IsStream: true}
   687  	opts.WithFrameParser(ld)
   688  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   689  	assert.Nil(s.T(), err)
   690  	assert.NotNil(s.T(), vc)
   691  
   692  	body := []byte("hello world")
   693  	buf, err := ld.Encode(&delimitedRequest{
   694  		body:      body,
   695  		requestID: id,
   696  	})
   697  	assert.Nil(s.T(), err)
   698  	assert.Nil(s.T(), vc.Write(buf))
   699  
   700  	rsp, err := vc.Read()
   701  	assert.Nil(s.T(), err)
   702  	assert.Equal(s.T(), buf, rsp)
   703  }
   704  
   705  func (s *msuite) TestStreamMultiplexd_Addr() {
   706  	streamID := atomic.AddUint32(&s.requestID, 1)
   707  
   708  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   709  	defer cancel()
   710  
   711  	m := New()
   712  	opts := NewGetOptions()
   713  	opts.WithVID(streamID)
   714  	ld := &lengthDelimitedFramer{IsStream: true}
   715  	opts.WithFrameParser(ld)
   716  	vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   717  	assert.Nil(s.T(), err)
   718  	assert.NotNil(s.T(), vc)
   719  	time.Sleep(50 * time.Millisecond)
   720  
   721  	la := vc.LocalAddr()
   722  	assert.NotNil(s.T(), la)
   723  
   724  	ra := vc.RemoteAddr()
   725  	assert.Equal(s.T(), s.address, ra.String())
   726  }
   727  
   728  func (s *msuite) TestStreamMultiplexd_MaxVirConnPerConn() {
   729  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   730  	defer cancel()
   731  
   732  	m := New(WithMaxVirConnsPerConn(4))
   733  	opts := NewGetOptions()
   734  	ld := &lengthDelimitedFramer{IsStream: true}
   735  	opts.WithFrameParser(ld)
   736  	var cs *Connections
   737  	for i := 0; i < 10; i++ {
   738  		id := atomic.AddUint32(&s.requestID, 1)
   739  		opts.WithVID(id)
   740  		vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   741  		assert.Nil(s.T(), err)
   742  		assert.NotNil(s.T(), vc)
   743  		conns, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address))
   744  		require.True(s.T(), ok)
   745  		cs, ok = conns.(*Connections)
   746  		require.True(s.T(), ok)
   747  
   748  		body := []byte("hello world")
   749  		buf, err := ld.Encode(&delimitedRequest{
   750  			body:      body,
   751  			requestID: uint32(id),
   752  		})
   753  		assert.Nil(s.T(), err)
   754  		assert.Nil(s.T(), vc.Write(buf))
   755  
   756  		rsp, err := vc.Read()
   757  		assert.Nil(s.T(), err)
   758  		assert.Equal(s.T(), buf, rsp)
   759  	}
   760  	assert.Equal(s.T(), 3, len(cs.conns))
   761  }
   762  
   763  func (s *msuite) TestStreamMultiplexd_MaxIdleConnPerHost() {
   764  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   765  	defer cancel()
   766  
   767  	m := New(WithMaxVirConnsPerConn(2), WithMaxIdleConnsPerHost(3))
   768  	opts := NewGetOptions()
   769  	ld := &lengthDelimitedFramer{IsStream: true}
   770  	opts.WithFrameParser(ld)
   771  
   772  	vcs := make([]MuxConn, 0)
   773  	for i := 0; i < 10; i++ {
   774  		id := atomic.AddUint32(&s.requestID, 1)
   775  		opts.WithVID(id)
   776  		vc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   777  		assert.Nil(s.T(), err)
   778  		vcs = append(vcs, vc)
   779  	}
   780  	conns, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address))
   781  	require.True(s.T(), ok)
   782  	cs, ok := conns.(*Connections)
   783  	require.True(s.T(), ok)
   784  	assert.Equal(s.T(), 5, len(cs.conns))
   785  	for i := 0; i < 10; i++ {
   786  		vcs[i].Close()
   787  	}
   788  	assert.Equal(s.T(), 3, len(cs.conns))
   789  }
   790  
   791  func (s *msuite) TestMultiplexedGetConcurrent_MaxIdleConnPerHost() {
   792  	count := 100
   793  	ld := &lengthDelimitedFramer{}
   794  	m := New(WithMaxVirConnsPerConn(20), WithMaxIdleConnsPerHost(2))
   795  	tests := []struct {
   796  		network string
   797  		address string
   798  	}{
   799  		{s.network, s.address},
   800  		{s.udpNetwork, s.udpAddr},
   801  	}
   802  	for _, tt := range tests {
   803  		wg := sync.WaitGroup{}
   804  		wg.Add(count)
   805  		for i := 0; i < count; i++ {
   806  			go func(i int) {
   807  				defer wg.Done()
   808  				ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   809  				id := atomic.AddUint32(&s.requestID, 1)
   810  				opts := NewGetOptions()
   811  				opts.WithVID(id)
   812  				opts.WithFrameParser(ld)
   813  				vc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts)
   814  				assert.Nil(s.T(), err)
   815  				body := []byte("hello world" + strconv.Itoa(i))
   816  				buf, err := ld.Encode(&delimitedRequest{
   817  					body:      body,
   818  					requestID: id,
   819  				})
   820  				assert.Nil(s.T(), err)
   821  				assert.Nil(s.T(), vc.Write(buf))
   822  				rsp, err := vc.Read()
   823  				assert.Nil(s.T(), err)
   824  				assert.Equal(s.T(), rsp, body)
   825  				vc.Close()
   826  				cancel()
   827  			}(i)
   828  			if i%50 == 0 {
   829  				time.Sleep(50 * time.Millisecond)
   830  			}
   831  		}
   832  		wg.Wait()
   833  	}
   834  }
   835  
   836  func (s *msuite) TestMultiplexedReconnectOnConnectError() {
   837  	ctx := context.Background()
   838  	ts := newTCPServer()
   839  	ts.start(ctx)
   840  	defer ts.stop()
   841  	m := New(
   842  		WithConnectNumber(1),
   843  		// On windows, it will try to use up all the timeout to do the dialling.
   844  		// So limit the dial timeout.
   845  		WithDialTimeout(time.Millisecond*10),
   846  	)
   847  	ctx, cancel := context.WithTimeout(ctx, time.Second)
   848  	defer cancel()
   849  	opts := NewGetOptions()
   850  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   851  	readTrigger := make(chan struct{})
   852  	readErr := make(chan error)
   853  	opts.WithFrameParser(&triggeredReadFramerBuilder{readTrigger: readTrigger, readErr: readErr})
   854  	mc, err := m.GetMuxConn(ctx, s.network, ts.ln.Addr().String(), opts)
   855  	require.Nil(s.T(), err)
   856  	vc, ok := mc.(*VirtualConnection)
   857  	assert.True(s.T(), ok)
   858  	<-readTrigger                     // Wait for the first read.
   859  	require.Nil(s.T(), ts.ln.Close()) // Then close the server.
   860  	readErr <- errAlwaysFail          // Fail the first read to trigger reconnection.
   861  	require.Eventually(s.T(),
   862  		func() bool { return maxReconnectCount+1 == vc.conn.reconnectCount },
   863  		time.Second, 10*time.Millisecond)
   864  }
   865  
   866  func (s *msuite) TestMultiplexedReconnectOnReadError() {
   867  	preInitialBackoff := initialBackoff
   868  	preMaxBackoff := maxBackoff
   869  	preMaxReconnectCount := maxReconnectCount
   870  	preResetInterval := reconnectCountResetInterval
   871  	defer func() {
   872  		initialBackoff = preInitialBackoff
   873  		maxBackoff = preMaxBackoff
   874  		maxReconnectCount = preMaxReconnectCount
   875  		reconnectCountResetInterval = preResetInterval
   876  	}()
   877  	initialBackoff = time.Microsecond
   878  	maxBackoff = 50 * time.Microsecond
   879  	maxReconnectCount = 5
   880  	reconnectCountResetInterval = time.Hour
   881  
   882  	m := New(
   883  		WithConnectNumber(1),
   884  		// On windows, it will try to use up all the timeout to do the dialling.
   885  		// So limit the dial timeout.
   886  		WithDialTimeout(time.Millisecond*10),
   887  	)
   888  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   889  	defer cancel()
   890  	opts := NewGetOptions()
   891  	calledAt := make([]time.Time, 0, maxReconnectCount)
   892  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   893  	opts.WithFrameParser(&errFramerBuilder{readFrameCalledAt: &calledAt})
   894  	mc, err := m.GetMuxConn(ctx, s.network, s.address, opts)
   895  	require.Nil(s.T(), err)
   896  	vc, ok := mc.(*VirtualConnection)
   897  	assert.True(s.T(), ok)
   898  	require.Eventually(s.T(),
   899  		func() bool { return maxReconnectCount+1 == vc.conn.reconnectCount },
   900  		3*time.Second, time.Second,
   901  		fmt.Sprintf("final status: maxReconnectCount+1=%d, vc.conn.reconnectCount=%d",
   902  			maxReconnectCount+1, vc.conn.reconnectCount))
   903  	require.Eventually(s.T(),
   904  		func() bool { return maxReconnectCount+1 == len(calledAt) },
   905  		3*time.Second, 50*time.Millisecond,
   906  		fmt.Sprintf("final status: maxReconnectCount+1=%d, len(calledAt)=%d",
   907  			maxReconnectCount+1, len(calledAt)))
   908  	var differences []float64
   909  	for i := 1; i < len(calledAt); i++ {
   910  		delay := calledAt[i].Sub(calledAt[i-1])
   911  		expectedBackoff := (initialBackoff * time.Duration(i))
   912  		s.T().Logf("calledAt delay: %2dms, expect: %2dms (between %d and %d)\n",
   913  			delay.Milliseconds(), expectedBackoff.Milliseconds(), i-1, i)
   914  		differences = append(differences, float64(delay-expectedBackoff))
   915  	}
   916  	require.Equal(s.T(), maxReconnectCount+1, len(calledAt),
   917  		"the actual times called is %d, expect %d", len(calledAt), maxReconnectCount+1)
   918  	s.T().Logf("differences: %+v", differences)
   919  	s.T().Logf("mean of differences between real retry delay and the calculated backoff: %vns", mean(differences))
   920  	ss := std(differences)
   921  	s.T().Logf("std of differences between real retry delay and the calculated backoff: %vns", ss)
   922  	const expectedStdLimit = time.Second
   923  	require.Less(s.T(), ss, float64(expectedStdLimit),
   924  		"standard deviation of differences between real retry delay and calculated backoff is expected to be within %s",
   925  		expectedStdLimit)
   926  }
   927  
   928  func (s *msuite) TestMultiplexedReconnectOnWriteError() {
   929  	ctx := context.Background()
   930  	ts := newTCPServer()
   931  	ts.start(ctx)
   932  	defer ts.stop()
   933  	m := New(
   934  		WithConnectNumber(1),
   935  		// On windows, it will try to use up all the timeout to do the dialling.
   936  		// So limit the dial timeout.
   937  		WithDialTimeout(time.Millisecond*10),
   938  	)
   939  	ctx, cancel := context.WithTimeout(ctx, time.Second)
   940  	defer cancel()
   941  	opts := NewGetOptions()
   942  	opts.WithVID(atomic.AddUint32(&s.requestID, 1))
   943  	readTrigger := make(chan struct{})
   944  	readErr := make(chan error)
   945  	opts.WithFrameParser(&triggeredReadFramerBuilder{readTrigger: readTrigger, readErr: readErr})
   946  	mc, err := m.GetMuxConn(ctx, s.network, ts.ln.Addr().String(), opts)
   947  	require.Nil(s.T(), err)
   948  	vc, ok := mc.(*VirtualConnection)
   949  	assert.True(s.T(), ok)
   950  	<-readTrigger                                    // Wait for the first read.
   951  	require.Nil(s.T(), vc.conn.getRawConn().Close()) // Now close the underlying connection.
   952  	require.Nil(s.T(), vc.Write([]byte("hello")))    // Then this write will trigger a reconnection on write error.
   953  	// Now we are cool to check that a reconnection is triggered.
   954  	require.Eventually(s.T(),
   955  		func() bool { return 1 == vc.conn.reconnectCount },
   956  		time.Second, 10*time.Millisecond)
   957  }
   958  
   959  func TestMultiplexedDestroyMayCauseGoroutineLeak(t *testing.T) {
   960  	l, err := net.Listen("tcp", ":")
   961  	require.Nil(t, err)
   962  	const connNum = 2
   963  	acceptedConns, acceptErrs := make(chan net.Conn, connNum*2), make(chan error)
   964  	var closedConns uint32
   965  	go func() {
   966  		for {
   967  			c, err := l.Accept()
   968  			if err != nil {
   969  				acceptErrs <- err
   970  				return
   971  			}
   972  			acceptedConns <- c
   973  			go func() {
   974  				_, _ = io.Copy(c, c)
   975  				atomic.AddUint32(&closedConns, 1)
   976  			}()
   977  		}
   978  	}()
   979  
   980  	fb := fixedLenFrameBuilder{packetLen: 2}
   981  	dialTimeout := time.Millisecond * 50
   982  	m := New(
   983  		WithConnectNumber(connNum),
   984  		// replace the too long default 1s dail timeout.
   985  		WithDialTimeout(dialTimeout))
   986  	getVirtualConn := func(requestID uint32) (MuxConn, error) {
   987  		getOptions := NewGetOptions()
   988  		getOptions.WithVID(requestID)
   989  		getOptions.WithFrameParser(&fb)
   990  		return m.GetMuxConn(context.Background(), l.Addr().Network(), l.Addr().String(), getOptions)
   991  	}
   992  
   993  	vc, err := getVirtualConn(1)
   994  	require.Nil(t, err)
   995  	require.Nil(t, vc.Write(fb.EncodeWithRequestID(1, []byte("1a"))))
   996  	read, err := vc.Read()
   997  	require.Nil(t, err)
   998  	require.Equal(t, []byte("1a"), read)
   999  	vc.Close()
  1000  
  1001  	var (
  1002  		c1 net.Conn
  1003  		c2 net.Conn
  1004  	)
  1005  	select {
  1006  	case c1 = <-acceptedConns:
  1007  	case <-time.After(time.Second):
  1008  		require.FailNow(t, "should accept a connection")
  1009  	}
  1010  	select {
  1011  	case c2 = <-acceptedConns:
  1012  	case <-time.After(time.Second):
  1013  		require.FailNow(t, "multiplexed should establish two concreteConns")
  1014  	}
  1015  
  1016  	require.Nil(t, l.Close())
  1017  	<-acceptErrs
  1018  	require.Nil(t, c1.Close())
  1019  	// on windows, connecting to closed listener returns an error until dial timeout, not immediately.
  1020  	// we should sleep additional dialTimeout * maxReconnectCount to wait all retry finished.
  1021  	time.Sleep((maxBackoff + dialTimeout) * time.Duration(maxReconnectCount))
  1022  	require.Equal(t, uint32(1), atomic.LoadUint32(&closedConns))
  1023  
  1024  	vc, err = getVirtualConn(2)
  1025  	require.Nil(t, err)
  1026  	require.Nil(t, vc.Write(fb.EncodeWithRequestID(2, []byte("2a"))))
  1027  	require.EqualValues(t, 1, atomic.LoadUint32(&closedConns))
  1028  	read, err = vc.Read()
  1029  	require.Nil(t, err)
  1030  	require.Equal(t, []byte("2a"), read)
  1031  	require.Nil(t, err)
  1032  	require.Nil(t, c2.Close())
  1033  }
  1034  
  1035  func mean(v []float64) float64 {
  1036  	n := len(v)
  1037  	if n == 0 {
  1038  		return 0
  1039  	}
  1040  	var res float64
  1041  	for i := 0; i < n; i++ {
  1042  		res += v[i]
  1043  	}
  1044  	return res / float64(n)
  1045  }
  1046  
  1047  func variance(v []float64) float64 {
  1048  	n := len(v)
  1049  	if n <= 1 {
  1050  		return 0
  1051  	}
  1052  	var res float64
  1053  	m := mean(v)
  1054  	for i := 0; i < n; i++ {
  1055  		res += (v[i] - m) * (v[i] - m)
  1056  	}
  1057  	return res / float64(n-1)
  1058  }
  1059  
  1060  func std(v []float64) float64 {
  1061  	return math.Sqrt(variance(v))
  1062  }
  1063  
  1064  type errFramerBuilder struct {
  1065  	readFrameCalledAt *[]time.Time
  1066  }
  1067  
  1068  func (fb *errFramerBuilder) New(io.Reader) codec.Framer {
  1069  	return &errFramer{
  1070  		calledAt: fb.readFrameCalledAt,
  1071  	}
  1072  }
  1073  
  1074  func (fb *errFramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
  1075  	*fb.readFrameCalledAt = append(*fb.readFrameCalledAt, time.Now())
  1076  	buf, err = fb.New(rc).ReadFrame()
  1077  	if err != nil {
  1078  		return 0, nil, err
  1079  	}
  1080  	return 0, buf, nil
  1081  }
  1082  
  1083  var errAlwaysFail = errors.New("always fail")
  1084  
  1085  type errFramer struct {
  1086  	calledAt *[]time.Time
  1087  }
  1088  
  1089  // ReadFrame implements codec.Framer.
  1090  func (f *errFramer) ReadFrame() ([]byte, error) {
  1091  	return nil, errAlwaysFail
  1092  }
  1093  
  1094  type triggeredReadFramerBuilder struct {
  1095  	readTrigger chan struct{}
  1096  	readErr     chan error
  1097  }
  1098  
  1099  func (fb *triggeredReadFramerBuilder) New(io.Reader) codec.Framer {
  1100  	return &triggeredReadFramer{
  1101  		readTrigger: fb.readTrigger,
  1102  		readErr:     fb.readErr,
  1103  	}
  1104  }
  1105  
  1106  func (fb *triggeredReadFramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
  1107  	buf, err = fb.New(rc).ReadFrame()
  1108  	if err != nil {
  1109  		return 0, nil, err
  1110  	}
  1111  	return 0, buf, nil
  1112  }
  1113  
  1114  type triggeredReadFramer struct {
  1115  	readTrigger chan struct{}
  1116  	readErr     chan error
  1117  }
  1118  
  1119  // ReadFrame implements codec.Framer.
  1120  func (f *triggeredReadFramer) ReadFrame() ([]byte, error) {
  1121  	f.readTrigger <- struct{}{}
  1122  	err := <-f.readErr
  1123  	return nil, err
  1124  }
  1125  
  1126  type fixedLenFrameBuilder struct {
  1127  	packetLen int
  1128  }
  1129  
  1130  func (fb *fixedLenFrameBuilder) New(r io.Reader) codec.Framer {
  1131  	return &fixedLenFramer{
  1132  		decode: fb.Decode,
  1133  		buf:    make([]byte, 4+fb.packetLen), // uint64 request id + packet len
  1134  		r:      r,
  1135  	}
  1136  }
  1137  
  1138  func (fb *fixedLenFrameBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
  1139  	buf = make([]byte, 4+fb.packetLen)
  1140  	n, err := rc.Read(buf)
  1141  	if err != nil {
  1142  		return 0, nil, err
  1143  	}
  1144  	id, bts, err := fb.Decode(buf[:n])
  1145  	if err != nil {
  1146  		return 0, nil, err
  1147  	}
  1148  	return id, bts, nil
  1149  }
  1150  
  1151  func (*fixedLenFrameBuilder) EncodeWithRequestID(id uint32, buf []byte) []byte {
  1152  	bts := make([]byte, 4+len(buf))
  1153  	binary.BigEndian.PutUint32(bts[:4], id)
  1154  	copy(bts[4:], buf)
  1155  	return bts
  1156  }
  1157  
  1158  func (*fixedLenFrameBuilder) Decode(bts []byte) (uint32, []byte, error) {
  1159  	if l := len(bts); l < 4 {
  1160  		return 0, nil, fmt.Errorf("bts len %d must not be lesser than 8, content: %q", l, bts)
  1161  	}
  1162  	return binary.BigEndian.Uint32(bts), bts[4:], nil
  1163  }
  1164  
  1165  type fixedLenFramer struct {
  1166  	decode func([]byte) (uint32, []byte, error)
  1167  	buf    []byte
  1168  	r      io.Reader
  1169  }
  1170  
  1171  func (f *fixedLenFramer) ReadFrame() ([]byte, error) {
  1172  	return nil, errors.New("should not be used by multiplexed")
  1173  }
  1174  
  1175  func newTCPServer() *tcpServer {
  1176  	return &tcpServer{}
  1177  }
  1178  
  1179  type tcpServer struct {
  1180  	cancel        context.CancelFunc
  1181  	ln            net.Listener
  1182  	concreteConns []net.Conn
  1183  }
  1184  
  1185  func (s *tcpServer) start(ctx context.Context) error {
  1186  	var err error
  1187  	s.ln, err = net.Listen("tcp", "127.0.0.1:0")
  1188  	if err != nil {
  1189  		return err
  1190  	}
  1191  	ctx, s.cancel = context.WithCancel(ctx)
  1192  	go func() {
  1193  		for {
  1194  			select {
  1195  			case <-ctx.Done():
  1196  				return
  1197  			default:
  1198  			}
  1199  			conn, err := s.ln.Accept()
  1200  			if err != nil {
  1201  				log.Println("l.Accept err: ", err)
  1202  				return
  1203  			}
  1204  			s.concreteConns = append(s.concreteConns, conn)
  1205  
  1206  			go func() {
  1207  				select {
  1208  				case <-ctx.Done():
  1209  					return
  1210  				default:
  1211  				}
  1212  				io.Copy(conn, conn)
  1213  			}()
  1214  		}
  1215  	}()
  1216  	return nil
  1217  }
  1218  
  1219  func (s *tcpServer) stop() {
  1220  	s.cancel()
  1221  	s.closeConnections()
  1222  	s.ln.Close()
  1223  }
  1224  
  1225  func (s *tcpServer) closeConnections() {
  1226  	for i := range s.concreteConns {
  1227  		s.concreteConns[i].Close()
  1228  	}
  1229  	s.concreteConns = s.concreteConns[:0]
  1230  }
  1231  
  1232  func newUDPServer() *udpServer {
  1233  	return &udpServer{}
  1234  }
  1235  
  1236  type udpServer struct {
  1237  	cancel context.CancelFunc
  1238  	conn   net.PacketConn
  1239  }
  1240  
  1241  func (s *udpServer) start(ctx context.Context) error {
  1242  	var err error
  1243  	s.conn, err = net.ListenPacket("udp", "127.0.0.1:0")
  1244  	if err != nil {
  1245  		return err
  1246  	}
  1247  	ctx, s.cancel = context.WithCancel(ctx)
  1248  	go func() {
  1249  		buf := make([]byte, 65535)
  1250  		for {
  1251  			select {
  1252  			case <-ctx.Done():
  1253  				return
  1254  			default:
  1255  			}
  1256  			n, addr, err := s.conn.ReadFrom(buf)
  1257  			if err != nil {
  1258  				log.Println("l.ReadFrom err: ", err)
  1259  				return
  1260  			}
  1261  
  1262  			s.conn.WriteTo(buf[:n], addr)
  1263  		}
  1264  	}()
  1265  	return nil
  1266  }
  1267  
  1268  func (s *udpServer) stop() {
  1269  	s.cancel()
  1270  	s.conn.Close()
  1271  }