trpc.group/trpc-go/trpc-go@v1.0.3/stream/server.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package stream
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"io"
    20  	"sync"
    21  
    22  	"go.uber.org/atomic"
    23  	"trpc.group/trpc-go/trpc-go/internal/addrutil"
    24  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    25  
    26  	trpc "trpc.group/trpc-go/trpc-go"
    27  	"trpc.group/trpc-go/trpc-go/codec"
    28  	"trpc.group/trpc-go/trpc-go/errs"
    29  	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
    30  	"trpc.group/trpc-go/trpc-go/internal/queue"
    31  	"trpc.group/trpc-go/trpc-go/log"
    32  	"trpc.group/trpc-go/trpc-go/server"
    33  	"trpc.group/trpc-go/trpc-go/transport"
    34  )
    35  
    36  // serverStream is a structure provided to the service implementation logic,
    37  // and users use the API of this structure to send and receive streaming messages.
    38  type serverStream struct {
    39  	ctx       context.Context
    40  	streamID  uint32
    41  	opts      *server.Options
    42  	recvQueue *queue.Queue[*response]
    43  	done      chan struct{}
    44  	err       atomic.Error // Carry the server tcp failure information.
    45  	once      sync.Once
    46  	rControl  *receiveControl // Receiver flow control.
    47  	sControl  *sendControl    // Sender flow control.
    48  }
    49  
    50  // SendMsg is the API that users use to send streaming messages.
    51  func (s *serverStream) SendMsg(m interface{}) error {
    52  	if err := s.err.Load(); err != nil {
    53  		return errs.WrapFrameError(err, errs.Code(err), "stream sending error")
    54  	}
    55  	msg := codec.Message(s.ctx)
    56  	ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx)
    57  	defer codec.PutBackMessage(newMsg)
    58  	newMsg.WithLocalAddr(msg.LocalAddr())
    59  	newMsg.WithRemoteAddr(msg.RemoteAddr())
    60  	newMsg.WithCompressType(msg.CompressType())
    61  	newMsg.WithStreamID(s.streamID)
    62  	// Refer to the pb code generated by trpc.proto, common to each language, automatically generated code.
    63  	newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, s.streamID))
    64  
    65  	var (
    66  		err           error
    67  		reqBodyBuffer []byte
    68  	)
    69  	serializationType, compressType := s.serializationAndCompressType(newMsg)
    70  	if icodec.IsValidSerializationType(serializationType) {
    71  		reqBodyBuffer, err = codec.Marshal(serializationType, m)
    72  		if err != nil {
    73  			return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Marshal: "+err.Error())
    74  		}
    75  	}
    76  
    77  	// compress
    78  	if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
    79  		reqBodyBuffer, err = codec.Compress(compressType, reqBodyBuffer)
    80  		if err != nil {
    81  			return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Compress: "+err.Error())
    82  		}
    83  	}
    84  
    85  	// Flow control only controls the payload of data.
    86  	if s.sControl != nil {
    87  		if err := s.sControl.GetWindow(uint32(len(reqBodyBuffer))); err != nil {
    88  			return err
    89  		}
    90  	}
    91  
    92  	// encode the entire request.
    93  	reqBuffer, err := s.opts.Codec.Encode(newMsg, reqBodyBuffer)
    94  	if err != nil {
    95  		return errs.NewFrameError(errs.RetServerEncodeFail, "server codec Encode: "+err.Error())
    96  	}
    97  
    98  	// initiate a backend network request.
    99  	return s.opts.StreamTransport.Send(ctx, reqBuffer)
   100  }
   101  
   102  func (s *serverStream) newFrameHead(streamFrameType trpcpb.TrpcStreamFrameType) *trpc.FrameHead {
   103  	return &trpc.FrameHead{
   104  		FrameType:       uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
   105  		StreamFrameType: uint8(streamFrameType),
   106  		StreamID:        s.streamID,
   107  	}
   108  }
   109  
   110  func (s *serverStream) serializationAndCompressType(msg codec.Msg) (int, int) {
   111  	serializationType := msg.SerializationType()
   112  	compressType := msg.CompressType()
   113  	if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) {
   114  		serializationType = s.opts.CurrentSerializationType
   115  	}
   116  	if icodec.IsValidCompressType(s.opts.CurrentCompressType) {
   117  		compressType = s.opts.CurrentCompressType
   118  	}
   119  	return serializationType, compressType
   120  }
   121  
   122  // RecvMsg receives streaming messages, passes in the structure that needs to receive messages,
   123  // and returns the serialized structure.
   124  func (s *serverStream) RecvMsg(m interface{}) error {
   125  	resp, ok := s.recvQueue.Get()
   126  	if !ok {
   127  		if err := s.err.Load(); err != nil {
   128  			return err
   129  		}
   130  		return errs.NewFrameError(errs.RetServerSystemErr, streamClosed)
   131  	}
   132  	if resp.err != nil {
   133  		return resp.err
   134  	}
   135  	if s.rControl != nil {
   136  		if err := s.rControl.OnRecv(uint32(len(resp.data))); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	// Decompress and deserialize the data frame into a structure.
   141  	return s.decompressAndUnmarshal(resp.data, m)
   142  
   143  }
   144  
   145  // decompressAndUnmarshal decompresses the data frame and deserializes it.
   146  func (s *serverStream) decompressAndUnmarshal(data []byte, m interface{}) error {
   147  	msg := codec.Message(s.ctx)
   148  	var err error
   149  	serializationType, compressType := s.serializationAndCompressType(msg)
   150  	if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop {
   151  		data, err = codec.Decompress(compressType, data)
   152  		if err != nil {
   153  			return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Decompress: "+err.Error())
   154  		}
   155  	}
   156  
   157  	// Deserialize the binary body to a specific body structure.
   158  	if icodec.IsValidSerializationType(serializationType) {
   159  		if err := codec.Unmarshal(serializationType, data, m); err != nil {
   160  			return errs.NewFrameError(errs.RetClientDecodeFail, "server codec Unmarshal: "+err.Error())
   161  		}
   162  	}
   163  	return nil
   164  }
   165  
   166  // The CloseSend server closes the stream, where ret represents the close type,
   167  // which is divided into TRPC_STREAM_CLOSE and TRPC_STREAM_RESET.
   168  // message represents the returned message, where error messages can be logged.
   169  func (s *serverStream) CloseSend(closeType, ret int32, message string) error {
   170  	oldMsg := codec.Message(s.ctx)
   171  	ctx, msg := codec.WithCloneContextAndMessage(s.ctx)
   172  	defer codec.PutBackMessage(msg)
   173  	msg.WithLocalAddr(oldMsg.LocalAddr())
   174  	msg.WithRemoteAddr(oldMsg.RemoteAddr())
   175  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, s.streamID))
   176  	msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{
   177  		CloseType: closeType,
   178  		Ret:       ret,
   179  		Msg:       []byte(message),
   180  	})
   181  
   182  	rspBuffer, err := s.opts.Codec.Encode(msg, nil)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	return s.opts.StreamTransport.Send(ctx, rspBuffer)
   187  }
   188  
   189  // newServerStream creates a new server stream, which can send and receive streaming messages.
   190  func newServerStream(ctx context.Context, streamID uint32, opts *server.Options) *serverStream {
   191  	s := &serverStream{
   192  		ctx:      ctx,
   193  		opts:     opts,
   194  		streamID: streamID,
   195  		done:     make(chan struct{}, 1),
   196  	}
   197  	s.recvQueue = queue.New[*response](s.done)
   198  	return s
   199  }
   200  
   201  func (s *serverStream) feedback(w uint32) error {
   202  	oldMsg := codec.Message(s.ctx)
   203  	ctx, msg := codec.WithCloneContextAndMessage(s.ctx)
   204  	defer codec.PutBackMessage(msg)
   205  	msg.WithLocalAddr(oldMsg.LocalAddr())
   206  	msg.WithRemoteAddr(oldMsg.RemoteAddr())
   207  	msg.WithStreamID(s.streamID)
   208  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, s.streamID))
   209  	msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: w})
   210  
   211  	feedbackBuf, err := s.opts.Codec.Encode(msg, nil)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	return s.opts.StreamTransport.Send(ctx, feedbackBuf)
   216  }
   217  
   218  // Context returns the context of the serverStream structure.
   219  func (s *serverStream) Context() context.Context {
   220  	return s.ctx
   221  }
   222  
   223  // The structure of streamDispatcher is used to distribute streaming data.
   224  type streamDispatcher struct {
   225  	m sync.RWMutex
   226  	// local address + remote address + network
   227  	//  => stream ID
   228  	//    => serverStream
   229  	addrToServerStream map[string]map[uint32]*serverStream
   230  	opts               *server.Options
   231  }
   232  
   233  // DefaultStreamDispatcher is the default implementation of the trpc dispatcher,
   234  // supports the data distribution of the trpc streaming protocol.
   235  var DefaultStreamDispatcher = NewStreamDispatcher()
   236  
   237  // NewStreamDispatcher returns a new dispatcher.
   238  func NewStreamDispatcher() server.StreamHandle {
   239  	return &streamDispatcher{
   240  		addrToServerStream: make(map[string]map[uint32]*serverStream),
   241  	}
   242  }
   243  
   244  // storeServerStream msg contains the socket address of the client connection,
   245  // there are multiple streams under each socket address, and map it to serverStream
   246  // again according to the id of the stream.
   247  func (sd *streamDispatcher) storeServerStream(addr string, streamID uint32, ss *serverStream) {
   248  	sd.m.Lock()
   249  	defer sd.m.Unlock()
   250  	if addrToStreamID, ok := sd.addrToServerStream[addr]; !ok {
   251  		// Does not exist, indicating that a new connection is coming, re-create the structure.
   252  		sd.addrToServerStream[addr] = map[uint32]*serverStream{streamID: ss}
   253  	} else {
   254  		addrToStreamID[streamID] = ss
   255  	}
   256  }
   257  
   258  // deleteServerStream deletes the serverStream from cache.
   259  func (sd *streamDispatcher) deleteServerStream(addr string, streamID uint32) {
   260  	sd.m.Lock()
   261  	defer sd.m.Unlock()
   262  	if addrToStreamID, ok := sd.addrToServerStream[addr]; ok {
   263  		if _, ok = addrToStreamID[streamID]; ok {
   264  			delete(addrToStreamID, streamID)
   265  		}
   266  		if len(addrToStreamID) == 0 {
   267  			delete(sd.addrToServerStream, addr)
   268  		}
   269  	}
   270  }
   271  
   272  // loadServerStream loads the stored serverStream through the socket address
   273  // of the client connection and the id of the stream.
   274  func (sd *streamDispatcher) loadServerStream(addr string, streamID uint32) (*serverStream, error) {
   275  	sd.m.RLock()
   276  	defer sd.m.RUnlock()
   277  	addrToStream, ok := sd.addrToServerStream[addr]
   278  	if !ok {
   279  		return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr)
   280  	}
   281  
   282  	var ss *serverStream
   283  	if ss, ok = addrToStream[streamID]; !ok {
   284  		return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchStreamID)
   285  	}
   286  	return ss, nil
   287  }
   288  
   289  // Init initializes some settings of dispatcher.
   290  func (sd *streamDispatcher) Init(opts *server.Options) error {
   291  	sd.opts = opts
   292  	st, ok := sd.opts.Transport.(transport.ServerStreamTransport)
   293  	if !ok {
   294  		return errors.New(streamTransportUnimplemented)
   295  	}
   296  	sd.opts.StreamTransport = st
   297  	sd.opts.ServeOptions = append(sd.opts.ServeOptions,
   298  		transport.WithServerAsync(false), transport.WithCopyFrame(true))
   299  	return nil
   300  }
   301  
   302  // startStreamHandler is used to start the goroutine, execute streamHandler,
   303  // streamHandler is implemented for the specific streaming server.
   304  func (sd *streamDispatcher) startStreamHandler(addr string, streamID uint32,
   305  	ss *serverStream, si *server.StreamServerInfo, sh server.StreamHandler) {
   306  	defer func() {
   307  		sd.deleteServerStream(addr, streamID)
   308  		ss.once.Do(func() { close(ss.done) })
   309  	}()
   310  
   311  	// Execute the implementation code of the server stream.
   312  	var err error
   313  	if ss.opts.StreamFilters != nil {
   314  		err = ss.opts.StreamFilters.Filter(ss, si, sh)
   315  	} else {
   316  		err = sh(ss)
   317  	}
   318  
   319  	var frameworkError *errs.Error
   320  	switch {
   321  	case errors.As(err, &frameworkError):
   322  		err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), int32(frameworkError.Code), frameworkError.Msg)
   323  	case err != nil:
   324  		// return business error.
   325  		err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), 0, err.Error())
   326  	default:
   327  		// Stream is normally closed.
   328  		err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 0, "")
   329  	}
   330  	if err != nil {
   331  		ss.err.Store(err)
   332  		log.Trace(closeSendFail, err)
   333  	}
   334  }
   335  
   336  // setSendControl obtained from the init frame.
   337  func (s *serverStream) setSendControl(msg codec.Msg) (uint32, error) {
   338  	initMeta, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta)
   339  	if !ok {
   340  		return 0, errors.New(streamFrameInvalid)
   341  	}
   342  
   343  	// This section of logic is compatible with framework implementations in other languages
   344  	// that do not enable flow control, and will be deleted later.
   345  	if initMeta.InitWindowSize == 0 {
   346  		// Compatible with the client without flow control enabled.
   347  		s.rControl = nil
   348  		s.sControl = nil
   349  		return initMeta.InitWindowSize, nil
   350  	}
   351  	s.sControl = newSendControl(initMeta.InitWindowSize, s.done)
   352  	return initMeta.InitWindowSize, nil
   353  }
   354  
   355  // handleInit processes the sent init package.
   356  func (sd *streamDispatcher) handleInit(ctx context.Context,
   357  	sh server.StreamHandler, si *server.StreamServerInfo) ([]byte, error) {
   358  	// The Msg in ctx is passed to us by the upper layer, and we can't make any assumptions about its life cycle.
   359  	// Before creating ServerStream, make a complete copy of Msg.
   360  	oldMsg := codec.Message(ctx)
   361  	ctx, msg := codec.WithNewMessage(ctx)
   362  	codec.CopyMsg(msg, oldMsg)
   363  
   364  	streamID := msg.StreamID()
   365  	ss := newServerStream(ctx, streamID, sd.opts)
   366  	w := getWindowSize(sd.opts.MaxWindowSize)
   367  	ss.rControl = newReceiveControl(w, ss.feedback)
   368  	sd.storeServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss)
   369  
   370  	cw, err := ss.setSendControl(msg)
   371  	if err != nil {
   372  		return nil, err
   373  	}
   374  
   375  	// send init response packet.
   376  	newCtx, newMsg := codec.WithCloneContextAndMessage(ctx)
   377  	defer codec.PutBackMessage(newMsg)
   378  	newMsg.WithLocalAddr(msg.LocalAddr())
   379  	newMsg.WithRemoteAddr(msg.RemoteAddr())
   380  	newMsg.WithStreamID(streamID)
   381  	newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, ss.streamID))
   382  
   383  	initMeta := &trpcpb.TrpcStreamInitMeta{ResponseMeta: &trpcpb.TrpcStreamInitResponseMeta{}}
   384  	// If the client does not set it, the server should not set it to prevent incompatibility.
   385  	if cw == 0 {
   386  		initMeta.InitWindowSize = 0
   387  	} else {
   388  		initMeta.InitWindowSize = w
   389  	}
   390  	newMsg.WithStreamFrame(initMeta)
   391  
   392  	rspBuffer, err := ss.opts.Codec.Encode(newMsg, nil)
   393  	if err != nil {
   394  		return nil, err
   395  	}
   396  	if err := ss.opts.StreamTransport.Send(newCtx, rspBuffer); err != nil {
   397  		return nil, err
   398  	}
   399  
   400  	// Initiate a goroutine to execute specific business logic.
   401  	go sd.startStreamHandler(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss, si, sh)
   402  	return nil, errs.ErrServerNoResponse
   403  }
   404  
   405  // handleData handles data messages.
   406  func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error) {
   407  	ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID())
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  	ss.recvQueue.Put(&response{data: req})
   412  	return nil, errs.ErrServerNoResponse
   413  }
   414  
   415  // handleClose handles the Close message.
   416  func (sd *streamDispatcher) handleClose(msg codec.Msg) ([]byte, error) {
   417  	ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID())
   418  	if err != nil {
   419  		// The server has sent the Close frame.
   420  		// Since the timing of the Close frame is unpredictable, when the server receives the Close frame from the client,
   421  		// the Close frame may have been sent, causing the resource to be released, no need to respond to this error.
   422  		log.Trace("handleClose loadServerStream fail", err)
   423  		return nil, errs.ErrServerNoResponse
   424  	}
   425  	// is Reset message.
   426  	if msg.ServerRspErr() != nil {
   427  		ss.recvQueue.Put(&response{err: msg.ServerRspErr()})
   428  		return nil, errs.ErrServerNoResponse
   429  	}
   430  	// is a normal Close message
   431  	ss.recvQueue.Put(&response{err: io.EOF})
   432  	return nil, errs.ErrServerNoResponse
   433  }
   434  
   435  // handleError When the connection is wrong, handle the error.
   436  func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) {
   437  	sd.m.Lock()
   438  	defer sd.m.Unlock()
   439  
   440  	addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr())
   441  	addrToStream, ok := sd.addrToServerStream[addr]
   442  	if !ok {
   443  		return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr)
   444  	}
   445  	for streamID, ss := range addrToStream {
   446  		ss.err.Store(msg.ServerRspErr())
   447  		ss.once.Do(func() { close(ss.done) })
   448  		delete(addrToStream, streamID)
   449  	}
   450  	delete(sd.addrToServerStream, addr)
   451  	return nil, errs.ErrServerNoResponse
   452  }
   453  
   454  // StreamHandleFunc The processing logic after a complete streaming frame received by the streaming transport.
   455  func (sd *streamDispatcher) StreamHandleFunc(ctx context.Context,
   456  	sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) {
   457  	msg := codec.Message(ctx)
   458  	frameHead, ok := msg.FrameHead().(*trpc.FrameHead)
   459  	if !ok {
   460  		// If there is no frame head and serverRspErr, the server connection is abnormal
   461  		// and returns to the upper service.
   462  		if msg.ServerRspErr() != nil {
   463  			return sd.handleError(msg)
   464  		}
   465  		return nil, errs.NewFrameError(errs.RetServerSystemErr, frameHeadNotInMsg)
   466  	}
   467  	msg.WithFrameHead(nil)
   468  	return sd.handleByStreamFrameType(ctx, trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType), sh, si, req)
   469  }
   470  
   471  // handleFeedback handles the feedback frame.
   472  func (sd *streamDispatcher) handleFeedback(msg codec.Msg) ([]byte, error) {
   473  	ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID())
   474  	if err != nil {
   475  		return nil, err
   476  	}
   477  	fb, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta)
   478  	if !ok {
   479  		return nil, errors.New(streamFrameInvalid)
   480  	}
   481  	if ss.sControl != nil {
   482  		ss.sControl.UpdateWindow(fb.WindowSizeIncrement)
   483  	}
   484  	return nil, errs.ErrServerNoResponse
   485  }
   486  
   487  // handleByStreamFrameType performs different logic processing according to the type of stream frame.
   488  func (sd *streamDispatcher) handleByStreamFrameType(ctx context.Context, streamFrameType trpcpb.TrpcStreamFrameType,
   489  	sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) {
   490  	msg := codec.Message(ctx)
   491  	switch streamFrameType {
   492  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT:
   493  		return sd.handleInit(ctx, sh, si)
   494  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA:
   495  		return sd.handleData(msg, req)
   496  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE:
   497  		return sd.handleClose(msg)
   498  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK:
   499  		return sd.handleFeedback(msg)
   500  	default:
   501  		return nil, errs.NewFrameError(errs.RetServerSystemErr, unknownFrameType)
   502  	}
   503  }