trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_stream.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package transport 15 16 import ( 17 "context" 18 "sync" 19 20 "trpc.group/trpc-go/trpc-go/codec" 21 "trpc.group/trpc-go/trpc-go/errs" 22 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 23 ) 24 25 const ( 26 defaultMaxConcurrentStreams = 1000 27 defaultMaxIdleConnsPerHost = 2 28 ) 29 30 // DefaultClientStreamTransport is the default client stream transport. 31 var DefaultClientStreamTransport = NewClientStreamTransport() 32 33 // NewClientStreamTransport creates a new ClientStreamTransport. 34 func NewClientStreamTransport(opts ...ClientStreamTransportOption) ClientStreamTransport { 35 options := &cstOptions{ 36 maxConcurrentStreams: defaultMaxConcurrentStreams, 37 maxIdleConnsPerHost: defaultMaxIdleConnsPerHost, 38 } 39 for _, opt := range opts { 40 opt(options) 41 } 42 t := &clientStreamTransport{ 43 // Map streamID to connection. On the client side, ensure that the streamID is 44 // incremented and unique, otherwise the map of addr must be added. 45 streamIDToConn: make(map[uint32]multiplexed.MuxConn), 46 m: &sync.RWMutex{}, 47 multiplexedPool: multiplexed.New( 48 multiplexed.WithMaxVirConnsPerConn(options.maxConcurrentStreams), 49 multiplexed.WithMaxIdleConnsPerHost(options.maxIdleConnsPerHost), 50 ), 51 } 52 return t 53 } 54 55 // cstOptions is the client stream transport options. 56 type cstOptions struct { 57 maxConcurrentStreams int 58 maxIdleConnsPerHost int 59 } 60 61 // ClientStreamTransportOption sets properties of ClientStreamTransport. 62 type ClientStreamTransportOption func(*cstOptions) 63 64 // WithMaxConcurrentStreams sets the maximum concurrent streams in each TCP connection. 65 func WithMaxConcurrentStreams(n int) ClientStreamTransportOption { 66 return func(opts *cstOptions) { 67 opts.maxConcurrentStreams = n 68 } 69 } 70 71 // WithMaxIdleConnsPerHost sets the maximum idle connections per host. 72 func WithMaxIdleConnsPerHost(n int) ClientStreamTransportOption { 73 return func(opts *cstOptions) { 74 opts.maxIdleConnsPerHost = n 75 } 76 } 77 78 // clientStreamTransport keeps compatibility with the original client transport. 79 type clientStreamTransport struct { 80 streamIDToConn map[uint32]multiplexed.MuxConn 81 m *sync.RWMutex 82 multiplexedPool multiplexed.Pool 83 } 84 85 // Init inits clientStreamTransport. It gets a connection from the multiplexing pool. A stream is 86 // corresponding to a virtual connection, which provides the interface for the stream. 87 func (c *clientStreamTransport) Init(ctx context.Context, roundTripOpts ...RoundTripOption) error { 88 opts, err := c.getOptions(ctx, roundTripOpts...) 89 if err != nil { 90 return err 91 } 92 // If ctx has been canceled or timeout, just return. 93 if ctx.Err() == context.Canceled { 94 return errs.NewFrameError(errs.RetClientCanceled, 95 "client canceled before tcp dial: "+ctx.Err().Error()) 96 } 97 if ctx.Err() == context.DeadlineExceeded { 98 return errs.NewFrameError(errs.RetClientTimeout, 99 "client timeout before tcp dial: "+ctx.Err().Error()) 100 } 101 msg := opts.Msg 102 streamID := msg.StreamID() 103 104 getOpts := multiplexed.NewGetOptions() 105 getOpts.WithVID(streamID) 106 fp, ok := opts.FramerBuilder.(multiplexed.FrameParser) 107 if !ok { 108 return errs.NewFrameError(errs.RetClientConnectFail, 109 "frame builder does not implement multiplexed.FrameParser") 110 } 111 getOpts.WithFrameParser(fp) 112 getOpts.WithDialTLS(opts.TLSCertFile, opts.TLSKeyFile, opts.CACertFile, opts.TLSServerName) 113 getOpts.WithLocalAddr(opts.LocalAddr) 114 conn, err := opts.Multiplexed.GetMuxConn(ctx, opts.Network, opts.Address, getOpts) 115 if err != nil { 116 return errs.NewFrameError(errs.RetClientConnectFail, 117 "tcp client transport multiplexd pool: "+err.Error()) 118 } 119 msg.WithRemoteAddr(conn.RemoteAddr()) 120 msg.WithLocalAddr(conn.LocalAddr()) 121 c.m.Lock() 122 c.streamIDToConn[streamID] = conn 123 c.m.Unlock() 124 return nil 125 } 126 127 // Send sends stream data and provides interface for stream. 128 func (c *clientStreamTransport) Send(ctx context.Context, req []byte, roundTripOpts ...RoundTripOption) error { 129 msg := codec.Message(ctx) 130 streamID := msg.StreamID() 131 // StreamID is uniquely generated by stream client. 132 c.m.RLock() 133 cc := c.streamIDToConn[streamID] 134 c.m.RUnlock() 135 if cc == nil { 136 return errs.NewFrameError(errs.RetServerSystemErr, "Connection is Closed") 137 } 138 if err := cc.Write(req); err != nil { 139 return err 140 } 141 return nil 142 } 143 144 // Recv receives stream data and provides interface for stream. 145 func (c *clientStreamTransport) Recv(ctx context.Context, roundTripOpts ...RoundTripOption) ([]byte, error) { 146 cc, err := c.getConnect(ctx, roundTripOpts...) 147 if err != nil { 148 return nil, err 149 } 150 151 select { 152 case <-ctx.Done(): 153 if ctx.Err() == context.Canceled { 154 return nil, errs.NewFrameError(errs.RetClientCanceled, 155 "tcp client transport canceled before Write: "+ctx.Err().Error()) 156 } 157 if ctx.Err() == context.DeadlineExceeded { 158 return nil, errs.NewFrameError(errs.RetClientTimeout, 159 "tcp client transport timeout before Write: "+ctx.Err().Error()) 160 } 161 default: 162 } 163 return cc.Read() 164 } 165 166 // Close closes connections and cleans up. 167 func (c *clientStreamTransport) Close(ctx context.Context) { 168 msg := codec.Message(ctx) 169 streamID := msg.StreamID() 170 c.m.Lock() 171 defer c.m.Unlock() 172 if conn, ok := c.streamIDToConn[streamID]; ok { 173 conn.Close() 174 delete(c.streamIDToConn, streamID) 175 } 176 } 177 178 // getOptions inits RoundTripOptions and does some basic check. 179 func (c *clientStreamTransport) getOptions(ctx context.Context, 180 roundTripOpts ...RoundTripOption) (*RoundTripOptions, error) { 181 opts := &RoundTripOptions{ 182 Multiplexed: c.multiplexedPool, 183 } 184 185 // use roundTripOpts to modify opts. 186 for _, o := range roundTripOpts { 187 o(opts) 188 } 189 190 if opts.Multiplexed == nil { 191 return nil, errs.NewFrameError(errs.RetClientConnectFail, 192 "tcp client transport: multiplexd pool empty") 193 } 194 195 if opts.FramerBuilder == nil { 196 return nil, errs.NewFrameError(errs.RetClientConnectFail, 197 "tcp client transport: framer builder empty") 198 } 199 200 if opts.Msg == nil { 201 return nil, errs.NewFrameError(errs.RetClientConnectFail, 202 "tcp client transport: message empty") 203 } 204 return opts, nil 205 } 206 207 func (c *clientStreamTransport) getConnect(ctx context.Context, 208 roundTripOpts ...RoundTripOption) (multiplexed.MuxConn, error) { 209 msg := codec.Message(ctx) 210 streamID := msg.StreamID() 211 c.m.RLock() 212 cc := c.streamIDToConn[streamID] 213 c.m.RUnlock() 214 if cc == nil { 215 return nil, errs.NewFrameError(errs.RetServerSystemErr, "Stream is not inited yet") 216 } 217 return cc, nil 218 }