trpc.group/trpc-go/trpc-go@v1.0.3/stream/client.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 // Package stream contains streaming client and server APIs. 15 package stream 16 17 import ( 18 "context" 19 "errors" 20 "fmt" 21 "io" 22 "sync" 23 "sync/atomic" 24 25 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 26 27 trpc "trpc.group/trpc-go/trpc-go" 28 "trpc.group/trpc-go/trpc-go/client" 29 "trpc.group/trpc-go/trpc-go/codec" 30 "trpc.group/trpc-go/trpc-go/errs" 31 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 32 "trpc.group/trpc-go/trpc-go/internal/queue" 33 "trpc.group/trpc-go/trpc-go/transport" 34 ) 35 36 // Client is the Streaming client interface, NewStream is its only method. 37 type Client interface { 38 // NewStream returns a client stream, which the user uses to call Recv and Send to send, 39 // receive and send streaming messages. 40 NewStream(ctx context.Context, desc *client.ClientStreamDesc, 41 method string, opt ...client.Option) (client.ClientStream, error) 42 } 43 44 // DefaultStreamClient is the default streaming client. 45 var DefaultStreamClient = NewStreamClient() 46 47 // NewStreamClient returns a streaming client. 48 func NewStreamClient() Client { 49 // Streaming ID from 0-99 is reserved ID, used as control ID. 50 return &streamClient{streamID: uint32(99)} 51 } 52 53 // an implementation of streamClient Client. 54 type streamClient struct { 55 streamID uint32 56 } 57 58 // The specific implementation of ClientStream. 59 type clientStream struct { 60 desc *client.ClientStreamDesc 61 method string 62 sc *streamClient 63 ctx context.Context 64 opts *client.Options 65 streamID uint32 66 stream client.Stream 67 recvQueue *queue.Queue[*response] 68 closed uint32 69 closeCh chan struct{} 70 closeOnce sync.Once 71 } 72 73 // NewStream creates a new stream through which users send and receive messages. 74 func (c *streamClient) NewStream(ctx context.Context, desc *client.ClientStreamDesc, 75 method string, opt ...client.Option) (client.ClientStream, error) { 76 return c.newStream(ctx, desc, method, opt...) 77 } 78 79 // newStream creates a new stream through which users send and receive messages. 80 func (c *streamClient) newStream(ctx context.Context, desc *client.ClientStreamDesc, 81 method string, opt ...client.Option) (client.ClientStream, error) { 82 ctx, _ = codec.EnsureMessage(ctx) 83 cs := &clientStream{ 84 desc: desc, 85 method: method, 86 sc: c, 87 streamID: atomic.AddUint32(&c.streamID, 1), 88 ctx: ctx, 89 closeCh: make(chan struct{}, 1), 90 recvQueue: queue.New[*response](ctx.Done()), 91 stream: client.NewStream(), 92 } 93 if err := cs.prepare(opt...); err != nil { 94 return nil, err 95 } 96 if cs.opts.StreamFilters != nil { 97 return cs.opts.StreamFilters.Filter(cs.ctx, cs.desc, cs.invoke) 98 } 99 return cs.invoke(cs.ctx, cs.desc) 100 } 101 102 // Context returns the Context of the current stream. 103 func (cs *clientStream) Context() context.Context { 104 return cs.ctx 105 } 106 107 // RecvMsg receives the message, if there is no message it will get stuck. 108 // RecvMsg and SendMsg are concurrency safe, but two RecvMsg are not concurrency safe. 109 func (cs *clientStream) RecvMsg(m interface{}) error { 110 if err := cs.recv(m); err != nil { 111 return err 112 } 113 if cs.desc.ServerStreams { 114 // Subsequent messages should be received by subsequent RecvMsg calls. 115 return nil 116 } 117 // Special handling for non-server-stream rpcs. 118 // This recv expects EOF or errors. 119 err := cs.recv(m) 120 if err == nil { 121 return errs.NewFrameError(errs.RetClientStreamReadEnd, 122 "client streaming protocol violation: get <nil>, want <EOF>") 123 } 124 if err == io.EOF { 125 return nil 126 } 127 return err 128 } 129 130 func (cs *clientStream) recv(m interface{}) error { 131 resp, ok := cs.recvQueue.Get() 132 if !ok { 133 return cs.dealContextDone() 134 } 135 if resp.err != nil { 136 return resp.err 137 } 138 // Gather flow control information. 139 if err := cs.recvFlowCtl(len(resp.data)); err != nil { 140 return err 141 } 142 143 serializationType := codec.Message(cs.ctx).SerializationType() 144 if icodec.IsValidSerializationType(cs.opts.CurrentSerializationType) { 145 serializationType = cs.opts.CurrentSerializationType 146 } 147 if err := codec.Unmarshal(serializationType, resp.data, m); err != nil { 148 return errs.NewFrameError(errs.RetClientDecodeFail, "client codec Unmarshal: "+err.Error()) 149 } 150 return nil 151 } 152 153 func (cs *clientStream) recvFlowCtl(n int) error { 154 if cs.opts.RControl == nil { 155 return nil 156 } 157 // If the bottom layer has received the Close frame, then no feedback frame will be returned. 158 if atomic.LoadUint32(&cs.closed) == 1 { 159 return nil 160 } 161 if err := cs.opts.RControl.OnRecv(uint32(n)); err != nil { 162 // Avoid receiving the Close frame after entering OnRecv, and make another judgment. 163 if atomic.LoadUint32(&cs.closed) == 1 { 164 return nil 165 } 166 return err 167 } 168 return nil 169 } 170 171 // dealContextDone returns the final error message according to the Error type of the context. 172 func (cs *clientStream) dealContextDone() error { 173 if cs.ctx.Err() == context.Canceled { 174 return errs.NewFrameError(errs.RetClientCanceled, "tcp client stream canceled before recv: "+cs.ctx.Err().Error()) 175 } 176 if cs.ctx.Err() == context.DeadlineExceeded { 177 return errs.NewFrameError(errs.RetClientTimeout, 178 "tcp client stream canceled timeout before recv: "+cs.ctx.Err().Error()) 179 } 180 return nil 181 } 182 183 // SendMsg is the specific implementation of sending a message. 184 // RecvMsg and SendMsg are concurrency safe, but two SendMsg are not concurrency safe. 185 func (cs *clientStream) SendMsg(m interface{}) error { 186 ctx, msg := codec.WithCloneContextAndMessage(cs.ctx) 187 defer codec.PutBackMessage(msg) 188 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, cs.streamID)) 189 msg.WithStreamID(cs.streamID) 190 msg.WithClientRPCName(cs.method) 191 msg.WithCompressType(codec.Message(cs.ctx).CompressType()) 192 return cs.stream.Send(ctx, m) 193 } 194 195 func newFrameHead(t trpcpb.TrpcStreamFrameType, id uint32) *trpc.FrameHead { 196 return &trpc.FrameHead{ 197 FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME), 198 StreamFrameType: uint8(t), 199 StreamID: id, 200 } 201 } 202 203 // CloseSend normally closes the sender, no longer sends messages, only accepts messages. 204 func (cs *clientStream) CloseSend() error { 205 ctx, msg := codec.WithCloneContextAndMessage(cs.ctx) 206 defer codec.PutBackMessage(msg) 207 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, cs.streamID)) 208 msg.WithStreamID(cs.streamID) 209 msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{ 210 CloseType: int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 211 Ret: 0, 212 }) 213 return cs.stream.Send(ctx, nil) 214 } 215 216 func (cs *clientStream) prepare(opt ...client.Option) error { 217 msg := codec.Message(cs.ctx) 218 msg.WithClientRPCName(cs.method) 219 msg.WithStreamID(cs.streamID) 220 221 opt = append([]client.Option{client.WithStreamTransport(transport.DefaultClientStreamTransport)}, opt...) 222 opts, err := cs.stream.Init(cs.ctx, opt...) 223 if err != nil { 224 return err 225 } 226 cs.opts = opts 227 return nil 228 } 229 230 func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) (client.ClientStream, error) { 231 if err := cs.stream.Invoke(ctx); err != nil { 232 return nil, err 233 } 234 w := getWindowSize(cs.opts.MaxWindowSize) 235 newCtx, newMsg := codec.WithCloneContextAndMessage(ctx) 236 defer codec.PutBackMessage(newMsg) 237 copyMetaData(newMsg, codec.Message(cs.ctx)) 238 newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, cs.streamID)) 239 newMsg.WithClientRPCName(cs.method) 240 newMsg.WithStreamID(cs.streamID) 241 newMsg.WithCompressType(codec.Message(cs.ctx).CompressType()) 242 newMsg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{ 243 RequestMeta: &trpcpb.TrpcStreamInitRequestMeta{}, 244 InitWindowSize: w, 245 }) 246 cs.opts.RControl = newReceiveControl(w, cs.feedback) 247 // Send the init message out. 248 if err := cs.stream.Send(newCtx, nil); err != nil { 249 return nil, err 250 } 251 // After init is sent, the server will return directly. 252 if _, err := cs.stream.Recv(newCtx); err != nil { 253 return nil, err 254 } 255 initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 256 if !ok { 257 return nil, fmt.Errorf("client stream (method = %s, streamID = %d) recv "+ 258 "unexpected frame type: %T, expected: %T", 259 cs.method, cs.streamID, newMsg.StreamFrame(), (*trpcpb.TrpcStreamInitMeta)(nil)) 260 } 261 initWindowSize := initRspMeta.GetInitWindowSize() 262 cs.configSendControl(initWindowSize) 263 264 // Start the dispatch goroutine loop to send packets. 265 go cs.dispatch() 266 return cs, nil 267 } 268 269 // configSendControl configs Send Control according to initWindowSize. 270 func (cs *clientStream) configSendControl(initWindowSize uint32) { 271 if initWindowSize == 0 { 272 // Disable flow control, compatible with the server without flow control enabled, delete this logic later. 273 cs.opts.RControl = nil 274 cs.opts.SControl = nil 275 return 276 } 277 cs.opts.SControl = newSendControl(initWindowSize, cs.ctx.Done(), cs.closeCh) 278 } 279 280 // feedback send feedback frame. 281 func (cs *clientStream) feedback(i uint32) error { 282 ctx, msg := codec.WithCloneContextAndMessage(cs.ctx) 283 defer codec.PutBackMessage(msg) 284 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, cs.streamID)) 285 msg.WithStreamID(cs.streamID) 286 msg.WithClientRPCName(cs.method) 287 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: i}) 288 return cs.stream.Send(ctx, nil) 289 } 290 291 // handleFrame performs different logical processing according to the type of frame. 292 func (cs *clientStream) handleFrame(ctx context.Context, resp *response, 293 respData []byte, frameHead *trpc.FrameHead) error { 294 msg := codec.Message(ctx) 295 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 296 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 297 // Get the data and return it to the client. 298 resp.data = respData 299 resp.err = nil 300 cs.recvQueue.Put(resp) 301 return nil 302 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 303 // Close, it should be judged as Reset or Close. 304 resp.data = nil 305 var err error 306 if msg.ClientRspErr() != nil { 307 // Description is Reset. 308 err = msg.ClientRspErr() 309 } else { 310 err = io.EOF 311 } 312 resp.err = err 313 cs.recvQueue.Put(resp) 314 return err 315 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 316 cs.handleFeedback(msg) 317 return nil 318 default: 319 return nil 320 } 321 } 322 323 // handleFeedback handles the feedback frame. 324 func (cs *clientStream) handleFeedback(msg codec.Msg) { 325 if feedbackFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta); ok && cs.opts.SControl != nil { 326 cs.opts.SControl.UpdateWindow(feedbackFrame.WindowSizeIncrement) 327 } 328 } 329 330 // dispatch is used to distribute the received data packets, receive them in a loop, 331 // and then distribute the data packets according to different data types. 332 func (cs *clientStream) dispatch() { 333 defer func() { 334 cs.opts.StreamTransport.Close(cs.ctx) 335 cs.close() 336 }() 337 for { 338 ctx, msg := codec.WithCloneContextAndMessage(cs.ctx) 339 msg.WithCompressType(codec.Message(cs.ctx).CompressType()) 340 msg.WithStreamID(cs.streamID) 341 respData, err := cs.stream.Recv(ctx) 342 if err != nil { 343 // return to client on error. 344 cs.recvQueue.Put(&response{ 345 err: errs.WrapFrameError(err, errs.RetClientStreamReadEnd, streamClosed), 346 }) 347 return 348 } 349 350 frameHead, ok := msg.FrameHead().(*trpc.FrameHead) 351 if !ok { 352 cs.recvQueue.Put(&response{ 353 err: errors.New(frameHeadInvalid), 354 }) 355 return 356 } 357 358 if err := cs.handleFrame(ctx, &response{}, respData, frameHead); err != nil { 359 // If there is a Close frame, the dispatch goroutine ends. 360 return 361 } 362 } 363 } 364 365 func (cs *clientStream) close() { 366 cs.closeOnce.Do(func() { 367 atomic.StoreUint32(&cs.closed, 1) 368 close(cs.closeCh) 369 }) 370 } 371 372 func copyMetaData(dst codec.Msg, src codec.Msg) { 373 if src.ClientMetaData() != nil { 374 dst.WithClientMetaData(src.ClientMetaData().Clone()) 375 } 376 }