trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/client_transport.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package tnet
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"net"
    23  
    24  	"trpc.group/trpc-go/tnet"
    25  	"trpc.group/trpc-go/tnet/tls"
    26  
    27  	"trpc.group/trpc-go/trpc-go/errs"
    28  	intertls "trpc.group/trpc-go/trpc-go/internal/tls"
    29  	"trpc.group/trpc-go/trpc-go/log"
    30  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    31  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    32  	"trpc.group/trpc-go/trpc-go/transport"
    33  	"trpc.group/trpc-go/trpc-go/transport/tnet/multiplex"
    34  )
    35  
    36  func init() {
    37  	transport.RegisterClientTransport(transportName, DefaultClientTransport)
    38  }
    39  
    40  // DefaultClientTransport is the default implementation of tnet client transport.
    41  var DefaultClientTransport = NewClientTransport()
    42  
    43  // DefaultConnPool is default connection pool used by tnet.
    44  var DefaultConnPool = connpool.NewConnectionPool(
    45  	connpool.WithDialFunc(Dial),
    46  	connpool.WithHealthChecker(HealthChecker),
    47  )
    48  
    49  // DefaultMuxPool is default muxtiplex pool used by tnet.
    50  var DefaultMuxPool = multiplex.NewPool(Dial)
    51  
    52  // NewConnectionPool creates a new connection pool. Use it instead
    53  // of connpool.NewConnectionPool when use tnet transport because
    54  // it will dial tnet connection, otherwise error will occur.
    55  func NewConnectionPool(opts ...connpool.Option) connpool.Pool {
    56  	opts = append(opts,
    57  		connpool.WithDialFunc(Dial),
    58  		connpool.WithHealthChecker(HealthChecker))
    59  	return connpool.NewConnectionPool(opts...)
    60  }
    61  
    62  // NewMuxPool creates a new multiplexing pool. Use it instead
    63  // of mux.NewPool when use tnet transport because it will dial tnet connection.
    64  func NewMuxPool(opts ...multiplex.OptPool) multiplexed.Pool {
    65  	return multiplex.NewPool(Dial, opts...)
    66  }
    67  
    68  type clientTransport struct{}
    69  
    70  // NewClientTransport creates a tnet client transport.
    71  func NewClientTransport() transport.ClientTransport {
    72  	return &clientTransport{}
    73  }
    74  
    75  // RoundTrip begins an RPC roundtrip.
    76  func (c *clientTransport) RoundTrip(
    77  	ctx context.Context,
    78  	req []byte,
    79  	opts ...transport.RoundTripOption,
    80  ) ([]byte, error) {
    81  	return c.switchNetworkToRoundTrip(ctx, req, opts...)
    82  }
    83  
    84  func (c *clientTransport) switchNetworkToRoundTrip(
    85  	ctx context.Context,
    86  	req []byte,
    87  	opts ...transport.RoundTripOption,
    88  ) ([]byte, error) {
    89  	option, err := buildRoundTripOptions(opts...)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	if err := canUseTnet(option); err != nil {
    94  		log.Error("switch to gonet default transport, ", err)
    95  		return transport.DefaultClientTransport.RoundTrip(ctx, req, opts...)
    96  	}
    97  	log.Tracef("roundtrip to:%s is using tnet transport, current number of pollers: %d",
    98  		option.Address, tnet.NumPollers())
    99  	if option.EnableMultiplexed {
   100  		return c.multiplex(ctx, req, option)
   101  	}
   102  	switch option.Network {
   103  	case "tcp", "tcp4", "tcp6":
   104  		return c.tcpRoundTrip(ctx, req, option)
   105  	default:
   106  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   107  			fmt.Sprintf("tnet client transport, doesn't support network [%s]", option.Network))
   108  	}
   109  }
   110  
   111  func buildRoundTripOptions(opts ...transport.RoundTripOption) (*transport.RoundTripOptions, error) {
   112  	rtOpts := &transport.RoundTripOptions{
   113  		Pool:        DefaultConnPool,
   114  		Multiplexed: DefaultMuxPool,
   115  	}
   116  	for _, o := range opts {
   117  		o(rtOpts)
   118  	}
   119  	if rtOpts.FramerBuilder == nil {
   120  		return nil, errs.NewFrameError(errs.RetClientConnectFail, "client transport: framer builder empty")
   121  	}
   122  	return rtOpts, nil
   123  }
   124  
   125  // Dial connects to the address on the named network.
   126  func Dial(opts *connpool.DialOptions) (net.Conn, error) {
   127  	if opts.CACertFile == "" {
   128  		conn, err := tnet.DialTCP(opts.Network, opts.Address, opts.Timeout)
   129  		if err != nil {
   130  			return nil, err
   131  		}
   132  		if err := conn.SetIdleTimeout(opts.IdleTimeout); err != nil {
   133  			return nil, err
   134  		}
   135  		return conn, nil
   136  	}
   137  	if opts.TLSServerName == "" {
   138  		opts.TLSServerName = opts.Address
   139  	}
   140  	tlsConf, err := intertls.GetClientConfig(opts.TLSServerName, opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile)
   141  	if err != nil {
   142  		return nil, errs.WrapFrameError(err, errs.RetClientDecodeFail, "client dial tnet tls fail")
   143  	}
   144  	return tls.Dial(opts.Network, opts.Address,
   145  		tls.WithClientTLSConfig(tlsConf),
   146  		tls.WithTimeout(opts.Timeout),
   147  		tls.WithClientIdleTimeout(opts.IdleTimeout),
   148  	)
   149  }
   150  
   151  // HealthChecker checks if connection healthy or not.
   152  func HealthChecker(pc *connpool.PoolConn, _ bool) bool {
   153  	c := pc.GetRawConn()
   154  	tc, ok := c.(tnet.Conn)
   155  	if !ok {
   156  		return true
   157  	}
   158  	return tc.IsActive()
   159  }
   160  
   161  func validateTnetConn(conn net.Conn) bool {
   162  	if _, ok := conn.(tnet.Conn); ok {
   163  		return true
   164  	}
   165  	pc, ok := conn.(*connpool.PoolConn)
   166  	if !ok {
   167  		return false
   168  	}
   169  	_, ok = pc.GetRawConn().(tnet.Conn)
   170  	return ok
   171  }
   172  
   173  func validateTnetTLSConn(conn net.Conn) bool {
   174  	if _, ok := conn.(tls.Conn); ok {
   175  		return true
   176  	}
   177  	pc, ok := conn.(*connpool.PoolConn)
   178  	if !ok {
   179  		return false
   180  	}
   181  	_, ok = pc.GetRawConn().(tls.Conn)
   182  	return ok
   183  }
   184  
   185  func canUseTnet(opts *transport.RoundTripOptions) error {
   186  	switch opts.Network {
   187  	case "tcp", "tcp4", "tcp6":
   188  	default:
   189  		return fmt.Errorf("tnet doesn't support network [%s]", opts.Network)
   190  	}
   191  	return nil
   192  }