github.com/bakjos/protoreflect@v1.9.2/dynamic/grpcdynamic/stub.go (about) 1 // Package grpcdynamic provides a dynamic RPC stub. It can be used to invoke RPC 2 // method where only method descriptors are known. The actual request and response 3 // messages may be dynamic messages. 4 package grpcdynamic 5 6 import ( 7 "fmt" 8 "io" 9 10 "github.com/golang/protobuf/proto" 11 "golang.org/x/net/context" 12 "google.golang.org/grpc" 13 "google.golang.org/grpc/metadata" 14 15 "github.com/bakjos/protoreflect/desc" 16 "github.com/bakjos/protoreflect/dynamic" 17 ) 18 19 // Stub is an RPC client stub, used for dynamically dispatching RPCs to a server. 20 type Stub struct { 21 channel Channel 22 mf *dynamic.MessageFactory 23 } 24 25 // Channel represents the operations necessary to issue RPCs via gRPC. The 26 // *grpc.ClientConn type provides this interface and will typically the concrete 27 // type used to construct Stubs. But the use of this interface allows 28 // construction of stubs that use alternate concrete types as the transport for 29 // RPC operations. 30 type Channel interface { 31 Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error 32 NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) 33 } 34 35 var _ Channel = (*grpc.ClientConn)(nil) 36 37 // NewStub creates a new RPC stub that uses the given channel for dispatching RPCs. 38 func NewStub(channel Channel) Stub { 39 return NewStubWithMessageFactory(channel, nil) 40 } 41 42 // NewStubWithMessageFactory creates a new RPC stub that uses the given channel for 43 // dispatching RPCs and the given MessageFactory for creating response messages. 44 func NewStubWithMessageFactory(channel Channel, mf *dynamic.MessageFactory) Stub { 45 return Stub{channel: channel, mf: mf} 46 } 47 48 func requestMethod(md *desc.MethodDescriptor) string { 49 return fmt.Sprintf("/%s/%s", md.GetService().GetFullyQualifiedName(), md.GetName()) 50 } 51 52 // InvokeRpc sends a unary RPC and returns the response. Use this for unary methods. 53 func (s Stub) InvokeRpc(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (proto.Message, error) { 54 if method.IsClientStreaming() || method.IsServerStreaming() { 55 return nil, fmt.Errorf("InvokeRpc is for unary methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 56 } 57 if err := checkMessageType(method.GetInputType(), request); err != nil { 58 return nil, err 59 } 60 resp := s.mf.NewMessage(method.GetOutputType()) 61 if err := s.channel.Invoke(ctx, requestMethod(method), request, resp, opts...); err != nil { 62 return nil, err 63 } 64 return resp, nil 65 } 66 67 // InvokeRpcServerStream sends a unary RPC and returns the response stream. Use this for server-streaming methods. 68 func (s Stub) InvokeRpcServerStream(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (*ServerStream, error) { 69 if method.IsClientStreaming() || !method.IsServerStreaming() { 70 return nil, fmt.Errorf("InvokeRpcServerStream is for server-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 71 } 72 if err := checkMessageType(method.GetInputType(), request); err != nil { 73 return nil, err 74 } 75 ctx, cancel := context.WithCancel(ctx) 76 sd := grpc.StreamDesc{ 77 StreamName: method.GetName(), 78 ServerStreams: method.IsServerStreaming(), 79 ClientStreams: method.IsClientStreaming(), 80 } 81 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { 82 return nil, err 83 } else { 84 err = cs.SendMsg(request) 85 if err != nil { 86 cancel() 87 return nil, err 88 } 89 err = cs.CloseSend() 90 if err != nil { 91 cancel() 92 return nil, err 93 } 94 return &ServerStream{cs, method.GetOutputType(), s.mf}, nil 95 } 96 } 97 98 // InvokeRpcClientStream creates a new stream that is used to send request messages and, at the end, 99 // receive the response message. Use this for client-streaming methods. 100 func (s Stub) InvokeRpcClientStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*ClientStream, error) { 101 if !method.IsClientStreaming() || method.IsServerStreaming() { 102 return nil, fmt.Errorf("InvokeRpcClientStream is for client-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 103 } 104 ctx, cancel := context.WithCancel(ctx) 105 sd := grpc.StreamDesc{ 106 StreamName: method.GetName(), 107 ServerStreams: method.IsServerStreaming(), 108 ClientStreams: method.IsClientStreaming(), 109 } 110 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { 111 return nil, err 112 } else { 113 return &ClientStream{cs, method, s.mf, cancel}, nil 114 } 115 } 116 117 // InvokeRpcBidiStream creates a new stream that is used to both send request messages and receive response 118 // messages. Use this for bidi-streaming methods. 119 func (s Stub) InvokeRpcBidiStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*BidiStream, error) { 120 if !method.IsClientStreaming() || !method.IsServerStreaming() { 121 return nil, fmt.Errorf("InvokeRpcBidiStream is for bidi-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 122 } 123 sd := grpc.StreamDesc{ 124 StreamName: method.GetName(), 125 ServerStreams: method.IsServerStreaming(), 126 ClientStreams: method.IsClientStreaming(), 127 } 128 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { 129 return nil, err 130 } else { 131 return &BidiStream{cs, method.GetInputType(), method.GetOutputType(), s.mf}, nil 132 } 133 } 134 135 func methodType(md *desc.MethodDescriptor) string { 136 if md.IsClientStreaming() && md.IsServerStreaming() { 137 return "bidi-streaming" 138 } else if md.IsClientStreaming() { 139 return "client-streaming" 140 } else if md.IsServerStreaming() { 141 return "server-streaming" 142 } else { 143 return "unary" 144 } 145 } 146 147 func checkMessageType(md *desc.MessageDescriptor, msg proto.Message) error { 148 var typeName string 149 if dm, ok := msg.(*dynamic.Message); ok { 150 typeName = dm.GetMessageDescriptor().GetFullyQualifiedName() 151 } else { 152 typeName = proto.MessageName(msg) 153 } 154 if typeName != md.GetFullyQualifiedName() { 155 return fmt.Errorf("expecting message of type %s; got %s", md.GetFullyQualifiedName(), typeName) 156 } 157 return nil 158 } 159 160 // ServerStream represents a response stream from a server. Messages in the stream can be queried 161 // as can header and trailer metadata sent by the server. 162 type ServerStream struct { 163 stream grpc.ClientStream 164 respType *desc.MessageDescriptor 165 mf *dynamic.MessageFactory 166 } 167 168 // Header returns any header metadata sent by the server (blocks if necessary until headers are 169 // received). 170 func (s *ServerStream) Header() (metadata.MD, error) { 171 return s.stream.Header() 172 } 173 174 // Trailer returns the trailer metadata sent by the server. It must only be called after 175 // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). 176 func (s *ServerStream) Trailer() metadata.MD { 177 return s.stream.Trailer() 178 } 179 180 // Context returns the context associated with this streaming operation. 181 func (s *ServerStream) Context() context.Context { 182 return s.stream.Context() 183 } 184 185 // RecvMsg returns the next message in the response stream or an error. If the stream 186 // has completed normally, the error is io.EOF. Otherwise, the error indicates the 187 // nature of the abnormal termination of the stream. 188 func (s *ServerStream) RecvMsg() (proto.Message, error) { 189 resp := s.mf.NewMessage(s.respType) 190 if err := s.stream.RecvMsg(resp); err != nil { 191 return nil, err 192 } else { 193 return resp, nil 194 } 195 } 196 197 // ClientStream represents a response stream from a client. Messages in the stream can be sent 198 // and, when done, the unary server message and header and trailer metadata can be queried. 199 type ClientStream struct { 200 stream grpc.ClientStream 201 method *desc.MethodDescriptor 202 mf *dynamic.MessageFactory 203 cancel context.CancelFunc 204 } 205 206 // Header returns any header metadata sent by the server (blocks if necessary until headers are 207 // received). 208 func (s *ClientStream) Header() (metadata.MD, error) { 209 return s.stream.Header() 210 } 211 212 // Trailer returns the trailer metadata sent by the server. It must only be called after 213 // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). 214 func (s *ClientStream) Trailer() metadata.MD { 215 return s.stream.Trailer() 216 } 217 218 // Context returns the context associated with this streaming operation. 219 func (s *ClientStream) Context() context.Context { 220 return s.stream.Context() 221 } 222 223 // SendMsg sends a request message to the server. 224 func (s *ClientStream) SendMsg(m proto.Message) error { 225 if err := checkMessageType(s.method.GetInputType(), m); err != nil { 226 return err 227 } 228 return s.stream.SendMsg(m) 229 } 230 231 // CloseAndReceive closes the outgoing request stream and then blocks for the server's response. 232 func (s *ClientStream) CloseAndReceive() (proto.Message, error) { 233 if err := s.stream.CloseSend(); err != nil { 234 return nil, err 235 } 236 resp := s.mf.NewMessage(s.method.GetOutputType()) 237 if err := s.stream.RecvMsg(resp); err != nil { 238 return nil, err 239 } 240 // make sure we get EOF for a second message 241 if err := s.stream.RecvMsg(resp); err != io.EOF { 242 if err == nil { 243 s.cancel() 244 return nil, fmt.Errorf("client-streaming method %q returned more than one response message", s.method.GetFullyQualifiedName()) 245 } else { 246 return nil, err 247 } 248 } 249 return resp, nil 250 } 251 252 // BidiStream represents a bi-directional stream for sending messages to and receiving 253 // messages from a server. The header and trailer metadata sent by the server can also be 254 // queried. 255 type BidiStream struct { 256 stream grpc.ClientStream 257 reqType *desc.MessageDescriptor 258 respType *desc.MessageDescriptor 259 mf *dynamic.MessageFactory 260 } 261 262 // Header returns any header metadata sent by the server (blocks if necessary until headers are 263 // received). 264 func (s *BidiStream) Header() (metadata.MD, error) { 265 return s.stream.Header() 266 } 267 268 // Trailer returns the trailer metadata sent by the server. It must only be called after 269 // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). 270 func (s *BidiStream) Trailer() metadata.MD { 271 return s.stream.Trailer() 272 } 273 274 // Context returns the context associated with this streaming operation. 275 func (s *BidiStream) Context() context.Context { 276 return s.stream.Context() 277 } 278 279 // SendMsg sends a request message to the server. 280 func (s *BidiStream) SendMsg(m proto.Message) error { 281 if err := checkMessageType(s.reqType, m); err != nil { 282 return err 283 } 284 return s.stream.SendMsg(m) 285 } 286 287 // CloseSend indicates the request stream has ended. Invoke this after all request messages 288 // are sent (even if there are zero such messages). 289 func (s *BidiStream) CloseSend() error { 290 return s.stream.CloseSend() 291 } 292 293 // RecvMsg returns the next message in the response stream or an error. If the stream 294 // has completed normally, the error is io.EOF. Otherwise, the error indicates the 295 // nature of the abnormal termination of the stream. 296 func (s *BidiStream) RecvMsg() (proto.Message, error) { 297 resp := s.mf.NewMessage(s.respType) 298 if err := s.stream.RecvMsg(resp); err != nil { 299 return nil, err 300 } else { 301 return resp, nil 302 } 303 }