trpc.group/trpc-go/trpc-go@v1.0.3/pool/connpool/checker_unix_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package connpool
    15  
    16  import (
    17  	"errors"
    18  	"io"
    19  	"net"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/require"
    25  	"trpc.group/trpc-go/trpc-go/codec"
    26  )
    27  
    28  const network = "tcp"
    29  
    30  func TestRemoteEOF(t *testing.T) {
    31  	var s server
    32  	require.Nil(t, s.init())
    33  
    34  	p := NewConnectionPool(
    35  		WithDialFunc(func(opts *DialOptions) (net.Conn, error) {
    36  			return net.Dial(opts.Network, opts.Address)
    37  		}),
    38  		WithHealthChecker(mockChecker),
    39  		WithForceClose(true))
    40  	defer closePool(t, p)
    41  
    42  	pc, err := p.Get(network, s.addr, GetOptions{CustomReader: codec.NewReader, DialTimeout: time.Second})
    43  	require.Nil(t, err)
    44  
    45  	clientConn := pc.(*PoolConn).GetRawConn()
    46  	serverConn := <-s.serverConns
    47  
    48  	require.Nil(t, serverConn.Close())
    49  	buf := make([]byte, 1)
    50  	require.Eventually(t, func() bool {
    51  		return errors.Is(checkConnErr(clientConn, buf), io.EOF)
    52  	}, time.Second, time.Millisecond)
    53  	require.Nil(t, pc.Close())
    54  }
    55  
    56  func TestUnexpectedRead(t *testing.T) {
    57  	var s server
    58  	require.Nil(t, s.init())
    59  
    60  	p := NewConnectionPool(
    61  		WithDialFunc(func(opts *DialOptions) (net.Conn, error) {
    62  			return net.Dial(opts.Network, opts.Address)
    63  		}),
    64  		WithHealthChecker(mockChecker))
    65  	defer closePool(t, p)
    66  
    67  	pc, err := p.Get(network, s.addr, GetOptions{CustomReader: codec.NewReader, DialTimeout: time.Second})
    68  	require.Nil(t, err)
    69  
    70  	clientConn := pc.(*PoolConn).GetRawConn()
    71  	serverConn := <-s.serverConns
    72  
    73  	require.Nil(t, pc.Close())
    74  	data := []byte("test")
    75  	n, err := serverConn.Write(data)
    76  	require.Nil(t, err)
    77  	require.Equal(t, len(data), n)
    78  
    79  	buf := make([]byte, 1)
    80  	require.Eventually(t, func() bool {
    81  		return strings.Contains(
    82  			checkConnErr(clientConn, buf).Error(),
    83  			ErrUnexpectedRead.Error())
    84  	}, time.Second, time.Millisecond)
    85  	require.Nil(t, serverConn.Close())
    86  }
    87  
    88  func TestEAGAIN(t *testing.T) {
    89  	var s server
    90  	require.Nil(t, s.init())
    91  
    92  	p := NewConnectionPool(
    93  		WithDialFunc(func(opts *DialOptions) (net.Conn, error) {
    94  			return net.Dial(opts.Network, opts.Address)
    95  		}),
    96  		WithHealthChecker(mockChecker),
    97  		WithForceClose(true))
    98  	defer closePool(t, p)
    99  
   100  	pc, err := p.Get(network, s.addr, GetOptions{CustomReader: codec.NewReader, DialTimeout: time.Second})
   101  	require.Nil(t, err)
   102  
   103  	clientConn := pc.(*PoolConn).GetRawConn()
   104  
   105  	buf := make([]byte, 100)
   106  	err2 := checkConnErr(clientConn, buf)
   107  	require.Nil(t, err2)
   108  
   109  	require.Nil(t, pc.Close())
   110  	require.Nil(t, (<-s.serverConns).Close())
   111  }
   112  
   113  type server struct {
   114  	serverConns chan net.Conn
   115  	addr        string
   116  }
   117  
   118  func (s *server) init() error {
   119  	s.serverConns = make(chan net.Conn)
   120  
   121  	l, err := net.Listen(network, ":0")
   122  	if err != nil {
   123  		return err
   124  	}
   125  	s.addr = l.Addr().String()
   126  
   127  	go func() {
   128  		for {
   129  			conn, err := l.Accept()
   130  			if err != nil {
   131  				panic(err)
   132  			}
   133  			s.serverConns <- conn
   134  		}
   135  	}()
   136  
   137  	return nil
   138  }