trpc.group/trpc-go/trpc-go@v1.0.3/stream/server.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package stream 15 16 import ( 17 "context" 18 "errors" 19 "io" 20 "sync" 21 22 "go.uber.org/atomic" 23 "trpc.group/trpc-go/trpc-go/internal/addrutil" 24 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 25 26 trpc "trpc.group/trpc-go/trpc-go" 27 "trpc.group/trpc-go/trpc-go/codec" 28 "trpc.group/trpc-go/trpc-go/errs" 29 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 30 "trpc.group/trpc-go/trpc-go/internal/queue" 31 "trpc.group/trpc-go/trpc-go/log" 32 "trpc.group/trpc-go/trpc-go/server" 33 "trpc.group/trpc-go/trpc-go/transport" 34 ) 35 36 // serverStream is a structure provided to the service implementation logic, 37 // and users use the API of this structure to send and receive streaming messages. 38 type serverStream struct { 39 ctx context.Context 40 streamID uint32 41 opts *server.Options 42 recvQueue *queue.Queue[*response] 43 done chan struct{} 44 err atomic.Error // Carry the server tcp failure information. 45 once sync.Once 46 rControl *receiveControl // Receiver flow control. 47 sControl *sendControl // Sender flow control. 48 } 49 50 // SendMsg is the API that users use to send streaming messages. 51 func (s *serverStream) SendMsg(m interface{}) error { 52 if err := s.err.Load(); err != nil { 53 return errs.WrapFrameError(err, errs.Code(err), "stream sending error") 54 } 55 msg := codec.Message(s.ctx) 56 ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx) 57 defer codec.PutBackMessage(newMsg) 58 newMsg.WithLocalAddr(msg.LocalAddr()) 59 newMsg.WithRemoteAddr(msg.RemoteAddr()) 60 newMsg.WithCompressType(msg.CompressType()) 61 newMsg.WithStreamID(s.streamID) 62 // Refer to the pb code generated by trpc.proto, common to each language, automatically generated code. 63 newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, s.streamID)) 64 65 var ( 66 err error 67 reqBodyBuffer []byte 68 ) 69 serializationType, compressType := s.serializationAndCompressType(newMsg) 70 if icodec.IsValidSerializationType(serializationType) { 71 reqBodyBuffer, err = codec.Marshal(serializationType, m) 72 if err != nil { 73 return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Marshal: "+err.Error()) 74 } 75 } 76 77 // compress 78 if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop { 79 reqBodyBuffer, err = codec.Compress(compressType, reqBodyBuffer) 80 if err != nil { 81 return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Compress: "+err.Error()) 82 } 83 } 84 85 // Flow control only controls the payload of data. 86 if s.sControl != nil { 87 if err := s.sControl.GetWindow(uint32(len(reqBodyBuffer))); err != nil { 88 return err 89 } 90 } 91 92 // encode the entire request. 93 reqBuffer, err := s.opts.Codec.Encode(newMsg, reqBodyBuffer) 94 if err != nil { 95 return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Encode: "+err.Error()) 96 } 97 98 // initiate a backend network request. 99 return s.opts.StreamTransport.Send(ctx, reqBuffer) 100 } 101 102 func (s *serverStream) newFrameHead(streamFrameType trpcpb.TrpcStreamFrameType) *trpc.FrameHead { 103 return &trpc.FrameHead{ 104 FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME), 105 StreamFrameType: uint8(streamFrameType), 106 StreamID: s.streamID, 107 } 108 } 109 110 func (s *serverStream) serializationAndCompressType(msg codec.Msg) (int, int) { 111 serializationType := msg.SerializationType() 112 compressType := msg.CompressType() 113 if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) { 114 serializationType = s.opts.CurrentSerializationType 115 } 116 if icodec.IsValidCompressType(s.opts.CurrentCompressType) { 117 compressType = s.opts.CurrentCompressType 118 } 119 return serializationType, compressType 120 } 121 122 // RecvMsg receives streaming messages, passes in the structure that needs to receive messages, 123 // and returns the serialized structure. 124 func (s *serverStream) RecvMsg(m interface{}) error { 125 resp, ok := s.recvQueue.Get() 126 if !ok { 127 if err := s.err.Load(); err != nil { 128 return err 129 } 130 return errs.NewFrameError(errs.RetServerSystemErr, streamClosed) 131 } 132 if resp.err != nil { 133 return resp.err 134 } 135 if s.rControl != nil { 136 if err := s.rControl.OnRecv(uint32(len(resp.data))); err != nil { 137 return err 138 } 139 } 140 // Decompress and deserialize the data frame into a structure. 141 return s.decompressAndUnmarshal(resp.data, m) 142 143 } 144 145 // decompressAndUnmarshal decompresses the data frame and deserializes it. 146 func (s *serverStream) decompressAndUnmarshal(data []byte, m interface{}) error { 147 msg := codec.Message(s.ctx) 148 var err error 149 serializationType, compressType := s.serializationAndCompressType(msg) 150 if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop { 151 data, err = codec.Decompress(compressType, data) 152 if err != nil { 153 return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Decompress: "+err.Error()) 154 } 155 } 156 157 // Deserialize the binary body to a specific body structure. 158 if icodec.IsValidSerializationType(serializationType) { 159 if err := codec.Unmarshal(serializationType, data, m); err != nil { 160 return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Unmarshal: "+err.Error()) 161 } 162 } 163 return nil 164 } 165 166 // The CloseSend server closes the stream, where ret represents the close type, 167 // which is divided into TRPC_STREAM_CLOSE and TRPC_STREAM_RESET. 168 // message represents the returned message, where error messages can be logged. 169 func (s *serverStream) CloseSend(closeType, ret int32, message string) error { 170 oldMsg := codec.Message(s.ctx) 171 ctx, msg := codec.WithCloneContextAndMessage(s.ctx) 172 defer codec.PutBackMessage(msg) 173 msg.WithLocalAddr(oldMsg.LocalAddr()) 174 msg.WithRemoteAddr(oldMsg.RemoteAddr()) 175 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, s.streamID)) 176 msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{ 177 CloseType: closeType, 178 Ret: ret, 179 Msg: []byte(message), 180 }) 181 182 rspBuffer, err := s.opts.Codec.Encode(msg, nil) 183 if err != nil { 184 return err 185 } 186 return s.opts.StreamTransport.Send(ctx, rspBuffer) 187 } 188 189 // newServerStream creates a new server stream, which can send and receive streaming messages. 190 func newServerStream(ctx context.Context, streamID uint32, opts *server.Options) *serverStream { 191 s := &serverStream{ 192 ctx: ctx, 193 opts: opts, 194 streamID: streamID, 195 done: make(chan struct{}, 1), 196 } 197 s.recvQueue = queue.New[*response](s.done) 198 return s 199 } 200 201 func (s *serverStream) feedback(w uint32) error { 202 oldMsg := codec.Message(s.ctx) 203 ctx, msg := codec.WithCloneContextAndMessage(s.ctx) 204 defer codec.PutBackMessage(msg) 205 msg.WithLocalAddr(oldMsg.LocalAddr()) 206 msg.WithRemoteAddr(oldMsg.RemoteAddr()) 207 msg.WithStreamID(s.streamID) 208 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, s.streamID)) 209 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: w}) 210 211 feedbackBuf, err := s.opts.Codec.Encode(msg, nil) 212 if err != nil { 213 return err 214 } 215 return s.opts.StreamTransport.Send(ctx, feedbackBuf) 216 } 217 218 // Context returns the context of the serverStream structure. 219 func (s *serverStream) Context() context.Context { 220 return s.ctx 221 } 222 223 // The structure of streamDispatcher is used to distribute streaming data. 224 type streamDispatcher struct { 225 m sync.RWMutex 226 // local address + remote address + network 227 // => stream ID 228 // => serverStream 229 addrToServerStream map[string]map[uint32]*serverStream 230 opts *server.Options 231 } 232 233 // DefaultStreamDispatcher is the default implementation of the trpc dispatcher, 234 // supports the data distribution of the trpc streaming protocol. 235 var DefaultStreamDispatcher = NewStreamDispatcher() 236 237 // NewStreamDispatcher returns a new dispatcher. 238 func NewStreamDispatcher() server.StreamHandle { 239 return &streamDispatcher{ 240 addrToServerStream: make(map[string]map[uint32]*serverStream), 241 } 242 } 243 244 // storeServerStream msg contains the socket address of the client connection, 245 // there are multiple streams under each socket address, and map it to serverStream 246 // again according to the id of the stream. 247 func (sd *streamDispatcher) storeServerStream(addr string, streamID uint32, ss *serverStream) { 248 sd.m.Lock() 249 defer sd.m.Unlock() 250 if addrToStreamID, ok := sd.addrToServerStream[addr]; !ok { 251 // Does not exist, indicating that a new connection is coming, re-create the structure. 252 sd.addrToServerStream[addr] = map[uint32]*serverStream{streamID: ss} 253 } else { 254 addrToStreamID[streamID] = ss 255 } 256 } 257 258 // deleteServerStream deletes the serverStream from cache. 259 func (sd *streamDispatcher) deleteServerStream(addr string, streamID uint32) { 260 sd.m.Lock() 261 defer sd.m.Unlock() 262 if addrToStreamID, ok := sd.addrToServerStream[addr]; ok { 263 if _, ok = addrToStreamID[streamID]; ok { 264 delete(addrToStreamID, streamID) 265 } 266 if len(addrToStreamID) == 0 { 267 delete(sd.addrToServerStream, addr) 268 } 269 } 270 } 271 272 // loadServerStream loads the stored serverStream through the socket address 273 // of the client connection and the id of the stream. 274 func (sd *streamDispatcher) loadServerStream(addr string, streamID uint32) (*serverStream, error) { 275 sd.m.RLock() 276 defer sd.m.RUnlock() 277 addrToStream, ok := sd.addrToServerStream[addr] 278 if !ok { 279 return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) 280 } 281 282 var ss *serverStream 283 if ss, ok = addrToStream[streamID]; !ok { 284 return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchStreamID) 285 } 286 return ss, nil 287 } 288 289 // Init initializes some settings of dispatcher. 290 func (sd *streamDispatcher) Init(opts *server.Options) error { 291 sd.opts = opts 292 st, ok := sd.opts.Transport.(transport.ServerStreamTransport) 293 if !ok { 294 return errors.New(streamTransportUnimplemented) 295 } 296 sd.opts.StreamTransport = st 297 sd.opts.ServeOptions = append(sd.opts.ServeOptions, 298 transport.WithServerAsync(false), transport.WithCopyFrame(true)) 299 return nil 300 } 301 302 // startStreamHandler is used to start the goroutine, execute streamHandler, 303 // streamHandler is implemented for the specific streaming server. 304 func (sd *streamDispatcher) startStreamHandler(addr string, streamID uint32, 305 ss *serverStream, si *server.StreamServerInfo, sh server.StreamHandler) { 306 defer func() { 307 sd.deleteServerStream(addr, streamID) 308 ss.once.Do(func() { close(ss.done) }) 309 }() 310 311 // Execute the implementation code of the server stream. 312 var err error 313 if ss.opts.StreamFilters != nil { 314 err = ss.opts.StreamFilters.Filter(ss, si, sh) 315 } else { 316 err = sh(ss) 317 } 318 319 var frameworkError *errs.Error 320 switch { 321 case errors.As(err, &frameworkError): 322 err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), int32(frameworkError.Code), frameworkError.Msg) 323 case err != nil: 324 // return business error. 325 err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), 0, err.Error()) 326 default: 327 // Stream is normally closed. 328 err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 0, "") 329 } 330 if err != nil { 331 ss.err.Store(err) 332 log.Trace(closeSendFail, err) 333 } 334 } 335 336 // setSendControl obtained from the init frame. 337 func (s *serverStream) setSendControl(msg codec.Msg) (uint32, error) { 338 initMeta, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 339 if !ok { 340 return 0, errors.New(streamFrameInvalid) 341 } 342 343 // This section of logic is compatible with framework implementations in other languages 344 // that do not enable flow control, and will be deleted later. 345 if initMeta.InitWindowSize == 0 { 346 // Compatible with the client without flow control enabled. 347 s.rControl = nil 348 s.sControl = nil 349 return initMeta.InitWindowSize, nil 350 } 351 s.sControl = newSendControl(initMeta.InitWindowSize, s.done) 352 return initMeta.InitWindowSize, nil 353 } 354 355 // handleInit processes the sent init package. 356 func (sd *streamDispatcher) handleInit(ctx context.Context, 357 sh server.StreamHandler, si *server.StreamServerInfo) ([]byte, error) { 358 // The Msg in ctx is passed to us by the upper layer, and we can't make any assumptions about its life cycle. 359 // Before creating ServerStream, make a complete copy of Msg. 360 oldMsg := codec.Message(ctx) 361 ctx, msg := codec.WithNewMessage(ctx) 362 codec.CopyMsg(msg, oldMsg) 363 364 streamID := msg.StreamID() 365 ss := newServerStream(ctx, streamID, sd.opts) 366 w := getWindowSize(sd.opts.MaxWindowSize) 367 ss.rControl = newReceiveControl(w, ss.feedback) 368 sd.storeServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss) 369 370 cw, err := ss.setSendControl(msg) 371 if err != nil { 372 return nil, err 373 } 374 375 // send init response packet. 376 newCtx, newMsg := codec.WithCloneContextAndMessage(ctx) 377 defer codec.PutBackMessage(newMsg) 378 newMsg.WithLocalAddr(msg.LocalAddr()) 379 newMsg.WithRemoteAddr(msg.RemoteAddr()) 380 newMsg.WithStreamID(streamID) 381 newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, ss.streamID)) 382 383 initMeta := &trpcpb.TrpcStreamInitMeta{ResponseMeta: &trpcpb.TrpcStreamInitResponseMeta{}} 384 // If the client does not set it, the server should not set it to prevent incompatibility. 385 if cw == 0 { 386 initMeta.InitWindowSize = 0 387 } else { 388 initMeta.InitWindowSize = w 389 } 390 newMsg.WithStreamFrame(initMeta) 391 392 rspBuffer, err := ss.opts.Codec.Encode(newMsg, nil) 393 if err != nil { 394 return nil, err 395 } 396 if err := ss.opts.StreamTransport.Send(newCtx, rspBuffer); err != nil { 397 return nil, err 398 } 399 400 // Initiate a goroutine to execute specific business logic. 401 go sd.startStreamHandler(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss, si, sh) 402 return nil, errs.ErrServerNoResponse 403 } 404 405 // handleData handles data messages. 406 func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error) { 407 ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) 408 if err != nil { 409 return nil, err 410 } 411 ss.recvQueue.Put(&response{data: req}) 412 return nil, errs.ErrServerNoResponse 413 } 414 415 // handleClose handles the Close message. 416 func (sd *streamDispatcher) handleClose(msg codec.Msg) ([]byte, error) { 417 ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) 418 if err != nil { 419 // The server has sent the Close frame. 420 // Since the timing of the Close frame is unpredictable, when the server receives the Close frame from the client, 421 // the Close frame may have been sent, causing the resource to be released, no need to respond to this error. 422 log.Trace("handleClose loadServerStream fail", err) 423 return nil, errs.ErrServerNoResponse 424 } 425 // is Reset message. 426 if msg.ServerRspErr() != nil { 427 ss.recvQueue.Put(&response{err: msg.ServerRspErr()}) 428 return nil, errs.ErrServerNoResponse 429 } 430 // is a normal Close message 431 ss.recvQueue.Put(&response{err: io.EOF}) 432 return nil, errs.ErrServerNoResponse 433 } 434 435 // handleError When the connection is wrong, handle the error. 436 func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) { 437 sd.m.Lock() 438 defer sd.m.Unlock() 439 440 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 441 addrToStream, ok := sd.addrToServerStream[addr] 442 if !ok { 443 return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) 444 } 445 for streamID, ss := range addrToStream { 446 ss.err.Store(msg.ServerRspErr()) 447 ss.once.Do(func() { close(ss.done) }) 448 delete(addrToStream, streamID) 449 } 450 delete(sd.addrToServerStream, addr) 451 return nil, errs.ErrServerNoResponse 452 } 453 454 // StreamHandleFunc The processing logic after a complete streaming frame received by the streaming transport. 455 func (sd *streamDispatcher) StreamHandleFunc(ctx context.Context, 456 sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) { 457 msg := codec.Message(ctx) 458 frameHead, ok := msg.FrameHead().(*trpc.FrameHead) 459 if !ok { 460 // If there is no frame head and serverRspErr, the server connection is abnormal 461 // and returns to the upper service. 462 if msg.ServerRspErr() != nil { 463 return sd.handleError(msg) 464 } 465 return nil, errs.NewFrameError(errs.RetServerSystemErr, frameHeadNotInMsg) 466 } 467 msg.WithFrameHead(nil) 468 return sd.handleByStreamFrameType(ctx, trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType), sh, si, req) 469 } 470 471 // handleFeedback handles the feedback frame. 472 func (sd *streamDispatcher) handleFeedback(msg codec.Msg) ([]byte, error) { 473 ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) 474 if err != nil { 475 return nil, err 476 } 477 fb, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta) 478 if !ok { 479 return nil, errors.New(streamFrameInvalid) 480 } 481 if ss.sControl != nil { 482 ss.sControl.UpdateWindow(fb.WindowSizeIncrement) 483 } 484 return nil, errs.ErrServerNoResponse 485 } 486 487 // handleByStreamFrameType performs different logic processing according to the type of stream frame. 488 func (sd *streamDispatcher) handleByStreamFrameType(ctx context.Context, streamFrameType trpcpb.TrpcStreamFrameType, 489 sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) { 490 msg := codec.Message(ctx) 491 switch streamFrameType { 492 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 493 return sd.handleInit(ctx, sh, si) 494 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 495 return sd.handleData(msg, req) 496 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 497 return sd.handleClose(msg) 498 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 499 return sd.handleFeedback(msg) 500 default: 501 return nil, errs.NewFrameError(errs.RetServerSystemErr, unknownFrameType) 502 } 503 }