trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/client_transport_tcp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package tnet
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"time"
    25  
    26  	"trpc.group/trpc-go/trpc-go/codec"
    27  	"trpc.group/trpc-go/trpc-go/errs"
    28  	"trpc.group/trpc-go/trpc-go/internal/report"
    29  	"trpc.group/trpc-go/trpc-go/log"
    30  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    31  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    32  	"trpc.group/trpc-go/trpc-go/transport"
    33  )
    34  
    35  func (c *clientTransport) tcpRoundTrip(ctx context.Context, reqData []byte,
    36  	opts *transport.RoundTripOptions) ([]byte, error) {
    37  	// Dial a TCP connection
    38  	conn, err := dialTCP(ctx, opts)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	defer conn.Close()
    43  	msg := codec.Message(ctx)
    44  	msg.WithRemoteAddr(conn.RemoteAddr())
    45  	msg.WithLocalAddr(conn.LocalAddr())
    46  
    47  	if err := checkContextErr(ctx); err != nil {
    48  		return nil, fmt.Errorf("before Write: %w", err)
    49  	}
    50  
    51  	report.TCPClientTransportSendSize.Set(float64(len(reqData)))
    52  	// Send a request.
    53  	if err := tcpWriteFrame(conn, reqData); err != nil {
    54  		return nil, err
    55  	}
    56  	// Receive a response.
    57  	return tcpReadFrame(conn, opts)
    58  }
    59  
    60  func dialTCP(ctx context.Context, opts *transport.RoundTripOptions) (net.Conn, error) {
    61  	if err := checkContextErr(ctx); err != nil {
    62  		return nil, fmt.Errorf("before tcp dial, %w", err)
    63  	}
    64  	var timeout time.Duration
    65  	d, isSetDeadline := ctx.Deadline()
    66  	if isSetDeadline {
    67  		timeout = time.Until(d)
    68  	}
    69  
    70  	var conn net.Conn
    71  	var err error
    72  	// Short connection mode, directly dial a connection.
    73  	if opts.DisableConnectionPool {
    74  		if opts.DialTimeout > 0 && opts.DialTimeout < timeout {
    75  			timeout = opts.DialTimeout
    76  		}
    77  		conn, err = Dial(&connpool.DialOptions{
    78  			Network:       opts.Network,
    79  			Address:       opts.Address,
    80  			LocalAddr:     opts.LocalAddr,
    81  			Timeout:       timeout,
    82  			CACertFile:    opts.CACertFile,
    83  			TLSCertFile:   opts.TLSCertFile,
    84  			TLSKeyFile:    opts.TLSKeyFile,
    85  			TLSServerName: opts.TLSServerName,
    86  		})
    87  		if err != nil {
    88  			return nil, errs.WrapFrameError(err, errs.RetClientConnectFail, "tcp client transport dial")
    89  		}
    90  		// Set a deadline for subsequent reading on the connection.
    91  		if isSetDeadline {
    92  			if err := conn.SetReadDeadline(d); err != nil {
    93  				log.Tracef("client SetReadDeadline failed %v", err)
    94  			}
    95  		}
    96  		return conn, nil
    97  	}
    98  
    99  	// Connection pool mode, get connection from pool.
   100  	getOpts := connpool.NewGetOptions()
   101  	getOpts.WithContext(ctx)
   102  	getOpts.WithFramerBuilder(opts.FramerBuilder)
   103  	getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName)
   104  	getOpts.WithLocalAddr(opts.LocalAddr)
   105  	getOpts.WithDialTimeout(opts.DialTimeout)
   106  	getOpts.WithProtocol(opts.Protocol)
   107  	conn, err = opts.Pool.Get(opts.Network, opts.Address, getOpts)
   108  	if err != nil {
   109  		return nil, errs.WrapFrameError(err, errs.RetClientConnectFail, "tcp client transport connection pool")
   110  	}
   111  	// The created connection must be a tnet connection.
   112  	if !validateTnetConn(conn) && !validateTnetTLSConn(conn) {
   113  		return nil, errs.NewFrameError(errs.RetClientConnectFail, "tnet transport doesn't support non tnet.Conn")
   114  	}
   115  	if err := conn.SetReadDeadline(d); err != nil {
   116  		log.Tracef("client SetReadDeadline failed %v", err)
   117  	}
   118  	return conn, nil
   119  }
   120  
   121  func tcpWriteFrame(conn net.Conn, reqData []byte) error {
   122  	// When writing data on a tnet connection, there will be no partial write success,
   123  	// only complete success or complete failure.
   124  	_, err := conn.Write(reqData)
   125  	if err != nil {
   126  		return wrapNetError("tcp client tnet transport Write", err)
   127  	}
   128  	return nil
   129  }
   130  
   131  func tcpReadFrame(conn net.Conn, opts *transport.RoundTripOptions) ([]byte, error) {
   132  	if opts.ReqType == transport.SendOnly {
   133  		return nil, errs.ErrClientNoResponse
   134  	}
   135  
   136  	var fr codec.Framer
   137  	// The connection retrieved from the connection pool has already implemented the Framer interface.
   138  	if opts.DisableConnectionPool {
   139  		fr = opts.FramerBuilder.New(codec.NewReader(conn))
   140  	} else {
   141  		var ok bool
   142  		fr, ok = conn.(codec.Framer)
   143  		if !ok {
   144  			return nil, errs.NewFrameError(errs.RetClientConnectFail,
   145  				"tcp client transport: framer not implemented")
   146  		}
   147  	}
   148  
   149  	rspData, err := fr.ReadFrame()
   150  	if err != nil {
   151  		return nil, wrapNetError("tcp client transport ReadFrame", err)
   152  	}
   153  	report.TCPClientTransportReceiveSize.Set(float64(len(rspData)))
   154  	return rspData, nil
   155  }
   156  
   157  func wrapNetError(msg string, err error) error {
   158  	if err == nil {
   159  		return nil
   160  	}
   161  	if e, ok := err.(net.Error); ok && e.Timeout() {
   162  		return errs.WrapFrameError(err, errs.RetClientTimeout, msg)
   163  	}
   164  	return errs.WrapFrameError(err, errs.RetClientNetErr, msg)
   165  }
   166  
   167  func checkContextErr(ctx context.Context) error {
   168  	if errors.Is(ctx.Err(), context.Canceled) {
   169  		return errs.WrapFrameError(ctx.Err(), errs.RetClientCanceled, "client canceled")
   170  	}
   171  	if errors.Is(ctx.Err(), context.DeadlineExceeded) {
   172  		return errs.WrapFrameError(ctx.Err(), errs.RetClientTimeout, "client timeout")
   173  	}
   174  	return nil
   175  }
   176  func (c *clientTransport) multiplex(ctx context.Context, req []byte, opts *transport.RoundTripOptions) ([]byte, error) {
   177  	getOpts := multiplexed.NewGetOptions()
   178  	getOpts.WithVID(opts.Msg.RequestID())
   179  	fp, ok := opts.FramerBuilder.(multiplexed.FrameParser)
   180  	if !ok {
   181  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   182  			"frame builder does not implement multiplexed.FrameParser")
   183  	}
   184  	getOpts.WithFrameParser(fp)
   185  	getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName)
   186  	getOpts.WithLocalAddr(opts.LocalAddr)
   187  	conn, err := opts.Multiplexed.GetMuxConn(ctx, opts.Network, opts.Address, getOpts)
   188  	if err != nil {
   189  		return nil, errs.WrapFrameError(err, errs.RetClientNetErr, "tcp client get multiplex connection failed")
   190  	}
   191  	defer conn.Close()
   192  	msg := codec.Message(ctx)
   193  	msg.WithRemoteAddr(conn.RemoteAddr())
   194  
   195  	if err := conn.Write(req); err != nil {
   196  		return nil, errs.WrapFrameError(err, errs.RetClientNetErr, "tcp client multiplex write failed")
   197  	}
   198  
   199  	// no need to receive response when request type is SendOnly.
   200  	if opts.ReqType == codec.SendOnly {
   201  		return nil, errs.ErrClientNoResponse
   202  	}
   203  
   204  	buf, err := conn.Read()
   205  	if err != nil {
   206  		if err == context.Canceled {
   207  			return nil, errs.NewFrameError(errs.RetClientCanceled,
   208  				"tcp tnet multiplexed ReadFrame: "+err.Error())
   209  		}
   210  		if err == context.DeadlineExceeded {
   211  			return nil, errs.NewFrameError(errs.RetClientTimeout,
   212  				"tcp tnet multiplexed ReadFrame: "+err.Error())
   213  		}
   214  		return nil, errs.NewFrameError(errs.RetClientNetErr,
   215  			"tcp tnet multiplexed ReadFrame: "+err.Error())
   216  	}
   217  	return buf, nil
   218  }