trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_tcp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport
    15  
    16  import (
    17  	"context"
    18  	"net"
    19  	"time"
    20  
    21  	"trpc.group/trpc-go/trpc-go/codec"
    22  	"trpc.group/trpc-go/trpc-go/errs"
    23  	"trpc.group/trpc-go/trpc-go/internal/report"
    24  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    25  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    26  	"trpc.group/trpc-go/trpc-go/rpcz"
    27  )
    28  
    29  // tcpRoundTrip sends tcp request. It supports send, sendAndRcv, keepalive and multiplex.
    30  func (c *clientTransport) tcpRoundTrip(ctx context.Context, reqData []byte,
    31  	opts *RoundTripOptions) ([]byte, error) {
    32  	if opts.Pool == nil {
    33  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
    34  			"tcp client transport: connection pool empty")
    35  	}
    36  
    37  	if opts.FramerBuilder == nil {
    38  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
    39  			"tcp client transport: framer builder empty")
    40  	}
    41  
    42  	conn, err := c.dialTCP(ctx, opts)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	// TCP connection is exclusively multiplexed. Close determines whether connection should be put
    47  	// back into the connection pool to be reused.
    48  	defer conn.Close()
    49  	msg := codec.Message(ctx)
    50  	msg.WithRemoteAddr(conn.RemoteAddr())
    51  	msg.WithLocalAddr(conn.LocalAddr())
    52  
    53  	if ctx.Err() == context.Canceled {
    54  		return nil, errs.NewFrameError(errs.RetClientCanceled,
    55  			"tcp client transport canceled before Write: "+ctx.Err().Error())
    56  	}
    57  	if ctx.Err() == context.DeadlineExceeded {
    58  		return nil, errs.NewFrameError(errs.RetClientTimeout,
    59  			"tcp client transport timeout before Write: "+ctx.Err().Error())
    60  	}
    61  
    62  	report.TCPClientTransportSendSize.Set(float64(len(reqData)))
    63  	span := rpcz.SpanFromContext(ctx)
    64  	_, end := span.NewChild("SendMessage")
    65  	err = c.tcpWriteFrame(ctx, conn, reqData)
    66  	end.End()
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	_, end = span.NewChild("ReceiveMessage")
    72  	rspData, err := c.tcpReadFrame(conn, opts)
    73  	end.End()
    74  	return rspData, err
    75  }
    76  
    77  // dialTCP establishes a TCP connection.
    78  func (c *clientTransport) dialTCP(ctx context.Context, opts *RoundTripOptions) (net.Conn, error) {
    79  	// If ctx has canceled or timeout, just return.
    80  	if ctx.Err() == context.Canceled {
    81  		return nil, errs.NewFrameError(errs.RetClientCanceled,
    82  			"client canceled before tcp dial: "+ctx.Err().Error())
    83  	}
    84  	if ctx.Err() == context.DeadlineExceeded {
    85  		return nil, errs.NewFrameError(errs.RetClientTimeout,
    86  			"client timeout before tcp dial: "+ctx.Err().Error())
    87  	}
    88  	var timeout time.Duration
    89  	d, ok := ctx.Deadline()
    90  	if ok {
    91  		timeout = time.Until(d)
    92  	}
    93  
    94  	var conn net.Conn
    95  	var err error
    96  	// Short connection mode, directly dial a connection.
    97  	if opts.DisableConnectionPool {
    98  		// The connection is established using the minimum of ctx timeout and connecting timeout.
    99  		if opts.DialTimeout > 0 && opts.DialTimeout < timeout {
   100  			timeout = opts.DialTimeout
   101  		}
   102  		conn, err = connpool.Dial(&connpool.DialOptions{
   103  			Network:       opts.Network,
   104  			Address:       opts.Address,
   105  			LocalAddr:     opts.LocalAddr,
   106  			Timeout:       timeout,
   107  			CACertFile:    opts.CACertFile,
   108  			TLSCertFile:   opts.TLSCertFile,
   109  			TLSKeyFile:    opts.TLSKeyFile,
   110  			TLSServerName: opts.TLSServerName,
   111  		})
   112  		if err != nil {
   113  			return nil, errs.NewFrameError(errs.RetClientConnectFail,
   114  				"tcp client transport dial: "+err.Error())
   115  		}
   116  		if ok {
   117  			conn.SetDeadline(d)
   118  		}
   119  		return conn, nil
   120  	}
   121  
   122  	// Connection pool mode, get connection from pool.
   123  	getOpts := connpool.NewGetOptions()
   124  	getOpts.WithContext(ctx)
   125  	getOpts.WithFramerBuilder(opts.FramerBuilder)
   126  	getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName)
   127  	getOpts.WithLocalAddr(opts.LocalAddr)
   128  	getOpts.WithDialTimeout(opts.DialTimeout)
   129  	getOpts.WithProtocol(opts.Protocol)
   130  	conn, err = opts.Pool.Get(opts.Network, opts.Address, getOpts)
   131  	if err != nil {
   132  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   133  			"tcp client transport connection pool: "+err.Error())
   134  	}
   135  	if ok {
   136  		conn.SetDeadline(d)
   137  	}
   138  	return conn, nil
   139  }
   140  
   141  // tcpWriteReqData writes the tcp frame.
   142  func (c *clientTransport) tcpWriteFrame(ctx context.Context, conn net.Conn, reqData []byte) error {
   143  	// Send package in a loop.
   144  	sentNum := 0
   145  	num := 0
   146  	var err error
   147  	for sentNum < len(reqData) {
   148  		num, err = conn.Write(reqData[sentNum:])
   149  		if err != nil {
   150  			if e, ok := err.(net.Error); ok && e.Timeout() {
   151  				return errs.NewFrameError(errs.RetClientTimeout,
   152  					"tcp client transport Write: "+err.Error())
   153  			}
   154  			return errs.NewFrameError(errs.RetClientNetErr,
   155  				"tcp client transport Write: "+err.Error())
   156  		}
   157  		sentNum += num
   158  	}
   159  	return nil
   160  }
   161  
   162  // tcpReadFrame reads the tcp frame.
   163  func (c *clientTransport) tcpReadFrame(conn net.Conn, opts *RoundTripOptions) ([]byte, error) {
   164  	// send only.
   165  	if opts.ReqType == SendOnly {
   166  		return nil, errs.ErrClientNoResponse
   167  	}
   168  
   169  	var fr codec.Framer
   170  	if opts.DisableConnectionPool {
   171  		// Do not create new Framer for each connection in connection pool.
   172  		fr = opts.FramerBuilder.New(codec.NewReader(conn))
   173  	} else {
   174  		// The Framer is bound to conn in the connection pool.
   175  		var ok bool
   176  		fr, ok = conn.(codec.Framer)
   177  		if !ok {
   178  			return nil, errs.NewFrameError(errs.RetClientConnectFail,
   179  				"tcp client transport: framer not implemented")
   180  		}
   181  	}
   182  
   183  	rspData, err := fr.ReadFrame()
   184  	if err != nil {
   185  		if e, ok := err.(net.Error); ok && e.Timeout() {
   186  			return nil, errs.NewFrameError(errs.RetClientTimeout,
   187  				"tcp client transport ReadFrame: "+err.Error())
   188  		}
   189  		return nil, errs.NewFrameError(errs.RetClientReadFrameErr,
   190  			"tcp client transport ReadFrame: "+err.Error())
   191  	}
   192  	report.TCPClientTransportReceiveSize.Set(float64(len(rspData)))
   193  	return rspData, nil
   194  }
   195  
   196  // multiplexed handle multiplexed request.
   197  func (c *clientTransport) multiplexed(ctx context.Context, req []byte, opts *RoundTripOptions) ([]byte, error) {
   198  	if opts.FramerBuilder == nil {
   199  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   200  			"tcp client transport: framer builder empty")
   201  	}
   202  	getOpts := multiplexed.NewGetOptions()
   203  	getOpts.WithVID(opts.Msg.RequestID())
   204  	fp, ok := opts.FramerBuilder.(multiplexed.FrameParser)
   205  	if !ok {
   206  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
   207  			"frame builder does not implement multiplexed.FrameParser")
   208  	}
   209  	getOpts.WithFrameParser(fp)
   210  	getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName)
   211  	getOpts.WithLocalAddr(opts.LocalAddr)
   212  	conn, err := opts.Multiplexed.GetMuxConn(ctx, opts.Network, opts.Address, getOpts)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	defer conn.Close()
   217  	msg := codec.Message(ctx)
   218  	msg.WithRemoteAddr(conn.RemoteAddr())
   219  
   220  	if err := conn.Write(req); err != nil {
   221  		return nil, errs.NewFrameError(errs.RetClientNetErr,
   222  			"tcp client multiplexed transport Write: "+err.Error())
   223  	}
   224  
   225  	// SendOnly does not need to read response.
   226  	if opts.ReqType == codec.SendOnly {
   227  		return nil, errs.ErrClientNoResponse
   228  	}
   229  
   230  	buf, err := conn.Read()
   231  	if err != nil {
   232  		if err == context.Canceled {
   233  			return nil, errs.NewFrameError(errs.RetClientCanceled,
   234  				"tcp client multiplexed transport ReadFrame: "+err.Error())
   235  		}
   236  		if err == context.DeadlineExceeded {
   237  			return nil, errs.NewFrameError(errs.RetClientTimeout,
   238  				"tcp client multiplexed transport ReadFrame: "+err.Error())
   239  		}
   240  		return nil, errs.NewFrameError(errs.RetClientNetErr,
   241  			"tcp client multiplexed transport ReadFrame: "+err.Error())
   242  	}
   243  	return buf, nil
   244  }