github.com/bakjos/protoreflect@v1.9.2/dynamic/grpcdynamic/stub.go (about)

     1  // Package grpcdynamic provides a dynamic RPC stub. It can be used to invoke RPC
     2  // method where only method descriptors are known. The actual request and response
     3  // messages may be dynamic messages.
     4  package grpcdynamic
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/golang/protobuf/proto"
    11  	"golang.org/x/net/context"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  
    15  	"github.com/bakjos/protoreflect/desc"
    16  	"github.com/bakjos/protoreflect/dynamic"
    17  )
    18  
    19  // Stub is an RPC client stub, used for dynamically dispatching RPCs to a server.
    20  type Stub struct {
    21  	channel Channel
    22  	mf      *dynamic.MessageFactory
    23  }
    24  
    25  // Channel represents the operations necessary to issue RPCs via gRPC. The
    26  // *grpc.ClientConn type provides this interface and will typically the concrete
    27  // type used to construct Stubs. But the use of this interface allows
    28  // construction of stubs that use alternate concrete types as the transport for
    29  // RPC operations.
    30  type Channel interface {
    31  	Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error
    32  	NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
    33  }
    34  
    35  var _ Channel = (*grpc.ClientConn)(nil)
    36  
    37  // NewStub creates a new RPC stub that uses the given channel for dispatching RPCs.
    38  func NewStub(channel Channel) Stub {
    39  	return NewStubWithMessageFactory(channel, nil)
    40  }
    41  
    42  // NewStubWithMessageFactory creates a new RPC stub that uses the given channel for
    43  // dispatching RPCs and the given MessageFactory for creating response messages.
    44  func NewStubWithMessageFactory(channel Channel, mf *dynamic.MessageFactory) Stub {
    45  	return Stub{channel: channel, mf: mf}
    46  }
    47  
    48  func requestMethod(md *desc.MethodDescriptor) string {
    49  	return fmt.Sprintf("/%s/%s", md.GetService().GetFullyQualifiedName(), md.GetName())
    50  }
    51  
    52  // InvokeRpc sends a unary RPC and returns the response. Use this for unary methods.
    53  func (s Stub) InvokeRpc(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (proto.Message, error) {
    54  	if method.IsClientStreaming() || method.IsServerStreaming() {
    55  		return nil, fmt.Errorf("InvokeRpc is for unary methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
    56  	}
    57  	if err := checkMessageType(method.GetInputType(), request); err != nil {
    58  		return nil, err
    59  	}
    60  	resp := s.mf.NewMessage(method.GetOutputType())
    61  	if err := s.channel.Invoke(ctx, requestMethod(method), request, resp, opts...); err != nil {
    62  		return nil, err
    63  	}
    64  	return resp, nil
    65  }
    66  
    67  // InvokeRpcServerStream sends a unary RPC and returns the response stream. Use this for server-streaming methods.
    68  func (s Stub) InvokeRpcServerStream(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (*ServerStream, error) {
    69  	if method.IsClientStreaming() || !method.IsServerStreaming() {
    70  		return nil, fmt.Errorf("InvokeRpcServerStream is for server-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
    71  	}
    72  	if err := checkMessageType(method.GetInputType(), request); err != nil {
    73  		return nil, err
    74  	}
    75  	ctx, cancel := context.WithCancel(ctx)
    76  	sd := grpc.StreamDesc{
    77  		StreamName:    method.GetName(),
    78  		ServerStreams: method.IsServerStreaming(),
    79  		ClientStreams: method.IsClientStreaming(),
    80  	}
    81  	if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
    82  		return nil, err
    83  	} else {
    84  		err = cs.SendMsg(request)
    85  		if err != nil {
    86  			cancel()
    87  			return nil, err
    88  		}
    89  		err = cs.CloseSend()
    90  		if err != nil {
    91  			cancel()
    92  			return nil, err
    93  		}
    94  		return &ServerStream{cs, method.GetOutputType(), s.mf}, nil
    95  	}
    96  }
    97  
    98  // InvokeRpcClientStream creates a new stream that is used to send request messages and, at the end,
    99  // receive the response message. Use this for client-streaming methods.
   100  func (s Stub) InvokeRpcClientStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*ClientStream, error) {
   101  	if !method.IsClientStreaming() || method.IsServerStreaming() {
   102  		return nil, fmt.Errorf("InvokeRpcClientStream is for client-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
   103  	}
   104  	ctx, cancel := context.WithCancel(ctx)
   105  	sd := grpc.StreamDesc{
   106  		StreamName:    method.GetName(),
   107  		ServerStreams: method.IsServerStreaming(),
   108  		ClientStreams: method.IsClientStreaming(),
   109  	}
   110  	if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
   111  		return nil, err
   112  	} else {
   113  		return &ClientStream{cs, method, s.mf, cancel}, nil
   114  	}
   115  }
   116  
   117  // InvokeRpcBidiStream creates a new stream that is used to both send request messages and receive response
   118  // messages. Use this for bidi-streaming methods.
   119  func (s Stub) InvokeRpcBidiStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*BidiStream, error) {
   120  	if !method.IsClientStreaming() || !method.IsServerStreaming() {
   121  		return nil, fmt.Errorf("InvokeRpcBidiStream is for bidi-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method))
   122  	}
   123  	sd := grpc.StreamDesc{
   124  		StreamName:    method.GetName(),
   125  		ServerStreams: method.IsServerStreaming(),
   126  		ClientStreams: method.IsClientStreaming(),
   127  	}
   128  	if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil {
   129  		return nil, err
   130  	} else {
   131  		return &BidiStream{cs, method.GetInputType(), method.GetOutputType(), s.mf}, nil
   132  	}
   133  }
   134  
   135  func methodType(md *desc.MethodDescriptor) string {
   136  	if md.IsClientStreaming() && md.IsServerStreaming() {
   137  		return "bidi-streaming"
   138  	} else if md.IsClientStreaming() {
   139  		return "client-streaming"
   140  	} else if md.IsServerStreaming() {
   141  		return "server-streaming"
   142  	} else {
   143  		return "unary"
   144  	}
   145  }
   146  
   147  func checkMessageType(md *desc.MessageDescriptor, msg proto.Message) error {
   148  	var typeName string
   149  	if dm, ok := msg.(*dynamic.Message); ok {
   150  		typeName = dm.GetMessageDescriptor().GetFullyQualifiedName()
   151  	} else {
   152  		typeName = proto.MessageName(msg)
   153  	}
   154  	if typeName != md.GetFullyQualifiedName() {
   155  		return fmt.Errorf("expecting message of type %s; got %s", md.GetFullyQualifiedName(), typeName)
   156  	}
   157  	return nil
   158  }
   159  
   160  // ServerStream represents a response stream from a server. Messages in the stream can be queried
   161  // as can header and trailer metadata sent by the server.
   162  type ServerStream struct {
   163  	stream   grpc.ClientStream
   164  	respType *desc.MessageDescriptor
   165  	mf       *dynamic.MessageFactory
   166  }
   167  
   168  // Header returns any header metadata sent by the server (blocks if necessary until headers are
   169  // received).
   170  func (s *ServerStream) Header() (metadata.MD, error) {
   171  	return s.stream.Header()
   172  }
   173  
   174  // Trailer returns the trailer metadata sent by the server. It must only be called after
   175  // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
   176  func (s *ServerStream) Trailer() metadata.MD {
   177  	return s.stream.Trailer()
   178  }
   179  
   180  // Context returns the context associated with this streaming operation.
   181  func (s *ServerStream) Context() context.Context {
   182  	return s.stream.Context()
   183  }
   184  
   185  // RecvMsg returns the next message in the response stream or an error. If the stream
   186  // has completed normally, the error is io.EOF. Otherwise, the error indicates the
   187  // nature of the abnormal termination of the stream.
   188  func (s *ServerStream) RecvMsg() (proto.Message, error) {
   189  	resp := s.mf.NewMessage(s.respType)
   190  	if err := s.stream.RecvMsg(resp); err != nil {
   191  		return nil, err
   192  	} else {
   193  		return resp, nil
   194  	}
   195  }
   196  
   197  // ClientStream represents a response stream from a client. Messages in the stream can be sent
   198  // and, when done, the unary server message and header and trailer metadata can be queried.
   199  type ClientStream struct {
   200  	stream grpc.ClientStream
   201  	method *desc.MethodDescriptor
   202  	mf     *dynamic.MessageFactory
   203  	cancel context.CancelFunc
   204  }
   205  
   206  // Header returns any header metadata sent by the server (blocks if necessary until headers are
   207  // received).
   208  func (s *ClientStream) Header() (metadata.MD, error) {
   209  	return s.stream.Header()
   210  }
   211  
   212  // Trailer returns the trailer metadata sent by the server. It must only be called after
   213  // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
   214  func (s *ClientStream) Trailer() metadata.MD {
   215  	return s.stream.Trailer()
   216  }
   217  
   218  // Context returns the context associated with this streaming operation.
   219  func (s *ClientStream) Context() context.Context {
   220  	return s.stream.Context()
   221  }
   222  
   223  // SendMsg sends a request message to the server.
   224  func (s *ClientStream) SendMsg(m proto.Message) error {
   225  	if err := checkMessageType(s.method.GetInputType(), m); err != nil {
   226  		return err
   227  	}
   228  	return s.stream.SendMsg(m)
   229  }
   230  
   231  // CloseAndReceive closes the outgoing request stream and then blocks for the server's response.
   232  func (s *ClientStream) CloseAndReceive() (proto.Message, error) {
   233  	if err := s.stream.CloseSend(); err != nil {
   234  		return nil, err
   235  	}
   236  	resp := s.mf.NewMessage(s.method.GetOutputType())
   237  	if err := s.stream.RecvMsg(resp); err != nil {
   238  		return nil, err
   239  	}
   240  	// make sure we get EOF for a second message
   241  	if err := s.stream.RecvMsg(resp); err != io.EOF {
   242  		if err == nil {
   243  			s.cancel()
   244  			return nil, fmt.Errorf("client-streaming method %q returned more than one response message", s.method.GetFullyQualifiedName())
   245  		} else {
   246  			return nil, err
   247  		}
   248  	}
   249  	return resp, nil
   250  }
   251  
   252  // BidiStream represents a bi-directional stream for sending messages to and receiving
   253  // messages from a server. The header and trailer metadata sent by the server can also be
   254  // queried.
   255  type BidiStream struct {
   256  	stream   grpc.ClientStream
   257  	reqType  *desc.MessageDescriptor
   258  	respType *desc.MessageDescriptor
   259  	mf       *dynamic.MessageFactory
   260  }
   261  
   262  // Header returns any header metadata sent by the server (blocks if necessary until headers are
   263  // received).
   264  func (s *BidiStream) Header() (metadata.MD, error) {
   265  	return s.stream.Header()
   266  }
   267  
   268  // Trailer returns the trailer metadata sent by the server. It must only be called after
   269  // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream).
   270  func (s *BidiStream) Trailer() metadata.MD {
   271  	return s.stream.Trailer()
   272  }
   273  
   274  // Context returns the context associated with this streaming operation.
   275  func (s *BidiStream) Context() context.Context {
   276  	return s.stream.Context()
   277  }
   278  
   279  // SendMsg sends a request message to the server.
   280  func (s *BidiStream) SendMsg(m proto.Message) error {
   281  	if err := checkMessageType(s.reqType, m); err != nil {
   282  		return err
   283  	}
   284  	return s.stream.SendMsg(m)
   285  }
   286  
   287  // CloseSend indicates the request stream has ended. Invoke this after all request messages
   288  // are sent (even if there are zero such messages).
   289  func (s *BidiStream) CloseSend() error {
   290  	return s.stream.CloseSend()
   291  }
   292  
   293  // RecvMsg returns the next message in the response stream or an error. If the stream
   294  // has completed normally, the error is io.EOF. Otherwise, the error indicates the
   295  // nature of the abnormal termination of the stream.
   296  func (s *BidiStream) RecvMsg() (proto.Message, error) {
   297  	resp := s.mf.NewMessage(s.respType)
   298  	if err := s.stream.RecvMsg(resp); err != nil {
   299  		return nil, err
   300  	} else {
   301  		return resp, nil
   302  	}
   303  }