trpc.group/trpc-go/trpc-go@v1.0.2/stream/server.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package stream
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"io"
    20  	"net"
    21  	"sync"
    22  
    23  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    24  
    25  	trpc "trpc.group/trpc-go/trpc-go"
    26  	"trpc.group/trpc-go/trpc-go/codec"
    27  	"trpc.group/trpc-go/trpc-go/errs"
    28  	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
    29  	"trpc.group/trpc-go/trpc-go/internal/queue"
    30  	"trpc.group/trpc-go/trpc-go/log"
    31  	"trpc.group/trpc-go/trpc-go/server"
    32  	"trpc.group/trpc-go/trpc-go/transport"
    33  )
    34  
    35  // serverStream is a structure provided to the service implementation logic,
    36  // and users use the API of this structure to send and receive streaming messages.
    37  type serverStream struct {
    38  	ctx       context.Context
    39  	streamID  uint32
    40  	opts      *server.Options
    41  	recvQueue *queue.Queue[*response]
    42  	done      chan struct{}
    43  	err       error // Carry the server tcp failure information.
    44  	once      sync.Once
    45  	rControl  *receiveControl // Receiver flow control.
    46  	sControl  *sendControl    // Sender flow control.
    47  }
    48  
    49  // SendMsg is the API that users use to send streaming messages.
    50  func (s *serverStream) SendMsg(m interface{}) error {
    51  	msg := codec.Message(s.ctx)
    52  	ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx)
    53  	defer codec.PutBackMessage(newMsg)
    54  	newMsg.WithLocalAddr(msg.LocalAddr())
    55  	newMsg.WithRemoteAddr(msg.RemoteAddr())
    56  	newMsg.WithStreamID(s.streamID)
    57  	// Refer to the pb code generated by trpc.proto, common to each language, automatically generated code.
    58  	newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, s.streamID))
    59  
    60  	var (
    61  		err           error
    62  		reqBodyBuffer []byte
    63  	)
    64  	serializationType, compressType := s.serializationAndCompressType(msg)
    65  	if icodec.IsValidSerializationType(serializationType) {
    66  		reqBodyBuffer, err = codec.Marshal(serializationType, m)
    67  		if err != nil {
    68  			return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Marshal: "+err.Error())
    69  		}
    70  	}
    71  
    72  	// compress
    73  	if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
    74  		reqBodyBuffer, err = codec.Compress(compressType, reqBodyBuffer)
    75  		if err != nil {
    76  			return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Compress: "+err.Error())
    77  		}
    78  	}
    79  
    80  	// Flow control only controls the payload of data.
    81  	if s.sControl != nil {
    82  		if err := s.sControl.GetWindow(uint32(len(reqBodyBuffer))); err != nil {
    83  			return err
    84  		}
    85  	}
    86  
    87  	// encode the entire request.
    88  	reqBuffer, err := s.opts.Codec.Encode(newMsg, reqBodyBuffer)
    89  	if err != nil {
    90  		return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Encode: "+err.Error())
    91  	}
    92  
    93  	// initiate a backend network request.
    94  	return s.opts.StreamTransport.Send(ctx, reqBuffer)
    95  }
    96  
    97  func (s *serverStream) newFrameHead(streamFrameType trpcpb.TrpcStreamFrameType) *trpc.FrameHead {
    98  	return &trpc.FrameHead{
    99  		FrameType:       uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
   100  		StreamFrameType: uint8(streamFrameType),
   101  		StreamID:        s.streamID,
   102  	}
   103  }
   104  
   105  func (s *serverStream) serializationAndCompressType(msg codec.Msg) (int, int) {
   106  	serializationType := msg.SerializationType()
   107  	compressType := msg.CompressType()
   108  	if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) {
   109  		serializationType = s.opts.CurrentSerializationType
   110  	}
   111  	if icodec.IsValidCompressType(s.opts.CurrentCompressType) {
   112  		compressType = s.opts.CurrentCompressType
   113  	}
   114  	return serializationType, compressType
   115  }
   116  
   117  // RecvMsg receives streaming messages, passes in the structure that needs to receive messages,
   118  // and returns the serialized structure.
   119  func (s *serverStream) RecvMsg(m interface{}) error {
   120  	resp, ok := s.recvQueue.Get()
   121  	if !ok {
   122  		if s.err != nil {
   123  			return s.err
   124  		}
   125  		return errs.NewFrameError(errs.RetServerSystemErr, streamClosed)
   126  	}
   127  	if resp.err != nil {
   128  		return resp.err
   129  	}
   130  	if s.rControl != nil {
   131  		if err := s.rControl.OnRecv(uint32(len(resp.data))); err != nil {
   132  			return err
   133  		}
   134  	}
   135  	// Decompress and deserialize the data frame into a structure.
   136  	return s.decompressAndUnmarshal(resp.data, m)
   137  
   138  }
   139  
   140  // decompressAndUnmarshal decompresses the data frame and deserializes it.
   141  func (s *serverStream) decompressAndUnmarshal(data []byte, m interface{}) error {
   142  	msg := codec.Message(s.ctx)
   143  	var err error
   144  	serializationType, compressType := s.serializationAndCompressType(msg)
   145  	if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
   146  		data, err = codec.Decompress(compressType, data)
   147  		if err != nil {
   148  			return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Decompress: "+err.Error())
   149  		}
   150  	}
   151  
   152  	// Deserialize the binary body to a specific body structure.
   153  	if icodec.IsValidSerializationType(serializationType) {
   154  		if err := codec.Unmarshal(serializationType, data, m); err != nil {
   155  			return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Unmarshal: "+err.Error())
   156  		}
   157  	}
   158  	return nil
   159  }
   160  
   161  // The CloseSend server closes the stream, where ret represents the close type,
   162  // which is divided into TRPC_STREAM_CLOSE and TRPC_STREAM_RESET.
   163  // message represents the returned message, where error messages can be logged.
   164  func (s *serverStream) CloseSend(closeType, ret int32, message string) error {
   165  	oldMsg := codec.Message(s.ctx)
   166  	ctx, msg := codec.WithCloneContextAndMessage(s.ctx)
   167  	defer codec.PutBackMessage(msg)
   168  	msg.WithLocalAddr(oldMsg.LocalAddr())
   169  	msg.WithRemoteAddr(oldMsg.RemoteAddr())
   170  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, s.streamID))
   171  	msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{
   172  		CloseType: closeType,
   173  		Ret:       ret,
   174  		Msg:       []byte(message),
   175  	})
   176  
   177  	rspBuffer, err := s.opts.Codec.Encode(msg, nil)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	return s.opts.StreamTransport.Send(ctx, rspBuffer)
   182  }
   183  
   184  // newServerStream creates a new server stream, which can send and receive streaming messages.
   185  func newServerStream(ctx context.Context, streamID uint32, opts *server.Options) *serverStream {
   186  	s := &serverStream{
   187  		ctx:      ctx,
   188  		opts:     opts,
   189  		streamID: streamID,
   190  		done:     make(chan struct{}, 1),
   191  	}
   192  	s.recvQueue = queue.New[*response](s.done)
   193  	return s
   194  }
   195  
   196  func (s *serverStream) feedback(w uint32) error {
   197  	oldMsg := codec.Message(s.ctx)
   198  	ctx, msg := codec.WithCloneContextAndMessage(s.ctx)
   199  	defer codec.PutBackMessage(msg)
   200  	msg.WithLocalAddr(oldMsg.LocalAddr())
   201  	msg.WithRemoteAddr(oldMsg.RemoteAddr())
   202  	msg.WithStreamID(s.streamID)
   203  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, s.streamID))
   204  	msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: w})
   205  
   206  	feedbackBuf, err := s.opts.Codec.Encode(msg, nil)
   207  	if err != nil {
   208  		return err
   209  	}
   210  	return s.opts.StreamTransport.Send(ctx, feedbackBuf)
   211  }
   212  
   213  // Context returns the context of the serverStream structure.
   214  func (s *serverStream) Context() context.Context {
   215  	return s.ctx
   216  }
   217  
   218  // The structure of streamDispatcher is used to distribute streaming data.
   219  type streamDispatcher struct {
   220  	m                      sync.RWMutex
   221  	streamIDToServerStream map[net.Addr]map[uint32]*serverStream
   222  	opts                   *server.Options
   223  }
   224  
   225  // DefaultStreamDispatcher is the default implementation of the trpc dispatcher,
   226  // supports the data distribution of the trpc streaming protocol.
   227  var DefaultStreamDispatcher = NewStreamDispatcher()
   228  
   229  // NewStreamDispatcher returns a new dispatcher.
   230  func NewStreamDispatcher() server.StreamHandle {
   231  	return &streamDispatcher{
   232  		streamIDToServerStream: make(map[net.Addr]map[uint32]*serverStream),
   233  	}
   234  }
   235  
   236  // storeServerStream msg contains the socket address of the client connection,
   237  // there are multiple streams under each socket address, and map it to serverStream
   238  // again according to the id of the stream.
   239  func (sd *streamDispatcher) storeServerStream(addr net.Addr, streamID uint32, ss *serverStream) {
   240  	sd.m.Lock()
   241  	defer sd.m.Unlock()
   242  	if addrToStreamID, ok := sd.streamIDToServerStream[addr]; !ok {
   243  		// Does not exist, indicating that a new connection is coming, re-create the structure.
   244  		sd.streamIDToServerStream[addr] = map[uint32]*serverStream{streamID: ss}
   245  	} else {
   246  		addrToStreamID[streamID] = ss
   247  	}
   248  }
   249  
   250  // deleteServerStream deletes the serverStream from cache.
   251  func (sd *streamDispatcher) deleteServerStream(addr net.Addr, streamID uint32) {
   252  	sd.m.Lock()
   253  	defer sd.m.Unlock()
   254  	if addrToStreamID, ok := sd.streamIDToServerStream[addr]; ok {
   255  		if _, ok = addrToStreamID[streamID]; ok {
   256  			delete(addrToStreamID, streamID)
   257  		}
   258  		if len(addrToStreamID) == 0 {
   259  			delete(sd.streamIDToServerStream, addr)
   260  		}
   261  	}
   262  }
   263  
   264  // loadServerStream loads the stored serverStream through the socket address
   265  // of the client connection and the id of the stream.
   266  func (sd *streamDispatcher) loadServerStream(addr net.Addr, streamID uint32) (*serverStream, error) {
   267  	sd.m.RLock()
   268  	defer sd.m.RUnlock()
   269  	addrToStream, ok := sd.streamIDToServerStream[addr]
   270  	if !ok || addr == nil {
   271  		return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr)
   272  	}
   273  
   274  	var ss *serverStream
   275  	if ss, ok = addrToStream[streamID]; !ok {
   276  		return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchStreamID)
   277  	}
   278  	return ss, nil
   279  }
   280  
   281  // Init initializes some settings of dispatcher.
   282  func (sd *streamDispatcher) Init(opts *server.Options) error {
   283  	sd.opts = opts
   284  	st, ok := sd.opts.Transport.(transport.ServerStreamTransport)
   285  	if !ok {
   286  		return errors.New(streamTransportUnimplemented)
   287  	}
   288  	sd.opts.StreamTransport = st
   289  	sd.opts.ServeOptions = append(sd.opts.ServeOptions,
   290  		transport.WithServerAsync(false), transport.WithCopyFrame(true))
   291  	return nil
   292  }
   293  
   294  // startStreamHandler is used to start the goroutine, execute streamHandler,
   295  // streamHandler is implemented for the specific streaming server.
   296  func (sd *streamDispatcher) startStreamHandler(addr net.Addr, streamID uint32,
   297  	ss *serverStream, si *server.StreamServerInfo, sh server.StreamHandler) {
   298  	defer func() {
   299  		sd.deleteServerStream(addr, streamID)
   300  		ss.once.Do(func() { close(ss.done) })
   301  	}()
   302  
   303  	// Execute the implementation code of the server stream.
   304  	var err error
   305  	if ss.opts.StreamFilters != nil {
   306  		err = ss.opts.StreamFilters.Filter(ss, si, sh)
   307  	} else {
   308  		err = sh(ss)
   309  	}
   310  
   311  	var frameworkError *errs.Error
   312  	switch {
   313  	case errors.As(err, &frameworkError):
   314  		err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), int32(frameworkError.Code), frameworkError.Msg)
   315  	case err != nil:
   316  		// return business error.
   317  		err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), 0, err.Error())
   318  	default:
   319  		// Stream is normally closed.
   320  		err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 0, "")
   321  	}
   322  	if err != nil {
   323  		ss.err = err
   324  		log.Trace(closeSendFail, err)
   325  	}
   326  }
   327  
   328  // setSendControl obtained from the init frame.
   329  func (s *serverStream) setSendControl(msg codec.Msg) (uint32, error) {
   330  	initMeta, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta)
   331  	if !ok {
   332  		return 0, errors.New(streamFrameInvalid)
   333  	}
   334  
   335  	// This section of logic is compatible with framework implementations in other languages
   336  	// that do not enable flow control, and will be deleted later.
   337  	if initMeta.InitWindowSize == 0 {
   338  		// Compatible with the client without flow control enabled.
   339  		s.rControl = nil
   340  		s.sControl = nil
   341  		return initMeta.InitWindowSize, nil
   342  	}
   343  	s.sControl = newSendControl(initMeta.InitWindowSize, s.done)
   344  	return initMeta.InitWindowSize, nil
   345  }
   346  
   347  // handleInit processes the sent init package.
   348  func (sd *streamDispatcher) handleInit(ctx context.Context,
   349  	sh server.StreamHandler, si *server.StreamServerInfo) ([]byte, error) {
   350  	// The Msg in ctx is passed to us by the upper layer, and we can't make any assumptions about its life cycle.
   351  	// Before creating ServerStream, make a complete copy of Msg.
   352  	oldMsg := codec.Message(ctx)
   353  	ctx, msg := codec.WithNewMessage(ctx)
   354  	codec.CopyMsg(msg, oldMsg)
   355  
   356  	streamID := msg.StreamID()
   357  	ss := newServerStream(ctx, streamID, sd.opts)
   358  	w := getWindowSize(sd.opts.MaxWindowSize)
   359  	ss.rControl = newReceiveControl(w, ss.feedback)
   360  	sd.storeServerStream(msg.RemoteAddr(), streamID, ss)
   361  
   362  	cw, err := ss.setSendControl(msg)
   363  	if err != nil {
   364  		return nil, err
   365  	}
   366  
   367  	// send init response packet.
   368  	newCtx, newMsg := codec.WithCloneContextAndMessage(ctx)
   369  	defer codec.PutBackMessage(newMsg)
   370  	newMsg.WithLocalAddr(msg.LocalAddr())
   371  	newMsg.WithRemoteAddr(msg.RemoteAddr())
   372  	newMsg.WithStreamID(streamID)
   373  	newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, ss.streamID))
   374  
   375  	initMeta := &trpcpb.TrpcStreamInitMeta{ResponseMeta: &trpcpb.TrpcStreamInitResponseMeta{}}
   376  	// If the client does not set it, the server should not set it to prevent incompatibility.
   377  	if cw == 0 {
   378  		initMeta.InitWindowSize = 0
   379  	} else {
   380  		initMeta.InitWindowSize = w
   381  	}
   382  	newMsg.WithStreamFrame(initMeta)
   383  
   384  	rspBuffer, err := ss.opts.Codec.Encode(newMsg, nil)
   385  	if err != nil {
   386  		return nil, err
   387  	}
   388  	if err := ss.opts.StreamTransport.Send(newCtx, rspBuffer); err != nil {
   389  		return nil, err
   390  	}
   391  
   392  	// Initiate a goroutine to execute specific business logic.
   393  	go sd.startStreamHandler(msg.RemoteAddr(), streamID, ss, si, sh)
   394  	return nil, errs.ErrServerNoResponse
   395  }
   396  
   397  // handleData handles data messages.
   398  func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error) {
   399  	ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID())
   400  	if err != nil {
   401  		return nil, err
   402  	}
   403  	ss.recvQueue.Put(&response{data: req})
   404  	return nil, errs.ErrServerNoResponse
   405  }
   406  
   407  // handleClose handles the Close message.
   408  func (sd *streamDispatcher) handleClose(msg codec.Msg) ([]byte, error) {
   409  	ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID())
   410  	if err != nil {
   411  		// The server has sent the Close frame.
   412  		// Since the timing of the Close frame is unpredictable, when the server receives the Close frame from the client,
   413  		// the Close frame may have been sent, causing the resource to be released, no need to respond to this error.
   414  		log.Trace("handleClose loadServerStream fail", err)
   415  		return nil, errs.ErrServerNoResponse
   416  	}
   417  	// is Reset message.
   418  	if msg.ServerRspErr() != nil {
   419  		ss.recvQueue.Put(&response{err: msg.ServerRspErr()})
   420  		return nil, errs.ErrServerNoResponse
   421  	}
   422  	// is a normal Close message
   423  	ss.recvQueue.Put(&response{err: io.EOF})
   424  	return nil, errs.ErrServerNoResponse
   425  }
   426  
   427  // handleError When the connection is wrong, handle the error.
   428  func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) {
   429  	sd.m.Lock()
   430  	defer sd.m.Unlock()
   431  
   432  	addr := msg.RemoteAddr()
   433  	addrToStream, ok := sd.streamIDToServerStream[addr]
   434  	if !ok || addr == nil {
   435  		return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr)
   436  	}
   437  	for streamID, ss := range addrToStream {
   438  		ss.err = msg.ServerRspErr()
   439  		ss.once.Do(func() { close(ss.done) })
   440  		delete(addrToStream, streamID)
   441  	}
   442  	delete(sd.streamIDToServerStream, addr)
   443  	return nil, errs.ErrServerNoResponse
   444  }
   445  
   446  // StreamHandleFunc The processing logic after a complete streaming frame received by the streaming transport.
   447  func (sd *streamDispatcher) StreamHandleFunc(ctx context.Context,
   448  	sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) {
   449  	msg := codec.Message(ctx)
   450  	frameHead, ok := msg.FrameHead().(*trpc.FrameHead)
   451  	if !ok {
   452  		// If there is no frame head and serverRspErr, the server connection is abnormal
   453  		// and returns to the upper service.
   454  		if msg.ServerRspErr() != nil {
   455  			return sd.handleError(msg)
   456  		}
   457  		return nil, errs.NewFrameError(errs.RetServerSystemErr, frameHeadNotInMsg)
   458  	}
   459  	msg.WithFrameHead(nil)
   460  	return sd.handleByStreamFrameType(ctx, trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType), sh, si, req)
   461  }
   462  
   463  // handleFeedback handles the feedback frame.
   464  func (sd *streamDispatcher) handleFeedback(msg codec.Msg) ([]byte, error) {
   465  	ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID())
   466  	if err != nil {
   467  		return nil, err
   468  	}
   469  	fb, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta)
   470  	if !ok {
   471  		return nil, errors.New(streamFrameInvalid)
   472  	}
   473  	if ss.sControl != nil {
   474  		ss.sControl.UpdateWindow(fb.WindowSizeIncrement)
   475  	}
   476  	return nil, errs.ErrServerNoResponse
   477  }
   478  
   479  // handleByStreamFrameType performs different logic processing according to the type of stream frame.
   480  func (sd *streamDispatcher) handleByStreamFrameType(ctx context.Context, streamFrameType trpcpb.TrpcStreamFrameType,
   481  	sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) {
   482  	msg := codec.Message(ctx)
   483  	switch streamFrameType {
   484  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT:
   485  		return sd.handleInit(ctx, sh, si)
   486  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA:
   487  		return sd.handleData(msg, req)
   488  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE:
   489  		return sd.handleClose(msg)
   490  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK:
   491  		return sd.handleFeedback(msg)
   492  	default:
   493  		return nil, errs.NewFrameError(errs.RetServerSystemErr, unknownFrameType)
   494  	}
   495  }