github.com/jhump/protoreflect@v1.16.0/dynamic/grpcdynamic/stub.go (about) 1 // Package grpcdynamic provides a dynamic RPC stub. It can be used to invoke RPC 2 // method where only method descriptors are known. The actual request and response 3 // messages may be dynamic messages. 4 package grpcdynamic 5 6 import ( 7 "context" 8 "fmt" 9 "io" 10 11 "github.com/golang/protobuf/proto" 12 "google.golang.org/grpc" 13 "google.golang.org/grpc/metadata" 14 15 "github.com/jhump/protoreflect/desc" 16 "github.com/jhump/protoreflect/dynamic" 17 ) 18 19 // Stub is an RPC client stub, used for dynamically dispatching RPCs to a server. 20 type Stub struct { 21 channel Channel 22 mf *dynamic.MessageFactory 23 } 24 25 // Channel represents the operations necessary to issue RPCs via gRPC. The 26 // *grpc.ClientConn type provides this interface and will typically the concrete 27 // type used to construct Stubs. But the use of this interface allows 28 // construction of stubs that use alternate concrete types as the transport for 29 // RPC operations. 30 type Channel = grpc.ClientConnInterface 31 32 // NewStub creates a new RPC stub that uses the given channel for dispatching RPCs. 33 func NewStub(channel Channel) Stub { 34 return NewStubWithMessageFactory(channel, nil) 35 } 36 37 // NewStubWithMessageFactory creates a new RPC stub that uses the given channel for 38 // dispatching RPCs and the given MessageFactory for creating response messages. 39 func NewStubWithMessageFactory(channel Channel, mf *dynamic.MessageFactory) Stub { 40 return Stub{channel: channel, mf: mf} 41 } 42 43 func requestMethod(md *desc.MethodDescriptor) string { 44 return fmt.Sprintf("/%s/%s", md.GetService().GetFullyQualifiedName(), md.GetName()) 45 } 46 47 // InvokeRpc sends a unary RPC and returns the response. Use this for unary methods. 48 func (s Stub) InvokeRpc(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (proto.Message, error) { 49 if method.IsClientStreaming() || method.IsServerStreaming() { 50 return nil, fmt.Errorf("InvokeRpc is for unary methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 51 } 52 if err := checkMessageType(method.GetInputType(), request); err != nil { 53 return nil, err 54 } 55 resp := s.mf.NewMessage(method.GetOutputType()) 56 if err := s.channel.Invoke(ctx, requestMethod(method), request, resp, opts...); err != nil { 57 return nil, err 58 } 59 return resp, nil 60 } 61 62 // InvokeRpcServerStream sends a unary RPC and returns the response stream. Use this for server-streaming methods. 63 func (s Stub) InvokeRpcServerStream(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (*ServerStream, error) { 64 if method.IsClientStreaming() || !method.IsServerStreaming() { 65 return nil, fmt.Errorf("InvokeRpcServerStream is for server-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 66 } 67 if err := checkMessageType(method.GetInputType(), request); err != nil { 68 return nil, err 69 } 70 ctx, cancel := context.WithCancel(ctx) 71 sd := grpc.StreamDesc{ 72 StreamName: method.GetName(), 73 ServerStreams: method.IsServerStreaming(), 74 ClientStreams: method.IsClientStreaming(), 75 } 76 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { 77 cancel() 78 return nil, err 79 } else { 80 err = cs.SendMsg(request) 81 if err != nil { 82 cancel() 83 return nil, err 84 } 85 err = cs.CloseSend() 86 if err != nil { 87 cancel() 88 return nil, err 89 } 90 go func() { 91 // when the new stream is finished, also cleanup the parent context 92 <-cs.Context().Done() 93 cancel() 94 }() 95 return &ServerStream{cs, method.GetOutputType(), s.mf}, nil 96 } 97 } 98 99 // InvokeRpcClientStream creates a new stream that is used to send request messages and, at the end, 100 // receive the response message. Use this for client-streaming methods. 101 func (s Stub) InvokeRpcClientStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*ClientStream, error) { 102 if !method.IsClientStreaming() || method.IsServerStreaming() { 103 return nil, fmt.Errorf("InvokeRpcClientStream is for client-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 104 } 105 ctx, cancel := context.WithCancel(ctx) 106 sd := grpc.StreamDesc{ 107 StreamName: method.GetName(), 108 ServerStreams: method.IsServerStreaming(), 109 ClientStreams: method.IsClientStreaming(), 110 } 111 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { 112 cancel() 113 return nil, err 114 } else { 115 go func() { 116 // when the new stream is finished, also cleanup the parent context 117 <-cs.Context().Done() 118 cancel() 119 }() 120 return &ClientStream{cs, method, s.mf, cancel}, nil 121 } 122 } 123 124 // InvokeRpcBidiStream creates a new stream that is used to both send request messages and receive response 125 // messages. Use this for bidi-streaming methods. 126 func (s Stub) InvokeRpcBidiStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*BidiStream, error) { 127 if !method.IsClientStreaming() || !method.IsServerStreaming() { 128 return nil, fmt.Errorf("InvokeRpcBidiStream is for bidi-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) 129 } 130 sd := grpc.StreamDesc{ 131 StreamName: method.GetName(), 132 ServerStreams: method.IsServerStreaming(), 133 ClientStreams: method.IsClientStreaming(), 134 } 135 if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { 136 return nil, err 137 } else { 138 return &BidiStream{cs, method.GetInputType(), method.GetOutputType(), s.mf}, nil 139 } 140 } 141 142 func methodType(md *desc.MethodDescriptor) string { 143 if md.IsClientStreaming() && md.IsServerStreaming() { 144 return "bidi-streaming" 145 } else if md.IsClientStreaming() { 146 return "client-streaming" 147 } else if md.IsServerStreaming() { 148 return "server-streaming" 149 } else { 150 return "unary" 151 } 152 } 153 154 func checkMessageType(md *desc.MessageDescriptor, msg proto.Message) error { 155 var typeName string 156 if dm, ok := msg.(*dynamic.Message); ok { 157 typeName = dm.GetMessageDescriptor().GetFullyQualifiedName() 158 } else { 159 typeName = proto.MessageName(msg) 160 } 161 if typeName != md.GetFullyQualifiedName() { 162 return fmt.Errorf("expecting message of type %s; got %s", md.GetFullyQualifiedName(), typeName) 163 } 164 return nil 165 } 166 167 // ServerStream represents a response stream from a server. Messages in the stream can be queried 168 // as can header and trailer metadata sent by the server. 169 type ServerStream struct { 170 stream grpc.ClientStream 171 respType *desc.MessageDescriptor 172 mf *dynamic.MessageFactory 173 } 174 175 // Header returns any header metadata sent by the server (blocks if necessary until headers are 176 // received). 177 func (s *ServerStream) Header() (metadata.MD, error) { 178 return s.stream.Header() 179 } 180 181 // Trailer returns the trailer metadata sent by the server. It must only be called after 182 // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). 183 func (s *ServerStream) Trailer() metadata.MD { 184 return s.stream.Trailer() 185 } 186 187 // Context returns the context associated with this streaming operation. 188 func (s *ServerStream) Context() context.Context { 189 return s.stream.Context() 190 } 191 192 // RecvMsg returns the next message in the response stream or an error. If the stream 193 // has completed normally, the error is io.EOF. Otherwise, the error indicates the 194 // nature of the abnormal termination of the stream. 195 func (s *ServerStream) RecvMsg() (proto.Message, error) { 196 resp := s.mf.NewMessage(s.respType) 197 if err := s.stream.RecvMsg(resp); err != nil { 198 return nil, err 199 } else { 200 return resp, nil 201 } 202 } 203 204 // ClientStream represents a response stream from a client. Messages in the stream can be sent 205 // and, when done, the unary server message and header and trailer metadata can be queried. 206 type ClientStream struct { 207 stream grpc.ClientStream 208 method *desc.MethodDescriptor 209 mf *dynamic.MessageFactory 210 cancel context.CancelFunc 211 } 212 213 // Header returns any header metadata sent by the server (blocks if necessary until headers are 214 // received). 215 func (s *ClientStream) Header() (metadata.MD, error) { 216 return s.stream.Header() 217 } 218 219 // Trailer returns the trailer metadata sent by the server. It must only be called after 220 // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). 221 func (s *ClientStream) Trailer() metadata.MD { 222 return s.stream.Trailer() 223 } 224 225 // Context returns the context associated with this streaming operation. 226 func (s *ClientStream) Context() context.Context { 227 return s.stream.Context() 228 } 229 230 // SendMsg sends a request message to the server. 231 func (s *ClientStream) SendMsg(m proto.Message) error { 232 if err := checkMessageType(s.method.GetInputType(), m); err != nil { 233 return err 234 } 235 return s.stream.SendMsg(m) 236 } 237 238 // CloseAndReceive closes the outgoing request stream and then blocks for the server's response. 239 func (s *ClientStream) CloseAndReceive() (proto.Message, error) { 240 if err := s.stream.CloseSend(); err != nil { 241 return nil, err 242 } 243 resp := s.mf.NewMessage(s.method.GetOutputType()) 244 if err := s.stream.RecvMsg(resp); err != nil { 245 return nil, err 246 } 247 // make sure we get EOF for a second message 248 if err := s.stream.RecvMsg(resp); err != io.EOF { 249 if err == nil { 250 s.cancel() 251 return nil, fmt.Errorf("client-streaming method %q returned more than one response message", s.method.GetFullyQualifiedName()) 252 } else { 253 return nil, err 254 } 255 } 256 return resp, nil 257 } 258 259 // BidiStream represents a bi-directional stream for sending messages to and receiving 260 // messages from a server. The header and trailer metadata sent by the server can also be 261 // queried. 262 type BidiStream struct { 263 stream grpc.ClientStream 264 reqType *desc.MessageDescriptor 265 respType *desc.MessageDescriptor 266 mf *dynamic.MessageFactory 267 } 268 269 // Header returns any header metadata sent by the server (blocks if necessary until headers are 270 // received). 271 func (s *BidiStream) Header() (metadata.MD, error) { 272 return s.stream.Header() 273 } 274 275 // Trailer returns the trailer metadata sent by the server. It must only be called after 276 // RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). 277 func (s *BidiStream) Trailer() metadata.MD { 278 return s.stream.Trailer() 279 } 280 281 // Context returns the context associated with this streaming operation. 282 func (s *BidiStream) Context() context.Context { 283 return s.stream.Context() 284 } 285 286 // SendMsg sends a request message to the server. 287 func (s *BidiStream) SendMsg(m proto.Message) error { 288 if err := checkMessageType(s.reqType, m); err != nil { 289 return err 290 } 291 return s.stream.SendMsg(m) 292 } 293 294 // CloseSend indicates the request stream has ended. Invoke this after all request messages 295 // are sent (even if there are zero such messages). 296 func (s *BidiStream) CloseSend() error { 297 return s.stream.CloseSend() 298 } 299 300 // RecvMsg returns the next message in the response stream or an error. If the stream 301 // has completed normally, the error is io.EOF. Otherwise, the error indicates the 302 // nature of the abnormal termination of the stream. 303 func (s *BidiStream) RecvMsg() (proto.Message, error) { 304 resp := s.mf.NewMessage(s.respType) 305 if err := s.stream.RecvMsg(resp); err != nil { 306 return nil, err 307 } else { 308 return resp, nil 309 } 310 }