github.com/cloudwego/kitex@v0.9.0/client/stream.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"sync/atomic"
    23  
    24  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
    25  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    26  
    27  	"github.com/cloudwego/kitex/pkg/endpoint"
    28  	"github.com/cloudwego/kitex/pkg/remote"
    29  	"github.com/cloudwego/kitex/pkg/remote/remotecli"
    30  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    31  	"github.com/cloudwego/kitex/pkg/streaming"
    32  )
    33  
    34  // Streaming client streaming interface for code generate
    35  type Streaming interface {
    36  	Stream(ctx context.Context, method string, request, response interface{}) error
    37  }
    38  
    39  // Stream implements the Streaming interface
    40  func (kc *kClient) Stream(ctx context.Context, method string, request, response interface{}) error {
    41  	if !kc.inited {
    42  		panic("client not initialized")
    43  	}
    44  	if kc.closed {
    45  		panic("client is already closed")
    46  	}
    47  	if ctx == nil {
    48  		panic("ctx is nil")
    49  	}
    50  	var ri rpcinfo.RPCInfo
    51  	ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil)
    52  
    53  	rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming)
    54  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
    55  
    56  	ctx = kc.opt.TracerCtl.DoStart(ctx, ri)
    57  	return kc.sEps(ctx, request, response)
    58  }
    59  
    60  func (kc *kClient) invokeSendEndpoint() endpoint.SendEndpoint {
    61  	return func(stream streaming.Stream, req interface{}) (err error) {
    62  		return stream.SendMsg(req)
    63  	}
    64  }
    65  
    66  func (kc *kClient) invokeRecvEndpoint() endpoint.RecvEndpoint {
    67  	return func(stream streaming.Stream, resp interface{}) (err error) {
    68  		return stream.RecvMsg(resp)
    69  	}
    70  }
    71  
    72  func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) {
    73  	handler, err := kc.opt.RemoteOpt.CliHandlerFactory.NewTransHandler(kc.opt.RemoteOpt)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  	for _, h := range kc.opt.MetaHandlers {
    78  		if shdlr, ok := h.(remote.StreamingMetaHandler); ok {
    79  			kc.opt.RemoteOpt.StreamingMetaHandlers = append(kc.opt.RemoteOpt.StreamingMetaHandlers, shdlr)
    80  		}
    81  	}
    82  
    83  	recvEndpoint := kc.opt.Streaming.BuildRecvInvokeChain(kc.invokeRecvEndpoint())
    84  	sendEndpoint := kc.opt.Streaming.BuildSendInvokeChain(kc.invokeSendEndpoint())
    85  
    86  	return func(ctx context.Context, req, resp interface{}) (err error) {
    87  		// req and resp as &streaming.Stream
    88  		ri := rpcinfo.GetRPCInfo(ctx)
    89  		st, err := remotecli.NewStream(ctx, ri, handler, kc.opt.RemoteOpt)
    90  		if err != nil {
    91  			return
    92  		}
    93  		clientStream := newStream(st, kc, ri, kc.getStreamingMode(ri), sendEndpoint, recvEndpoint)
    94  		resp.(*streaming.Result).Stream = clientStream
    95  		return
    96  	}, nil
    97  }
    98  
    99  func (kc *kClient) getStreamingMode(ri rpcinfo.RPCInfo) serviceinfo.StreamingMode {
   100  	methodInfo := kc.svcInfo.MethodInfo(ri.Invocation().MethodName())
   101  	if methodInfo == nil {
   102  		return serviceinfo.StreamingNone
   103  	}
   104  	return methodInfo.StreamingMode()
   105  }
   106  
   107  type stream struct {
   108  	stream streaming.Stream
   109  	kc     *kClient
   110  	ri     rpcinfo.RPCInfo
   111  
   112  	streamingMode serviceinfo.StreamingMode
   113  	sendEndpoint  endpoint.SendEndpoint
   114  	recvEndpoint  endpoint.RecvEndpoint
   115  	finished      uint32
   116  }
   117  
   118  var _ streaming.WithDoFinish = (*stream)(nil)
   119  
   120  func newStream(s streaming.Stream, kc *kClient, ri rpcinfo.RPCInfo,
   121  	mode serviceinfo.StreamingMode, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint,
   122  ) *stream {
   123  	return &stream{
   124  		stream:        s,
   125  		kc:            kc,
   126  		ri:            ri,
   127  		streamingMode: mode,
   128  		sendEndpoint:  sendEP,
   129  		recvEndpoint:  recvEP,
   130  	}
   131  }
   132  
   133  func (s *stream) SetTrailer(metadata.MD) {
   134  	panic("this method should only be used in server side stream!")
   135  }
   136  
   137  func (s *stream) SetHeader(metadata.MD) error {
   138  	panic("this method should only be used in server side stream!")
   139  }
   140  
   141  func (s *stream) SendHeader(metadata.MD) error {
   142  	panic("this method should only be used in server side stream!")
   143  }
   144  
   145  // Header returns the header metadata sent by the server if any.
   146  // If a non-nil error is returned, stream.DoFinish() will be called to record the EndOfStream
   147  func (s *stream) Header() (md metadata.MD, err error) {
   148  	if md, err = s.stream.Header(); err != nil {
   149  		s.DoFinish(err)
   150  	}
   151  	return
   152  }
   153  
   154  func (s *stream) Trailer() metadata.MD {
   155  	return s.stream.Trailer()
   156  }
   157  
   158  func (s *stream) Context() context.Context {
   159  	return s.stream.Context()
   160  }
   161  
   162  // RecvMsg receives a message from the server.
   163  // If an error is returned, stream.DoFinish() will be called to record the end of stream
   164  func (s *stream) RecvMsg(m interface{}) (err error) {
   165  	err = s.recvEndpoint(s.stream, m)
   166  	if err == nil {
   167  		err = s.ri.Invocation().BizStatusErr()
   168  	}
   169  	if err != nil || s.streamingMode == serviceinfo.StreamingClient {
   170  		s.DoFinish(err)
   171  	}
   172  	return
   173  }
   174  
   175  // SendMsg sends a message to the server.
   176  // If an error is returned, stream.DoFinish() will be called to record the end of stream
   177  func (s *stream) SendMsg(m interface{}) (err error) {
   178  	if err = s.sendEndpoint(s.stream, m); err != nil {
   179  		s.DoFinish(err)
   180  	}
   181  	return
   182  }
   183  
   184  // Close will send a frame with EndStream=true to the server.
   185  // It will always return a nil
   186  func (s *stream) Close() error {
   187  	return s.stream.Close()
   188  }
   189  
   190  // DoFinish implements the streaming.WithDoFinish interface, and it records the end of stream
   191  func (s *stream) DoFinish(err error) {
   192  	if atomic.SwapUint32(&s.finished, 1) == 1 {
   193  		// already called
   194  		return
   195  	}
   196  	if err == io.EOF {
   197  		err = nil
   198  	}
   199  	ctx := s.Context()
   200  	ri := rpcinfo.GetRPCInfo(ctx)
   201  	s.kc.opt.TracerCtl.DoFinish(ctx, ri, err)
   202  }