trpc.group/trpc-go/trpc-go@v1.0.2/stream/server.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package stream 15 16 import ( 17 "context" 18 "errors" 19 "io" 20 "net" 21 "sync" 22 23 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 24 25 trpc "trpc.group/trpc-go/trpc-go" 26 "trpc.group/trpc-go/trpc-go/codec" 27 "trpc.group/trpc-go/trpc-go/errs" 28 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 29 "trpc.group/trpc-go/trpc-go/internal/queue" 30 "trpc.group/trpc-go/trpc-go/log" 31 "trpc.group/trpc-go/trpc-go/server" 32 "trpc.group/trpc-go/trpc-go/transport" 33 ) 34 35 // serverStream is a structure provided to the service implementation logic, 36 // and users use the API of this structure to send and receive streaming messages. 37 type serverStream struct { 38 ctx context.Context 39 streamID uint32 40 opts *server.Options 41 recvQueue *queue.Queue[*response] 42 done chan struct{} 43 err error // Carry the server tcp failure information. 44 once sync.Once 45 rControl *receiveControl // Receiver flow control. 46 sControl *sendControl // Sender flow control. 47 } 48 49 // SendMsg is the API that users use to send streaming messages. 50 func (s *serverStream) SendMsg(m interface{}) error { 51 msg := codec.Message(s.ctx) 52 ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx) 53 defer codec.PutBackMessage(newMsg) 54 newMsg.WithLocalAddr(msg.LocalAddr()) 55 newMsg.WithRemoteAddr(msg.RemoteAddr()) 56 newMsg.WithStreamID(s.streamID) 57 // Refer to the pb code generated by trpc.proto, common to each language, automatically generated code. 58 newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, s.streamID)) 59 60 var ( 61 err error 62 reqBodyBuffer []byte 63 ) 64 serializationType, compressType := s.serializationAndCompressType(msg) 65 if icodec.IsValidSerializationType(serializationType) { 66 reqBodyBuffer, err = codec.Marshal(serializationType, m) 67 if err != nil { 68 return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Marshal: "+err.Error()) 69 } 70 } 71 72 // compress 73 if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop { 74 reqBodyBuffer, err = codec.Compress(compressType, reqBodyBuffer) 75 if err != nil { 76 return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Compress: "+err.Error()) 77 } 78 } 79 80 // Flow control only controls the payload of data. 81 if s.sControl != nil { 82 if err := s.sControl.GetWindow(uint32(len(reqBodyBuffer))); err != nil { 83 return err 84 } 85 } 86 87 // encode the entire request. 88 reqBuffer, err := s.opts.Codec.Encode(newMsg, reqBodyBuffer) 89 if err != nil { 90 return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Encode: "+err.Error()) 91 } 92 93 // initiate a backend network request. 94 return s.opts.StreamTransport.Send(ctx, reqBuffer) 95 } 96 97 func (s *serverStream) newFrameHead(streamFrameType trpcpb.TrpcStreamFrameType) *trpc.FrameHead { 98 return &trpc.FrameHead{ 99 FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME), 100 StreamFrameType: uint8(streamFrameType), 101 StreamID: s.streamID, 102 } 103 } 104 105 func (s *serverStream) serializationAndCompressType(msg codec.Msg) (int, int) { 106 serializationType := msg.SerializationType() 107 compressType := msg.CompressType() 108 if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) { 109 serializationType = s.opts.CurrentSerializationType 110 } 111 if icodec.IsValidCompressType(s.opts.CurrentCompressType) { 112 compressType = s.opts.CurrentCompressType 113 } 114 return serializationType, compressType 115 } 116 117 // RecvMsg receives streaming messages, passes in the structure that needs to receive messages, 118 // and returns the serialized structure. 119 func (s *serverStream) RecvMsg(m interface{}) error { 120 resp, ok := s.recvQueue.Get() 121 if !ok { 122 if s.err != nil { 123 return s.err 124 } 125 return errs.NewFrameError(errs.RetServerSystemErr, streamClosed) 126 } 127 if resp.err != nil { 128 return resp.err 129 } 130 if s.rControl != nil { 131 if err := s.rControl.OnRecv(uint32(len(resp.data))); err != nil { 132 return err 133 } 134 } 135 // Decompress and deserialize the data frame into a structure. 136 return s.decompressAndUnmarshal(resp.data, m) 137 138 } 139 140 // decompressAndUnmarshal decompresses the data frame and deserializes it. 141 func (s *serverStream) decompressAndUnmarshal(data []byte, m interface{}) error { 142 msg := codec.Message(s.ctx) 143 var err error 144 serializationType, compressType := s.serializationAndCompressType(msg) 145 if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop { 146 data, err = codec.Decompress(compressType, data) 147 if err != nil { 148 return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Decompress: "+err.Error()) 149 } 150 } 151 152 // Deserialize the binary body to a specific body structure. 153 if icodec.IsValidSerializationType(serializationType) { 154 if err := codec.Unmarshal(serializationType, data, m); err != nil { 155 return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Unmarshal: "+err.Error()) 156 } 157 } 158 return nil 159 } 160 161 // The CloseSend server closes the stream, where ret represents the close type, 162 // which is divided into TRPC_STREAM_CLOSE and TRPC_STREAM_RESET. 163 // message represents the returned message, where error messages can be logged. 164 func (s *serverStream) CloseSend(closeType, ret int32, message string) error { 165 oldMsg := codec.Message(s.ctx) 166 ctx, msg := codec.WithCloneContextAndMessage(s.ctx) 167 defer codec.PutBackMessage(msg) 168 msg.WithLocalAddr(oldMsg.LocalAddr()) 169 msg.WithRemoteAddr(oldMsg.RemoteAddr()) 170 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, s.streamID)) 171 msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{ 172 CloseType: closeType, 173 Ret: ret, 174 Msg: []byte(message), 175 }) 176 177 rspBuffer, err := s.opts.Codec.Encode(msg, nil) 178 if err != nil { 179 return err 180 } 181 return s.opts.StreamTransport.Send(ctx, rspBuffer) 182 } 183 184 // newServerStream creates a new server stream, which can send and receive streaming messages. 185 func newServerStream(ctx context.Context, streamID uint32, opts *server.Options) *serverStream { 186 s := &serverStream{ 187 ctx: ctx, 188 opts: opts, 189 streamID: streamID, 190 done: make(chan struct{}, 1), 191 } 192 s.recvQueue = queue.New[*response](s.done) 193 return s 194 } 195 196 func (s *serverStream) feedback(w uint32) error { 197 oldMsg := codec.Message(s.ctx) 198 ctx, msg := codec.WithCloneContextAndMessage(s.ctx) 199 defer codec.PutBackMessage(msg) 200 msg.WithLocalAddr(oldMsg.LocalAddr()) 201 msg.WithRemoteAddr(oldMsg.RemoteAddr()) 202 msg.WithStreamID(s.streamID) 203 msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, s.streamID)) 204 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: w}) 205 206 feedbackBuf, err := s.opts.Codec.Encode(msg, nil) 207 if err != nil { 208 return err 209 } 210 return s.opts.StreamTransport.Send(ctx, feedbackBuf) 211 } 212 213 // Context returns the context of the serverStream structure. 214 func (s *serverStream) Context() context.Context { 215 return s.ctx 216 } 217 218 // The structure of streamDispatcher is used to distribute streaming data. 219 type streamDispatcher struct { 220 m sync.RWMutex 221 streamIDToServerStream map[net.Addr]map[uint32]*serverStream 222 opts *server.Options 223 } 224 225 // DefaultStreamDispatcher is the default implementation of the trpc dispatcher, 226 // supports the data distribution of the trpc streaming protocol. 227 var DefaultStreamDispatcher = NewStreamDispatcher() 228 229 // NewStreamDispatcher returns a new dispatcher. 230 func NewStreamDispatcher() server.StreamHandle { 231 return &streamDispatcher{ 232 streamIDToServerStream: make(map[net.Addr]map[uint32]*serverStream), 233 } 234 } 235 236 // storeServerStream msg contains the socket address of the client connection, 237 // there are multiple streams under each socket address, and map it to serverStream 238 // again according to the id of the stream. 239 func (sd *streamDispatcher) storeServerStream(addr net.Addr, streamID uint32, ss *serverStream) { 240 sd.m.Lock() 241 defer sd.m.Unlock() 242 if addrToStreamID, ok := sd.streamIDToServerStream[addr]; !ok { 243 // Does not exist, indicating that a new connection is coming, re-create the structure. 244 sd.streamIDToServerStream[addr] = map[uint32]*serverStream{streamID: ss} 245 } else { 246 addrToStreamID[streamID] = ss 247 } 248 } 249 250 // deleteServerStream deletes the serverStream from cache. 251 func (sd *streamDispatcher) deleteServerStream(addr net.Addr, streamID uint32) { 252 sd.m.Lock() 253 defer sd.m.Unlock() 254 if addrToStreamID, ok := sd.streamIDToServerStream[addr]; ok { 255 if _, ok = addrToStreamID[streamID]; ok { 256 delete(addrToStreamID, streamID) 257 } 258 if len(addrToStreamID) == 0 { 259 delete(sd.streamIDToServerStream, addr) 260 } 261 } 262 } 263 264 // loadServerStream loads the stored serverStream through the socket address 265 // of the client connection and the id of the stream. 266 func (sd *streamDispatcher) loadServerStream(addr net.Addr, streamID uint32) (*serverStream, error) { 267 sd.m.RLock() 268 defer sd.m.RUnlock() 269 addrToStream, ok := sd.streamIDToServerStream[addr] 270 if !ok || addr == nil { 271 return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) 272 } 273 274 var ss *serverStream 275 if ss, ok = addrToStream[streamID]; !ok { 276 return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchStreamID) 277 } 278 return ss, nil 279 } 280 281 // Init initializes some settings of dispatcher. 282 func (sd *streamDispatcher) Init(opts *server.Options) error { 283 sd.opts = opts 284 st, ok := sd.opts.Transport.(transport.ServerStreamTransport) 285 if !ok { 286 return errors.New(streamTransportUnimplemented) 287 } 288 sd.opts.StreamTransport = st 289 sd.opts.ServeOptions = append(sd.opts.ServeOptions, 290 transport.WithServerAsync(false), transport.WithCopyFrame(true)) 291 return nil 292 } 293 294 // startStreamHandler is used to start the goroutine, execute streamHandler, 295 // streamHandler is implemented for the specific streaming server. 296 func (sd *streamDispatcher) startStreamHandler(addr net.Addr, streamID uint32, 297 ss *serverStream, si *server.StreamServerInfo, sh server.StreamHandler) { 298 defer func() { 299 sd.deleteServerStream(addr, streamID) 300 ss.once.Do(func() { close(ss.done) }) 301 }() 302 303 // Execute the implementation code of the server stream. 304 var err error 305 if ss.opts.StreamFilters != nil { 306 err = ss.opts.StreamFilters.Filter(ss, si, sh) 307 } else { 308 err = sh(ss) 309 } 310 311 var frameworkError *errs.Error 312 switch { 313 case errors.As(err, &frameworkError): 314 err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), int32(frameworkError.Code), frameworkError.Msg) 315 case err != nil: 316 // return business error. 317 err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), 0, err.Error()) 318 default: 319 // Stream is normally closed. 320 err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 0, "") 321 } 322 if err != nil { 323 ss.err = err 324 log.Trace(closeSendFail, err) 325 } 326 } 327 328 // setSendControl obtained from the init frame. 329 func (s *serverStream) setSendControl(msg codec.Msg) (uint32, error) { 330 initMeta, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 331 if !ok { 332 return 0, errors.New(streamFrameInvalid) 333 } 334 335 // This section of logic is compatible with framework implementations in other languages 336 // that do not enable flow control, and will be deleted later. 337 if initMeta.InitWindowSize == 0 { 338 // Compatible with the client without flow control enabled. 339 s.rControl = nil 340 s.sControl = nil 341 return initMeta.InitWindowSize, nil 342 } 343 s.sControl = newSendControl(initMeta.InitWindowSize, s.done) 344 return initMeta.InitWindowSize, nil 345 } 346 347 // handleInit processes the sent init package. 348 func (sd *streamDispatcher) handleInit(ctx context.Context, 349 sh server.StreamHandler, si *server.StreamServerInfo) ([]byte, error) { 350 // The Msg in ctx is passed to us by the upper layer, and we can't make any assumptions about its life cycle. 351 // Before creating ServerStream, make a complete copy of Msg. 352 oldMsg := codec.Message(ctx) 353 ctx, msg := codec.WithNewMessage(ctx) 354 codec.CopyMsg(msg, oldMsg) 355 356 streamID := msg.StreamID() 357 ss := newServerStream(ctx, streamID, sd.opts) 358 w := getWindowSize(sd.opts.MaxWindowSize) 359 ss.rControl = newReceiveControl(w, ss.feedback) 360 sd.storeServerStream(msg.RemoteAddr(), streamID, ss) 361 362 cw, err := ss.setSendControl(msg) 363 if err != nil { 364 return nil, err 365 } 366 367 // send init response packet. 368 newCtx, newMsg := codec.WithCloneContextAndMessage(ctx) 369 defer codec.PutBackMessage(newMsg) 370 newMsg.WithLocalAddr(msg.LocalAddr()) 371 newMsg.WithRemoteAddr(msg.RemoteAddr()) 372 newMsg.WithStreamID(streamID) 373 newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, ss.streamID)) 374 375 initMeta := &trpcpb.TrpcStreamInitMeta{ResponseMeta: &trpcpb.TrpcStreamInitResponseMeta{}} 376 // If the client does not set it, the server should not set it to prevent incompatibility. 377 if cw == 0 { 378 initMeta.InitWindowSize = 0 379 } else { 380 initMeta.InitWindowSize = w 381 } 382 newMsg.WithStreamFrame(initMeta) 383 384 rspBuffer, err := ss.opts.Codec.Encode(newMsg, nil) 385 if err != nil { 386 return nil, err 387 } 388 if err := ss.opts.StreamTransport.Send(newCtx, rspBuffer); err != nil { 389 return nil, err 390 } 391 392 // Initiate a goroutine to execute specific business logic. 393 go sd.startStreamHandler(msg.RemoteAddr(), streamID, ss, si, sh) 394 return nil, errs.ErrServerNoResponse 395 } 396 397 // handleData handles data messages. 398 func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error) { 399 ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) 400 if err != nil { 401 return nil, err 402 } 403 ss.recvQueue.Put(&response{data: req}) 404 return nil, errs.ErrServerNoResponse 405 } 406 407 // handleClose handles the Close message. 408 func (sd *streamDispatcher) handleClose(msg codec.Msg) ([]byte, error) { 409 ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) 410 if err != nil { 411 // The server has sent the Close frame. 412 // Since the timing of the Close frame is unpredictable, when the server receives the Close frame from the client, 413 // the Close frame may have been sent, causing the resource to be released, no need to respond to this error. 414 log.Trace("handleClose loadServerStream fail", err) 415 return nil, errs.ErrServerNoResponse 416 } 417 // is Reset message. 418 if msg.ServerRspErr() != nil { 419 ss.recvQueue.Put(&response{err: msg.ServerRspErr()}) 420 return nil, errs.ErrServerNoResponse 421 } 422 // is a normal Close message 423 ss.recvQueue.Put(&response{err: io.EOF}) 424 return nil, errs.ErrServerNoResponse 425 } 426 427 // handleError When the connection is wrong, handle the error. 428 func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) { 429 sd.m.Lock() 430 defer sd.m.Unlock() 431 432 addr := msg.RemoteAddr() 433 addrToStream, ok := sd.streamIDToServerStream[addr] 434 if !ok || addr == nil { 435 return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) 436 } 437 for streamID, ss := range addrToStream { 438 ss.err = msg.ServerRspErr() 439 ss.once.Do(func() { close(ss.done) }) 440 delete(addrToStream, streamID) 441 } 442 delete(sd.streamIDToServerStream, addr) 443 return nil, errs.ErrServerNoResponse 444 } 445 446 // StreamHandleFunc The processing logic after a complete streaming frame received by the streaming transport. 447 func (sd *streamDispatcher) StreamHandleFunc(ctx context.Context, 448 sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) { 449 msg := codec.Message(ctx) 450 frameHead, ok := msg.FrameHead().(*trpc.FrameHead) 451 if !ok { 452 // If there is no frame head and serverRspErr, the server connection is abnormal 453 // and returns to the upper service. 454 if msg.ServerRspErr() != nil { 455 return sd.handleError(msg) 456 } 457 return nil, errs.NewFrameError(errs.RetServerSystemErr, frameHeadNotInMsg) 458 } 459 msg.WithFrameHead(nil) 460 return sd.handleByStreamFrameType(ctx, trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType), sh, si, req) 461 } 462 463 // handleFeedback handles the feedback frame. 464 func (sd *streamDispatcher) handleFeedback(msg codec.Msg) ([]byte, error) { 465 ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) 466 if err != nil { 467 return nil, err 468 } 469 fb, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta) 470 if !ok { 471 return nil, errors.New(streamFrameInvalid) 472 } 473 if ss.sControl != nil { 474 ss.sControl.UpdateWindow(fb.WindowSizeIncrement) 475 } 476 return nil, errs.ErrServerNoResponse 477 } 478 479 // handleByStreamFrameType performs different logic processing according to the type of stream frame. 480 func (sd *streamDispatcher) handleByStreamFrameType(ctx context.Context, streamFrameType trpcpb.TrpcStreamFrameType, 481 sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) { 482 msg := codec.Message(ctx) 483 switch streamFrameType { 484 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 485 return sd.handleInit(ctx, sh, si) 486 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 487 return sd.handleData(msg, req) 488 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 489 return sd.handleClose(msg) 490 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 491 return sd.handleFeedback(msg) 492 default: 493 return nil, errs.NewFrameError(errs.RetServerSystemErr, unknownFrameType) 494 } 495 }