github.com/jhump/protoreflect@v1.16.0/dynamic/grpcdynamic/stub.go (about)

     1  // Package grpcdynamic provides a dynamic RPC stub. It can be used to invoke RPC
     2  // method where only method descriptors are known. The actual request and response
     3  // messages may be dynamic messages.
     4  package grpcdynamic
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"io"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  
    15  	"github.com/jhump/protoreflect/desc"
    16  	"github.com/jhump/protoreflect/dynamic"
    17  )
    18  
    19  // Stub is an RPC client stub, used for dynamically dispatching RPCs to a server.
    20  type Stub struct {
    21  	channel Channel
    22  	mf      *dynamic.MessageFactory
    23  }
    24  
    25  // Channel represents the operations necessary to issue RPCs via gRPC. The
    26  // *grpc.ClientConn type provides this interface and will typically the concrete
    27  // type used to construct Stubs. But the use of this interface allows
    28  // construction of stubs that use alternate concrete types as the transport for
    29  // RPC operations.
    30  type Channel = grpc.ClientConnInterface
    31  
    32  // NewStub creates a new RPC stub that uses the given channel for dispatching RPCs.
    33  func NewStub(channel Channel) Stub {
    34  	return NewStubWithMessageFactory(channel, nil)
    35  }
    36  
    37  // NewStubWithMessageFactory creates a new RPC stub that uses the given channel for
    38  // dispatching RPCs and the given MessageFactory for creating response messages.
    39  func NewStubWithMessageFactory(channel Channel, mf *dynamic.MessageFactory) Stub {
    40  	return Stub{channel: channel, mf: mf}
    41  }
    42  
    43  func requestMethod(md *desc.MethodDescriptor) string {
    44  	return fmt.Sprintf("/%s/%s", md.GetService().GetFullyQualifiedName(), md.GetName())
    45  }
    46  
    47  // InvokeRpc sends a unary RPC and returns the response. Use this for unary methods.
    48  func (s Stub) InvokeRpc(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (proto.Message, error) {
    49  	if method.IsClientStreaming() || method.IsServerStreaming() {
    50  		return nil, fmt.Errorf("InvokeRpc is for unary methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
    51  	}
    52  	if err := checkMessageType(method.GetInputType(), request); err != nil {
    53  		return nil, err
    54  	}
    55  	resp := s.mf.NewMessage(method.GetOutputType())
    56  	if err := s.channel.Invoke(ctx, requestMethod(method), request, resp, opts...); err != nil {
    57  		return nil, err
    58  	}
    59  	return resp, nil
    60  }
    61  
    62  // InvokeRpcServerStream sends a unary RPC and returns the response stream. Use this for server-streaming methods.
    63  func (s Stub) InvokeRpcServerStream(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (*ServerStream, error) {
    64  	if method.IsClientStreaming() || !method.IsServerStreaming() {
    65  		return nil, fmt.Errorf("InvokeRpcServerStream is for server-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
    66  	}
    67  	if err := checkMessageType(method.GetInputType(), request); err != nil {
    68  		return nil, err
    69  	}
    70  	ctx, cancel := context.WithCancel(ctx)
    71  	sd := grpc.StreamDesc{
    72  		StreamName:    method.GetName(),
    73  		ServerStreams: method.IsServerStreaming(),
    74  		ClientStreams: method.IsClientStreaming(),
    75  	}
    76  	if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
    77  		cancel()
    78  		return nil, err
    79  	} else {
    80  		err = cs.SendMsg(request)
    81  		if err != nil {
    82  			cancel()
    83  			return nil, err
    84  		}
    85  		err = cs.CloseSend()
    86  		if err != nil {
    87  			cancel()
    88  			return nil, err
    89  		}
    90  		go func() {
    91  			// when the new stream is finished, also cleanup the parent context
    92  			<-cs.Context().Done()
    93  			cancel()
    94  		}()
    95  		return &ServerStream{cs, method.GetOutputType(), s.mf}, nil
    96  	}
    97  }
    98  
    99  // InvokeRpcClientStream creates a new stream that is used to send request messages and, at the end,
   100  // receive the response message. Use this for client-streaming methods.
   101  func (s Stub) InvokeRpcClientStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*ClientStream, error) {
   102  	if !method.IsClientStreaming() || method.IsServerStreaming() {
   103  		return nil, fmt.Errorf("InvokeRpcClientStream is for client-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
   104  	}
   105  	ctx, cancel := context.WithCancel(ctx)
   106  	sd := grpc.StreamDesc{
   107  		StreamName:    method.GetName(),
   108  		ServerStreams: method.IsServerStreaming(),
   109  		ClientStreams: method.IsClientStreaming(),
   110  	}
   111  	if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
   112  		cancel()
   113  		return nil, err
   114  	} else {
   115  		go func() {
   116  			// when the new stream is finished, also cleanup the parent context
   117  			<-cs.Context().Done()
   118  			cancel()
   119  		}()
   120  		return &ClientStream{cs, method, s.mf, cancel}, nil
   121  	}
   122  }
   123  
   124  // InvokeRpcBidiStream creates a new stream that is used to both send request messages and receive response
   125  // messages. Use this for bidi-streaming methods.
   126  func (s Stub) InvokeRpcBidiStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*BidiStream, error) {
   127  	if !method.IsClientStreaming() || !method.IsServerStreaming() {
   128  		return nil, fmt.Errorf("InvokeRpcBidiStream is for bidi-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
   129  	}
   130  	sd := grpc.StreamDesc{
   131  		StreamName:    method.GetName(),
   132  		ServerStreams: method.IsServerStreaming(),
   133  		ClientStreams: method.IsClientStreaming(),
   134  	}
   135  	if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
   136  		return nil, err
   137  	} else {
   138  		return &BidiStream{cs, method.GetInputType(), method.GetOutputType(), s.mf}, nil
   139  	}
   140  }
   141  
   142  func methodType(md *desc.MethodDescriptor) string {
   143  	if md.IsClientStreaming() && md.IsServerStreaming() {
   144  		return "bidi-streaming"
   145  	} else if md.IsClientStreaming() {
   146  		return "client-streaming"
   147  	} else if md.IsServerStreaming() {
   148  		return "server-streaming"
   149  	} else {
   150  		return "unary"
   151  	}
   152  }
   153  
   154  func checkMessageType(md *desc.MessageDescriptor, msg proto.Message) error {
   155  	var typeName string
   156  	if dm, ok := msg.(*dynamic.Message); ok {
   157  		typeName = dm.GetMessageDescriptor().GetFullyQualifiedName()
   158  	} else {
   159  		typeName = proto.MessageName(msg)
   160  	}
   161  	if typeName != md.GetFullyQualifiedName() {
   162  		return fmt.Errorf("expecting message of type %s; got %s", md.GetFullyQualifiedName(), typeName)
   163  	}
   164  	return nil
   165  }
   166  
   167  // ServerStream represents a response stream from a server. Messages in the stream can be queried
   168  // as can header and trailer metadata sent by the server.
   169  type ServerStream struct {
   170  	stream   grpc.ClientStream
   171  	respType *desc.MessageDescriptor
   172  	mf       *dynamic.MessageFactory
   173  }
   174  
   175  // Header returns any header metadata sent by the server (blocks if necessary until headers are
   176  // received).
   177  func (s *ServerStream) Header() (metadata.MD, error) {
   178  	return s.stream.Header()
   179  }
   180  
   181  // Trailer returns the trailer metadata sent by the server. It must only be called after
   182  // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
   183  func (s *ServerStream) Trailer() metadata.MD {
   184  	return s.stream.Trailer()
   185  }
   186  
   187  // Context returns the context associated with this streaming operation.
   188  func (s *ServerStream) Context() context.Context {
   189  	return s.stream.Context()
   190  }
   191  
   192  // RecvMsg returns the next message in the response stream or an error. If the stream
   193  // has completed normally, the error is io.EOF. Otherwise, the error indicates the
   194  // nature of the abnormal termination of the stream.
   195  func (s *ServerStream) RecvMsg() (proto.Message, error) {
   196  	resp := s.mf.NewMessage(s.respType)
   197  	if err := s.stream.RecvMsg(resp); err != nil {
   198  		return nil, err
   199  	} else {
   200  		return resp, nil
   201  	}
   202  }
   203  
   204  // ClientStream represents a response stream from a client. Messages in the stream can be sent
   205  // and, when done, the unary server message and header and trailer metadata can be queried.
   206  type ClientStream struct {
   207  	stream grpc.ClientStream
   208  	method *desc.MethodDescriptor
   209  	mf     *dynamic.MessageFactory
   210  	cancel context.CancelFunc
   211  }
   212  
   213  // Header returns any header metadata sent by the server (blocks if necessary until headers are
   214  // received).
   215  func (s *ClientStream) Header() (metadata.MD, error) {
   216  	return s.stream.Header()
   217  }
   218  
   219  // Trailer returns the trailer metadata sent by the server. It must only be called after
   220  // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
   221  func (s *ClientStream) Trailer() metadata.MD {
   222  	return s.stream.Trailer()
   223  }
   224  
   225  // Context returns the context associated with this streaming operation.
   226  func (s *ClientStream) Context() context.Context {
   227  	return s.stream.Context()
   228  }
   229  
   230  // SendMsg sends a request message to the server.
   231  func (s *ClientStream) SendMsg(m proto.Message) error {
   232  	if err := checkMessageType(s.method.GetInputType(), m); err != nil {
   233  		return err
   234  	}
   235  	return s.stream.SendMsg(m)
   236  }
   237  
   238  // CloseAndReceive closes the outgoing request stream and then blocks for the server's response.
   239  func (s *ClientStream) CloseAndReceive() (proto.Message, error) {
   240  	if err := s.stream.CloseSend(); err != nil {
   241  		return nil, err
   242  	}
   243  	resp := s.mf.NewMessage(s.method.GetOutputType())
   244  	if err := s.stream.RecvMsg(resp); err != nil {
   245  		return nil, err
   246  	}
   247  	// make sure we get EOF for a second message
   248  	if err := s.stream.RecvMsg(resp); err != io.EOF {
   249  		if err == nil {
   250  			s.cancel()
   251  			return nil, fmt.Errorf("client-streaming method %q returned more than one response message", s.method.GetFullyQualifiedName())
   252  		} else {
   253  			return nil, err
   254  		}
   255  	}
   256  	return resp, nil
   257  }
   258  
   259  // BidiStream represents a bi-directional stream for sending messages to and receiving
   260  // messages from a server. The header and trailer metadata sent by the server can also be
   261  // queried.
   262  type BidiStream struct {
   263  	stream   grpc.ClientStream
   264  	reqType  *desc.MessageDescriptor
   265  	respType *desc.MessageDescriptor
   266  	mf       *dynamic.MessageFactory
   267  }
   268  
   269  // Header returns any header metadata sent by the server (blocks if necessary until headers are
   270  // received).
   271  func (s *BidiStream) Header() (metadata.MD, error) {
   272  	return s.stream.Header()
   273  }
   274  
   275  // Trailer returns the trailer metadata sent by the server. It must only be called after
   276  // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
   277  func (s *BidiStream) Trailer() metadata.MD {
   278  	return s.stream.Trailer()
   279  }
   280  
   281  // Context returns the context associated with this streaming operation.
   282  func (s *BidiStream) Context() context.Context {
   283  	return s.stream.Context()
   284  }
   285  
   286  // SendMsg sends a request message to the server.
   287  func (s *BidiStream) SendMsg(m proto.Message) error {
   288  	if err := checkMessageType(s.reqType, m); err != nil {
   289  		return err
   290  	}
   291  	return s.stream.SendMsg(m)
   292  }
   293  
   294  // CloseSend indicates the request stream has ended. Invoke this after all request messages
   295  // are sent (even if there are zero such messages).
   296  func (s *BidiStream) CloseSend() error {
   297  	return s.stream.CloseSend()
   298  }
   299  
   300  // RecvMsg returns the next message in the response stream or an error. If the stream
   301  // has completed normally, the error is io.EOF. Otherwise, the error indicates the
   302  // nature of the abnormal termination of the stream.
   303  func (s *BidiStream) RecvMsg() (proto.Message, error) {
   304  	resp := s.mf.NewMessage(s.respType)
   305  	if err := s.stream.RecvMsg(resp); err != nil {
   306  		return nil, err
   307  	} else {
   308  		return resp, nil
   309  	}
   310  }