trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/client_transport_tcp.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 //go:build linux || freebsd || dragonfly || darwin 15 // +build linux freebsd dragonfly darwin 16 17 package tnet 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "net" 24 "time" 25 26 "trpc.group/trpc-go/trpc-go/codec" 27 "trpc.group/trpc-go/trpc-go/errs" 28 "trpc.group/trpc-go/trpc-go/internal/report" 29 "trpc.group/trpc-go/trpc-go/log" 30 "trpc.group/trpc-go/trpc-go/pool/connpool" 31 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 32 "trpc.group/trpc-go/trpc-go/transport" 33 ) 34 35 func (c *clientTransport) tcpRoundTrip(ctx context.Context, reqData []byte, 36 opts *transport.RoundTripOptions) ([]byte, error) { 37 // Dial a TCP connection 38 conn, err := dialTCP(ctx, opts) 39 if err != nil { 40 return nil, err 41 } 42 defer conn.Close() 43 msg := codec.Message(ctx) 44 msg.WithRemoteAddr(conn.RemoteAddr()) 45 msg.WithLocalAddr(conn.LocalAddr()) 46 47 if err := checkContextErr(ctx); err != nil { 48 return nil, fmt.Errorf("before Write: %w", err) 49 } 50 51 report.TCPClientTransportSendSize.Set(float64(len(reqData))) 52 // Send a request. 53 if err := tcpWriteFrame(conn, reqData); err != nil { 54 return nil, err 55 } 56 // Receive a response. 57 return tcpReadFrame(conn, opts) 58 } 59 60 func dialTCP(ctx context.Context, opts *transport.RoundTripOptions) (net.Conn, error) { 61 if err := checkContextErr(ctx); err != nil { 62 return nil, fmt.Errorf("before tcp dial, %w", err) 63 } 64 var timeout time.Duration 65 d, isSetDeadline := ctx.Deadline() 66 if isSetDeadline { 67 timeout = time.Until(d) 68 } 69 70 var conn net.Conn 71 var err error 72 // Short connection mode, directly dial a connection. 73 if opts.DisableConnectionPool { 74 if opts.DialTimeout > 0 && opts.DialTimeout < timeout { 75 timeout = opts.DialTimeout 76 } 77 conn, err = Dial(&connpool.DialOptions{ 78 Network: opts.Network, 79 Address: opts.Address, 80 LocalAddr: opts.LocalAddr, 81 Timeout: timeout, 82 CACertFile: opts.CACertFile, 83 TLSCertFile: opts.TLSCertFile, 84 TLSKeyFile: opts.TLSKeyFile, 85 TLSServerName: opts.TLSServerName, 86 }) 87 if err != nil { 88 return nil, errs.WrapFrameError(err, errs.RetClientConnectFail, "tcp client transport dial") 89 } 90 // Set a deadline for subsequent reading on the connection. 91 if isSetDeadline { 92 if err := conn.SetReadDeadline(d); err != nil { 93 log.Tracef("client SetReadDeadline failed %v", err) 94 } 95 } 96 return conn, nil 97 } 98 99 // Connection pool mode, get connection from pool. 100 getOpts := connpool.NewGetOptions() 101 getOpts.WithContext(ctx) 102 getOpts.WithFramerBuilder(opts.FramerBuilder) 103 getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName) 104 getOpts.WithLocalAddr(opts.LocalAddr) 105 getOpts.WithDialTimeout(opts.DialTimeout) 106 getOpts.WithProtocol(opts.Protocol) 107 conn, err = opts.Pool.Get(opts.Network, opts.Address, getOpts) 108 if err != nil { 109 return nil, errs.WrapFrameError(err, errs.RetClientConnectFail, "tcp client transport connection pool") 110 } 111 // The created connection must be a tnet connection. 112 if !validateTnetConn(conn) && !validateTnetTLSConn(conn) { 113 return nil, errs.NewFrameError(errs.RetClientConnectFail, "tnet transport doesn't support non tnet.Conn") 114 } 115 if err := conn.SetReadDeadline(d); err != nil { 116 log.Tracef("client SetReadDeadline failed %v", err) 117 } 118 return conn, nil 119 } 120 121 func tcpWriteFrame(conn net.Conn, reqData []byte) error { 122 // When writing data on a tnet connection, there will be no partial write success, 123 // only complete success or complete failure. 124 _, err := conn.Write(reqData) 125 if err != nil { 126 return wrapNetError("tcp client tnet transport Write", err) 127 } 128 return nil 129 } 130 131 func tcpReadFrame(conn net.Conn, opts *transport.RoundTripOptions) ([]byte, error) { 132 if opts.ReqType == transport.SendOnly { 133 return nil, errs.ErrClientNoResponse 134 } 135 136 var fr codec.Framer 137 // The connection retrieved from the connection pool has already implemented the Framer interface. 138 if opts.DisableConnectionPool { 139 fr = opts.FramerBuilder.New(codec.NewReader(conn)) 140 } else { 141 var ok bool 142 fr, ok = conn.(codec.Framer) 143 if !ok { 144 return nil, errs.NewFrameError(errs.RetClientConnectFail, 145 "tcp client transport: framer not implemented") 146 } 147 } 148 149 rspData, err := fr.ReadFrame() 150 if err != nil { 151 return nil, wrapNetError("tcp client transport ReadFrame", err) 152 } 153 report.TCPClientTransportReceiveSize.Set(float64(len(rspData))) 154 return rspData, nil 155 } 156 157 func wrapNetError(msg string, err error) error { 158 if err == nil { 159 return nil 160 } 161 if e, ok := err.(net.Error); ok && e.Timeout() { 162 return errs.WrapFrameError(err, errs.RetClientTimeout, msg) 163 } 164 return errs.WrapFrameError(err, errs.RetClientNetErr, msg) 165 } 166 167 func checkContextErr(ctx context.Context) error { 168 if errors.Is(ctx.Err(), context.Canceled) { 169 return errs.WrapFrameError(ctx.Err(), errs.RetClientCanceled, "client canceled") 170 } 171 if errors.Is(ctx.Err(), context.DeadlineExceeded) { 172 return errs.WrapFrameError(ctx.Err(), errs.RetClientTimeout, "client timeout") 173 } 174 return nil 175 } 176 func (c *clientTransport) multiplex(ctx context.Context, req []byte, opts *transport.RoundTripOptions) ([]byte, error) { 177 getOpts := multiplexed.NewGetOptions() 178 getOpts.WithVID(opts.Msg.RequestID()) 179 fp, ok := opts.FramerBuilder.(multiplexed.FrameParser) 180 if !ok { 181 return nil, errs.NewFrameError(errs.RetClientConnectFail, 182 "frame builder does not implement multiplexed.FrameParser") 183 } 184 getOpts.WithFrameParser(fp) 185 getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName) 186 getOpts.WithLocalAddr(opts.LocalAddr) 187 conn, err := opts.Multiplexed.GetMuxConn(ctx, opts.Network, opts.Address, getOpts) 188 if err != nil { 189 return nil, errs.WrapFrameError(err, errs.RetClientNetErr, "tcp client get multiplex connection failed") 190 } 191 defer conn.Close() 192 msg := codec.Message(ctx) 193 msg.WithRemoteAddr(conn.RemoteAddr()) 194 195 if err := conn.Write(req); err != nil { 196 return nil, errs.WrapFrameError(err, errs.RetClientNetErr, "tcp client multiplex write failed") 197 } 198 199 // no need to receive response when request type is SendOnly. 200 if opts.ReqType == codec.SendOnly { 201 return nil, errs.ErrClientNoResponse 202 } 203 204 buf, err := conn.Read() 205 if err != nil { 206 if err == context.Canceled { 207 return nil, errs.NewFrameError(errs.RetClientCanceled, 208 "tcp tnet multiplexed ReadFrame: "+err.Error()) 209 } 210 if err == context.DeadlineExceeded { 211 return nil, errs.NewFrameError(errs.RetClientTimeout, 212 "tcp tnet multiplexed ReadFrame: "+err.Error()) 213 } 214 return nil, errs.NewFrameError(errs.RetClientNetErr, 215 "tcp tnet multiplexed ReadFrame: "+err.Error()) 216 } 217 return buf, nil 218 }