trpc.group/trpc-go/trpc-go@v1.0.2/stream/client.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package stream contains streaming client and server APIs.
    15  package stream
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"io"
    21  	"sync"
    22  	"sync/atomic"
    23  
    24  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    25  
    26  	trpc "trpc.group/trpc-go/trpc-go"
    27  	"trpc.group/trpc-go/trpc-go/client"
    28  	"trpc.group/trpc-go/trpc-go/codec"
    29  	"trpc.group/trpc-go/trpc-go/errs"
    30  	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
    31  	"trpc.group/trpc-go/trpc-go/internal/queue"
    32  	"trpc.group/trpc-go/trpc-go/transport"
    33  )
    34  
    35  // Client is the Streaming client interface, NewStream is its only method.
    36  type Client interface {
    37  	// NewStream returns a client stream, which the user uses to call Recv and Send to send,
    38  	// receive and send streaming messages.
    39  	NewStream(ctx context.Context, desc *client.ClientStreamDesc,
    40  		method string, opt ...client.Option) (client.ClientStream, error)
    41  }
    42  
    43  // DefaultStreamClient is the default streaming client.
    44  var DefaultStreamClient = NewStreamClient()
    45  
    46  // NewStreamClient returns a streaming client.
    47  func NewStreamClient() Client {
    48  	// Streaming ID from 0-99 is reserved ID, used as control ID.
    49  	return &streamClient{streamID: uint32(99)}
    50  }
    51  
    52  // an implementation of streamClient Client.
    53  type streamClient struct {
    54  	streamID uint32
    55  }
    56  
    57  // The specific implementation of ClientStream.
    58  type clientStream struct {
    59  	desc      *client.ClientStreamDesc
    60  	method    string
    61  	sc        *streamClient
    62  	ctx       context.Context
    63  	opts      *client.Options
    64  	streamID  uint32
    65  	stream    client.Stream
    66  	recvQueue *queue.Queue[*response]
    67  	closed    uint32
    68  	closeCh   chan struct{}
    69  	closeOnce sync.Once
    70  }
    71  
    72  // NewStream creates a new stream through which users send and receive messages.
    73  func (c *streamClient) NewStream(ctx context.Context, desc *client.ClientStreamDesc,
    74  	method string, opt ...client.Option) (client.ClientStream, error) {
    75  	return c.newStream(ctx, desc, method, opt...)
    76  }
    77  
    78  // newStream creates a new stream through which users send and receive messages.
    79  func (c *streamClient) newStream(ctx context.Context, desc *client.ClientStreamDesc,
    80  	method string, opt ...client.Option) (client.ClientStream, error) {
    81  	ctx, _ = codec.EnsureMessage(ctx)
    82  	cs := &clientStream{
    83  		desc:      desc,
    84  		method:    method,
    85  		sc:        c,
    86  		streamID:  atomic.AddUint32(&c.streamID, 1),
    87  		ctx:       ctx,
    88  		closeCh:   make(chan struct{}, 1),
    89  		recvQueue: queue.New[*response](ctx.Done()),
    90  		stream:    client.NewStream(),
    91  	}
    92  	if err := cs.prepare(opt...); err != nil {
    93  		return nil, err
    94  	}
    95  	if cs.opts.StreamFilters != nil {
    96  		return cs.opts.StreamFilters.Filter(cs.ctx, cs.desc, cs.invoke)
    97  	}
    98  	return cs.invoke(cs.ctx, cs.desc)
    99  }
   100  
   101  // Context returns the Context of the current stream.
   102  func (cs *clientStream) Context() context.Context {
   103  	return cs.ctx
   104  }
   105  
   106  // RecvMsg receives the message, if there is no message it will get stuck.
   107  // RecvMsg and SendMsg are concurrency safe, but two RecvMsg are not concurrency safe.
   108  func (cs *clientStream) RecvMsg(m interface{}) error {
   109  	if err := cs.recv(m); err != nil {
   110  		return err
   111  	}
   112  	if cs.desc.ServerStreams {
   113  		// Subsequent messages should be received by subsequent RecvMsg calls.
   114  		return nil
   115  	}
   116  	// Special handling for non-server-stream rpcs.
   117  	// This recv expects EOF or errors.
   118  	err := cs.recv(m)
   119  	if err == nil {
   120  		return errs.NewFrameError(errs.RetClientStreamReadEnd,
   121  			"client streaming protocol violation: get <nil>, want <EOF>")
   122  	}
   123  	if err == io.EOF {
   124  		return nil
   125  	}
   126  	return err
   127  }
   128  
   129  func (cs *clientStream) recv(m interface{}) error {
   130  	resp, ok := cs.recvQueue.Get()
   131  	if !ok {
   132  		return cs.dealContextDone()
   133  	}
   134  	if resp.err != nil {
   135  		return resp.err
   136  	}
   137  	// Gather flow control information.
   138  	if err := cs.recvFlowCtl(len(resp.data)); err != nil {
   139  		return err
   140  	}
   141  
   142  	serializationType := codec.Message(cs.ctx).SerializationType()
   143  	if icodec.IsValidSerializationType(cs.opts.CurrentSerializationType) {
   144  		serializationType = cs.opts.CurrentSerializationType
   145  	}
   146  	if err := codec.Unmarshal(serializationType, resp.data, m); err != nil {
   147  		return errs.NewFrameError(errs.RetClientDecodeFail, "client codec Unmarshal: "+err.Error())
   148  	}
   149  	return nil
   150  }
   151  
   152  func (cs *clientStream) recvFlowCtl(n int) error {
   153  	if cs.opts.RControl == nil {
   154  		return nil
   155  	}
   156  	// If the bottom layer has received the Close frame, then no feedback frame will be returned.
   157  	if atomic.LoadUint32(&cs.closed) == 1 {
   158  		return nil
   159  	}
   160  	if err := cs.opts.RControl.OnRecv(uint32(n)); err != nil {
   161  		// Avoid receiving the Close frame after entering OnRecv, and make another judgment.
   162  		if atomic.LoadUint32(&cs.closed) == 1 {
   163  			return nil
   164  		}
   165  		return err
   166  	}
   167  	return nil
   168  }
   169  
   170  // dealContextDone returns the final error message according to the Error type of the context.
   171  func (cs *clientStream) dealContextDone() error {
   172  	if cs.ctx.Err() == context.Canceled {
   173  		return errs.NewFrameError(errs.RetClientCanceled, "tcp client stream canceled before recv: "+cs.ctx.Err().Error())
   174  	}
   175  	if cs.ctx.Err() == context.DeadlineExceeded {
   176  		return errs.NewFrameError(errs.RetClientTimeout,
   177  			"tcp client stream canceled timeout before recv: "+cs.ctx.Err().Error())
   178  	}
   179  	return nil
   180  }
   181  
   182  // SendMsg is the specific implementation of sending a message.
   183  // RecvMsg and SendMsg are concurrency safe, but two SendMsg are not concurrency safe.
   184  func (cs *clientStream) SendMsg(m interface{}) error {
   185  	ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   186  	defer codec.PutBackMessage(msg)
   187  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, cs.streamID))
   188  	msg.WithStreamID(cs.streamID)
   189  	msg.WithClientRPCName(cs.method)
   190  	msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method))
   191  	return cs.stream.Send(ctx, m)
   192  }
   193  
   194  func newFrameHead(t trpcpb.TrpcStreamFrameType, id uint32) *trpc.FrameHead {
   195  	return &trpc.FrameHead{
   196  		FrameType:       uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
   197  		StreamFrameType: uint8(t),
   198  		StreamID:        id,
   199  	}
   200  }
   201  
   202  // CloseSend normally closes the sender, no longer sends messages, only accepts messages.
   203  func (cs *clientStream) CloseSend() error {
   204  	ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   205  	defer codec.PutBackMessage(msg)
   206  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, cs.streamID))
   207  	msg.WithStreamID(cs.streamID)
   208  	msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{
   209  		CloseType: int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE),
   210  		Ret:       0,
   211  	})
   212  	return cs.stream.Send(ctx, nil)
   213  }
   214  
   215  func (cs *clientStream) prepare(opt ...client.Option) error {
   216  	msg := codec.Message(cs.ctx)
   217  	msg.WithClientRPCName(cs.method)
   218  	msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method))
   219  	msg.WithStreamID(cs.streamID)
   220  
   221  	opt = append([]client.Option{client.WithStreamTransport(transport.DefaultClientStreamTransport)}, opt...)
   222  	opts, err := cs.stream.Init(cs.ctx, opt...)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	cs.opts = opts
   227  	return nil
   228  }
   229  
   230  func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) (client.ClientStream, error) {
   231  	if err := cs.stream.Invoke(ctx); err != nil {
   232  		return nil, err
   233  	}
   234  	w := getWindowSize(cs.opts.MaxWindowSize)
   235  	newCtx, newMsg := codec.WithCloneContextAndMessage(ctx)
   236  	defer codec.PutBackMessage(newMsg)
   237  	copyMetaData(newMsg, codec.Message(cs.ctx))
   238  	newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, cs.streamID))
   239  	newMsg.WithClientRPCName(cs.method)
   240  	newMsg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method))
   241  	newMsg.WithStreamID(cs.streamID)
   242  	newMsg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{
   243  		RequestMeta:    &trpcpb.TrpcStreamInitRequestMeta{},
   244  		InitWindowSize: w,
   245  	})
   246  	cs.opts.RControl = newReceiveControl(w, cs.feedback)
   247  	// Send the init message out.
   248  	if err := cs.stream.Send(newCtx, nil); err != nil {
   249  		return nil, err
   250  	}
   251  	// After init is sent, the server will return directly.
   252  	if _, err := cs.stream.Recv(newCtx); err != nil {
   253  		return nil, err
   254  	}
   255  	if newMsg.ClientRspErr() != nil {
   256  		return nil, newMsg.ClientRspErr()
   257  	}
   258  
   259  	initWindowSize := defaultInitWindowSize
   260  	if initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta); ok {
   261  		// If the server has feedback, use the server's window, if not, use the default window.
   262  		initWindowSize = initRspMeta.GetInitWindowSize()
   263  	}
   264  	cs.configSendControl(initWindowSize)
   265  
   266  	// Start the dispatch goroutine loop to send packets.
   267  	go cs.dispatch()
   268  	return cs, nil
   269  }
   270  
   271  // configSendControl configs Send Control according to initWindowSize.
   272  func (cs *clientStream) configSendControl(initWindowSize uint32) {
   273  	if initWindowSize == 0 {
   274  		// Disable flow control, compatible with the server without flow control enabled, delete this logic later.
   275  		cs.opts.RControl = nil
   276  		cs.opts.SControl = nil
   277  		return
   278  	}
   279  	cs.opts.SControl = newSendControl(initWindowSize, cs.ctx.Done(), cs.closeCh)
   280  }
   281  
   282  // feedback send feedback frame.
   283  func (cs *clientStream) feedback(i uint32) error {
   284  	ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   285  	defer codec.PutBackMessage(msg)
   286  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, cs.streamID))
   287  	msg.WithStreamID(cs.streamID)
   288  	msg.WithClientRPCName(cs.method)
   289  	msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method))
   290  	msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: i})
   291  	return cs.stream.Send(ctx, nil)
   292  }
   293  
   294  // handleFrame performs different logical processing according to the type of frame.
   295  func (cs *clientStream) handleFrame(ctx context.Context, resp *response,
   296  	respData []byte, frameHead *trpc.FrameHead) error {
   297  	msg := codec.Message(ctx)
   298  	switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) {
   299  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA:
   300  		// Get the data and return it to the client.
   301  		resp.data = respData
   302  		resp.err = nil
   303  		cs.recvQueue.Put(resp)
   304  		return nil
   305  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE:
   306  		// Close, it should be judged as Reset or Close.
   307  		resp.data = nil
   308  		var err error
   309  		if msg.ClientRspErr() != nil {
   310  			// Description is Reset.
   311  			err = msg.ClientRspErr()
   312  		} else {
   313  			err = io.EOF
   314  		}
   315  		resp.err = err
   316  		cs.recvQueue.Put(resp)
   317  		return err
   318  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK:
   319  		cs.handleFeedback(msg)
   320  		return nil
   321  	default:
   322  		return nil
   323  	}
   324  }
   325  
   326  // handleFeedback handles the feedback frame.
   327  func (cs *clientStream) handleFeedback(msg codec.Msg) {
   328  	if feedbackFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta); ok && cs.opts.SControl != nil {
   329  		cs.opts.SControl.UpdateWindow(feedbackFrame.WindowSizeIncrement)
   330  	}
   331  }
   332  
   333  // dispatch is used to distribute the received data packets, receive them in a loop,
   334  // and then distribute the data packets according to different data types.
   335  func (cs *clientStream) dispatch() {
   336  	defer func() {
   337  		cs.opts.StreamTransport.Close(cs.ctx)
   338  		cs.close()
   339  	}()
   340  	for {
   341  		ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   342  		msg.WithStreamID(cs.streamID)
   343  		respData, err := cs.stream.Recv(ctx)
   344  		if err != nil {
   345  			// return to client on error.
   346  			cs.recvQueue.Put(&response{
   347  				err: errs.Wrap(err, errs.RetClientStreamReadEnd, streamClosed),
   348  			})
   349  			return
   350  		}
   351  
   352  		frameHead, ok := msg.FrameHead().(*trpc.FrameHead)
   353  		if !ok {
   354  			cs.recvQueue.Put(&response{
   355  				err: errors.New(frameHeadInvalid),
   356  			})
   357  			return
   358  		}
   359  
   360  		if err := cs.handleFrame(ctx, &response{}, respData, frameHead); err != nil {
   361  			// If there is a Close frame, the dispatch goroutine ends.
   362  			return
   363  		}
   364  	}
   365  }
   366  
   367  func (cs *clientStream) close() {
   368  	cs.closeOnce.Do(func() {
   369  		atomic.StoreUint32(&cs.closed, 1)
   370  		close(cs.closeCh)
   371  	})
   372  }
   373  
   374  func copyMetaData(dst codec.Msg, src codec.Msg) {
   375  	if src.ClientMetaData() != nil {
   376  		dst.WithClientMetaData(src.ClientMetaData().Clone())
   377  	}
   378  }