trpc.group/trpc-go/trpc-go@v1.0.3/stream/client.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package stream contains streaming client and server APIs.
    15  package stream
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"sync"
    23  	"sync/atomic"
    24  
    25  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    26  
    27  	trpc "trpc.group/trpc-go/trpc-go"
    28  	"trpc.group/trpc-go/trpc-go/client"
    29  	"trpc.group/trpc-go/trpc-go/codec"
    30  	"trpc.group/trpc-go/trpc-go/errs"
    31  	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
    32  	"trpc.group/trpc-go/trpc-go/internal/queue"
    33  	"trpc.group/trpc-go/trpc-go/transport"
    34  )
    35  
    36  // Client is the Streaming client interface, NewStream is its only method.
    37  type Client interface {
    38  	// NewStream returns a client stream, which the user uses to call Recv and Send to send,
    39  	// receive and send streaming messages.
    40  	NewStream(ctx context.Context, desc *client.ClientStreamDesc,
    41  		method string, opt ...client.Option) (client.ClientStream, error)
    42  }
    43  
    44  // DefaultStreamClient is the default streaming client.
    45  var DefaultStreamClient = NewStreamClient()
    46  
    47  // NewStreamClient returns a streaming client.
    48  func NewStreamClient() Client {
    49  	// Streaming ID from 0-99 is reserved ID, used as control ID.
    50  	return &streamClient{streamID: uint32(99)}
    51  }
    52  
    53  // an implementation of streamClient Client.
    54  type streamClient struct {
    55  	streamID uint32
    56  }
    57  
    58  // The specific implementation of ClientStream.
    59  type clientStream struct {
    60  	desc      *client.ClientStreamDesc
    61  	method    string
    62  	sc        *streamClient
    63  	ctx       context.Context
    64  	opts      *client.Options
    65  	streamID  uint32
    66  	stream    client.Stream
    67  	recvQueue *queue.Queue[*response]
    68  	closed    uint32
    69  	closeCh   chan struct{}
    70  	closeOnce sync.Once
    71  }
    72  
    73  // NewStream creates a new stream through which users send and receive messages.
    74  func (c *streamClient) NewStream(ctx context.Context, desc *client.ClientStreamDesc,
    75  	method string, opt ...client.Option) (client.ClientStream, error) {
    76  	return c.newStream(ctx, desc, method, opt...)
    77  }
    78  
    79  // newStream creates a new stream through which users send and receive messages.
    80  func (c *streamClient) newStream(ctx context.Context, desc *client.ClientStreamDesc,
    81  	method string, opt ...client.Option) (client.ClientStream, error) {
    82  	ctx, _ = codec.EnsureMessage(ctx)
    83  	cs := &clientStream{
    84  		desc:      desc,
    85  		method:    method,
    86  		sc:        c,
    87  		streamID:  atomic.AddUint32(&c.streamID, 1),
    88  		ctx:       ctx,
    89  		closeCh:   make(chan struct{}, 1),
    90  		recvQueue: queue.New[*response](ctx.Done()),
    91  		stream:    client.NewStream(),
    92  	}
    93  	if err := cs.prepare(opt...); err != nil {
    94  		return nil, err
    95  	}
    96  	if cs.opts.StreamFilters != nil {
    97  		return cs.opts.StreamFilters.Filter(cs.ctx, cs.desc, cs.invoke)
    98  	}
    99  	return cs.invoke(cs.ctx, cs.desc)
   100  }
   101  
   102  // Context returns the Context of the current stream.
   103  func (cs *clientStream) Context() context.Context {
   104  	return cs.ctx
   105  }
   106  
   107  // RecvMsg receives the message, if there is no message it will get stuck.
   108  // RecvMsg and SendMsg are concurrency safe, but two RecvMsg are not concurrency safe.
   109  func (cs *clientStream) RecvMsg(m interface{}) error {
   110  	if err := cs.recv(m); err != nil {
   111  		return err
   112  	}
   113  	if cs.desc.ServerStreams {
   114  		// Subsequent messages should be received by subsequent RecvMsg calls.
   115  		return nil
   116  	}
   117  	// Special handling for non-server-stream rpcs.
   118  	// This recv expects EOF or errors.
   119  	err := cs.recv(m)
   120  	if err == nil {
   121  		return errs.NewFrameError(errs.RetClientStreamReadEnd,
   122  			"client streaming protocol violation: get <nil>, want <EOF>")
   123  	}
   124  	if err == io.EOF {
   125  		return nil
   126  	}
   127  	return err
   128  }
   129  
   130  func (cs *clientStream) recv(m interface{}) error {
   131  	resp, ok := cs.recvQueue.Get()
   132  	if !ok {
   133  		return cs.dealContextDone()
   134  	}
   135  	if resp.err != nil {
   136  		return resp.err
   137  	}
   138  	// Gather flow control information.
   139  	if err := cs.recvFlowCtl(len(resp.data)); err != nil {
   140  		return err
   141  	}
   142  
   143  	serializationType := codec.Message(cs.ctx).SerializationType()
   144  	if icodec.IsValidSerializationType(cs.opts.CurrentSerializationType) {
   145  		serializationType = cs.opts.CurrentSerializationType
   146  	}
   147  	if err := codec.Unmarshal(serializationType, resp.data, m); err != nil {
   148  		return errs.NewFrameError(errs.RetClientDecodeFail, "client codec Unmarshal: "+err.Error())
   149  	}
   150  	return nil
   151  }
   152  
   153  func (cs *clientStream) recvFlowCtl(n int) error {
   154  	if cs.opts.RControl == nil {
   155  		return nil
   156  	}
   157  	// If the bottom layer has received the Close frame, then no feedback frame will be returned.
   158  	if atomic.LoadUint32(&cs.closed) == 1 {
   159  		return nil
   160  	}
   161  	if err := cs.opts.RControl.OnRecv(uint32(n)); err != nil {
   162  		// Avoid receiving the Close frame after entering OnRecv, and make another judgment.
   163  		if atomic.LoadUint32(&cs.closed) == 1 {
   164  			return nil
   165  		}
   166  		return err
   167  	}
   168  	return nil
   169  }
   170  
   171  // dealContextDone returns the final error message according to the Error type of the context.
   172  func (cs *clientStream) dealContextDone() error {
   173  	if cs.ctx.Err() == context.Canceled {
   174  		return errs.NewFrameError(errs.RetClientCanceled, "tcp client stream canceled before recv: "+cs.ctx.Err().Error())
   175  	}
   176  	if cs.ctx.Err() == context.DeadlineExceeded {
   177  		return errs.NewFrameError(errs.RetClientTimeout,
   178  			"tcp client stream canceled timeout before recv: "+cs.ctx.Err().Error())
   179  	}
   180  	return nil
   181  }
   182  
   183  // SendMsg is the specific implementation of sending a message.
   184  // RecvMsg and SendMsg are concurrency safe, but two SendMsg are not concurrency safe.
   185  func (cs *clientStream) SendMsg(m interface{}) error {
   186  	ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   187  	defer codec.PutBackMessage(msg)
   188  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, cs.streamID))
   189  	msg.WithStreamID(cs.streamID)
   190  	msg.WithClientRPCName(cs.method)
   191  	msg.WithCompressType(codec.Message(cs.ctx).CompressType())
   192  	return cs.stream.Send(ctx, m)
   193  }
   194  
   195  func newFrameHead(t trpcpb.TrpcStreamFrameType, id uint32) *trpc.FrameHead {
   196  	return &trpc.FrameHead{
   197  		FrameType:       uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
   198  		StreamFrameType: uint8(t),
   199  		StreamID:        id,
   200  	}
   201  }
   202  
   203  // CloseSend normally closes the sender, no longer sends messages, only accepts messages.
   204  func (cs *clientStream) CloseSend() error {
   205  	ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   206  	defer codec.PutBackMessage(msg)
   207  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, cs.streamID))
   208  	msg.WithStreamID(cs.streamID)
   209  	msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{
   210  		CloseType: int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE),
   211  		Ret:       0,
   212  	})
   213  	return cs.stream.Send(ctx, nil)
   214  }
   215  
   216  func (cs *clientStream) prepare(opt ...client.Option) error {
   217  	msg := codec.Message(cs.ctx)
   218  	msg.WithClientRPCName(cs.method)
   219  	msg.WithStreamID(cs.streamID)
   220  
   221  	opt = append([]client.Option{client.WithStreamTransport(transport.DefaultClientStreamTransport)}, opt...)
   222  	opts, err := cs.stream.Init(cs.ctx, opt...)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	cs.opts = opts
   227  	return nil
   228  }
   229  
   230  func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) (client.ClientStream, error) {
   231  	if err := cs.stream.Invoke(ctx); err != nil {
   232  		return nil, err
   233  	}
   234  	w := getWindowSize(cs.opts.MaxWindowSize)
   235  	newCtx, newMsg := codec.WithCloneContextAndMessage(ctx)
   236  	defer codec.PutBackMessage(newMsg)
   237  	copyMetaData(newMsg, codec.Message(cs.ctx))
   238  	newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, cs.streamID))
   239  	newMsg.WithClientRPCName(cs.method)
   240  	newMsg.WithStreamID(cs.streamID)
   241  	newMsg.WithCompressType(codec.Message(cs.ctx).CompressType())
   242  	newMsg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{
   243  		RequestMeta:    &trpcpb.TrpcStreamInitRequestMeta{},
   244  		InitWindowSize: w,
   245  	})
   246  	cs.opts.RControl = newReceiveControl(w, cs.feedback)
   247  	// Send the init message out.
   248  	if err := cs.stream.Send(newCtx, nil); err != nil {
   249  		return nil, err
   250  	}
   251  	// After init is sent, the server will return directly.
   252  	if _, err := cs.stream.Recv(newCtx); err != nil {
   253  		return nil, err
   254  	}
   255  	initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta)
   256  	if !ok {
   257  		return nil, fmt.Errorf("client stream (method = %s, streamID = %d) recv "+
   258  			"unexpected frame type: %T, expected: %T",
   259  			cs.method, cs.streamID, newMsg.StreamFrame(), (*trpcpb.TrpcStreamInitMeta)(nil))
   260  	}
   261  	initWindowSize := initRspMeta.GetInitWindowSize()
   262  	cs.configSendControl(initWindowSize)
   263  
   264  	// Start the dispatch goroutine loop to send packets.
   265  	go cs.dispatch()
   266  	return cs, nil
   267  }
   268  
   269  // configSendControl configs Send Control according to initWindowSize.
   270  func (cs *clientStream) configSendControl(initWindowSize uint32) {
   271  	if initWindowSize == 0 {
   272  		// Disable flow control, compatible with the server without flow control enabled, delete this logic later.
   273  		cs.opts.RControl = nil
   274  		cs.opts.SControl = nil
   275  		return
   276  	}
   277  	cs.opts.SControl = newSendControl(initWindowSize, cs.ctx.Done(), cs.closeCh)
   278  }
   279  
   280  // feedback send feedback frame.
   281  func (cs *clientStream) feedback(i uint32) error {
   282  	ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   283  	defer codec.PutBackMessage(msg)
   284  	msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, cs.streamID))
   285  	msg.WithStreamID(cs.streamID)
   286  	msg.WithClientRPCName(cs.method)
   287  	msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: i})
   288  	return cs.stream.Send(ctx, nil)
   289  }
   290  
   291  // handleFrame performs different logical processing according to the type of frame.
   292  func (cs *clientStream) handleFrame(ctx context.Context, resp *response,
   293  	respData []byte, frameHead *trpc.FrameHead) error {
   294  	msg := codec.Message(ctx)
   295  	switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) {
   296  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA:
   297  		// Get the data and return it to the client.
   298  		resp.data = respData
   299  		resp.err = nil
   300  		cs.recvQueue.Put(resp)
   301  		return nil
   302  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE:
   303  		// Close, it should be judged as Reset or Close.
   304  		resp.data = nil
   305  		var err error
   306  		if msg.ClientRspErr() != nil {
   307  			// Description is Reset.
   308  			err = msg.ClientRspErr()
   309  		} else {
   310  			err = io.EOF
   311  		}
   312  		resp.err = err
   313  		cs.recvQueue.Put(resp)
   314  		return err
   315  	case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK:
   316  		cs.handleFeedback(msg)
   317  		return nil
   318  	default:
   319  		return nil
   320  	}
   321  }
   322  
   323  // handleFeedback handles the feedback frame.
   324  func (cs *clientStream) handleFeedback(msg codec.Msg) {
   325  	if feedbackFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta); ok && cs.opts.SControl != nil {
   326  		cs.opts.SControl.UpdateWindow(feedbackFrame.WindowSizeIncrement)
   327  	}
   328  }
   329  
   330  // dispatch is used to distribute the received data packets, receive them in a loop,
   331  // and then distribute the data packets according to different data types.
   332  func (cs *clientStream) dispatch() {
   333  	defer func() {
   334  		cs.opts.StreamTransport.Close(cs.ctx)
   335  		cs.close()
   336  	}()
   337  	for {
   338  		ctx, msg := codec.WithCloneContextAndMessage(cs.ctx)
   339  		msg.WithCompressType(codec.Message(cs.ctx).CompressType())
   340  		msg.WithStreamID(cs.streamID)
   341  		respData, err := cs.stream.Recv(ctx)
   342  		if err != nil {
   343  			// return to client on error.
   344  			cs.recvQueue.Put(&response{
   345  				err: errs.WrapFrameError(err, errs.RetClientStreamReadEnd, streamClosed),
   346  			})
   347  			return
   348  		}
   349  
   350  		frameHead, ok := msg.FrameHead().(*trpc.FrameHead)
   351  		if !ok {
   352  			cs.recvQueue.Put(&response{
   353  				err: errors.New(frameHeadInvalid),
   354  			})
   355  			return
   356  		}
   357  
   358  		if err := cs.handleFrame(ctx, &response{}, respData, frameHead); err != nil {
   359  			// If there is a Close frame, the dispatch goroutine ends.
   360  			return
   361  		}
   362  	}
   363  }
   364  
   365  func (cs *clientStream) close() {
   366  	cs.closeOnce.Do(func() {
   367  		atomic.StoreUint32(&cs.closed, 1)
   368  		close(cs.closeCh)
   369  	})
   370  }
   371  
   372  func copyMetaData(dst codec.Msg, src codec.Msg) {
   373  	if src.ClientMetaData() != nil {
   374  		dst.WithClientMetaData(src.ClientMetaData().Clone())
   375  	}
   376  }