github.com/renbou/grpcbridge@v0.0.2-0.20240416012907-bcbd8b12648a/reflection/client.go (about) 1 package reflection 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 9 "github.com/renbou/grpcbridge/grpcadapter" 10 "google.golang.org/grpc/codes" 11 reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1" 12 "google.golang.org/grpc/status" 13 "google.golang.org/protobuf/reflect/protoreflect" 14 ) 15 16 type client struct { 17 // Untyped stream instead of the ServerReflection_ServerReflectionInfoClient 18 // because both v1 and v1alpha are exactly the same, so messages for them can be interchanged. 19 stream grpcadapter.ClientStream 20 timeout time.Duration 21 } 22 23 func connectClient(timeout time.Duration, conn grpcadapter.ClientConn, method string) (*client, error) { 24 ctx, cancel := context.WithTimeout(context.Background(), timeout) 25 defer cancel() 26 27 stream, err := conn.Stream(ctx, method) 28 if err != nil { 29 return nil, fmt.Errorf("establishing reflection stream with method %q: %w", method, err) 30 } 31 32 return &client{stream: stream, timeout: timeout}, nil 33 } 34 35 func (c *client) close() { 36 c.stream.CloseSend() 37 38 // Close the reflection stream gracefully when possible, to avoid spurious errors on target servers. 39 ctx, cancel := context.WithTimeout(context.Background(), c.timeout) 40 defer cancel() 41 42 _ = c.stream.Recv(ctx, new(reflectionpb.ServerReflectionResponse)) 43 44 c.stream.Close() 45 } 46 47 // listServiceNames executes the ListServices reflection request using a single timeout for both request and response. 48 // it expects that the response is of type ListServiceResponse, so it should be used once at the start and not 49 // reused alongside other requests. 50 // listServices doesn't deduplicate any received service names, if any, so it should be done by the caller if necessary, 51 // even though this isn't allowed by the protocol, but who knows what some external server might return. 52 // The returned names aren't validated to actually be complete service names, this needs to happen on the caller's side. 53 func (c *client) listServiceNames() ([]string, error) { 54 ctx, cancel := context.WithTimeout(context.Background(), c.timeout) 55 defer cancel() 56 57 // NB: if this returns an error, the stream is successfully closed. 58 if err := c.stream.Send(ctx, &reflectionpb.ServerReflectionRequest{ 59 MessageRequest: &reflectionpb.ServerReflectionRequest_ListServices{}, 60 }); err != nil { 61 return nil, fmt.Errorf("sending ListServices request: %w", err) 62 } 63 64 resp := new(reflectionpb.ServerReflectionResponse) 65 if err := c.recv(ctx, resp); err != nil { 66 return nil, fmt.Errorf("receiving response to ListServices request: %w", err) 67 } 68 69 // sanity check to ensure that the response is valid 70 if _, ok := resp.MessageResponse.(*reflectionpb.ServerReflectionResponse_ListServicesResponse); !ok { 71 return nil, fmt.Errorf("received response to different request instead of ListServices (parallel call to reflection client?): %v", resp.MessageResponse) 72 } 73 74 // the client doesn't do any processing, so just return the names as is. 75 services := resp.GetListServicesResponse().GetService() 76 serviceNames := make([]string, len(services)) 77 78 for i, s := range services { 79 serviceNames[i] = s.GetName() 80 } 81 82 return serviceNames, nil 83 } 84 85 // fileDescriptorsBySymbols retrieves the file descriptors for all the given symbols using the FileContainingSymbol request. 86 func (c *client) fileDescriptorsBySymbols(serviceNames []protoreflect.FullName) ([][]byte, error) { 87 requests := make([]*reflectionpb.ServerReflectionRequest, len(serviceNames)) 88 89 for i, name := range serviceNames { 90 requests[i] = &reflectionpb.ServerReflectionRequest{ 91 MessageRequest: &reflectionpb.ServerReflectionRequest_FileContainingSymbol{ 92 FileContainingSymbol: string(name), 93 }, 94 } 95 } 96 97 return c.execFileDescriptorRequests(requests, "FileContainingSymbol") 98 } 99 100 // fileDescriptorsByFilenames retrieves the file descriptors for all the given symbols using the FileByFilename request. 101 func (c *client) fileDescriptorsByFilenames(fileNames []string) ([][]byte, error) { 102 requests := make([]*reflectionpb.ServerReflectionRequest, len(fileNames)) 103 104 for i, name := range fileNames { 105 requests[i] = &reflectionpb.ServerReflectionRequest{ 106 MessageRequest: &reflectionpb.ServerReflectionRequest_FileByFilename{ 107 FileByFilename: name, 108 }, 109 } 110 } 111 112 return c.execFileDescriptorRequests(requests, "FileByFilename") 113 } 114 115 // Sends/Recvs are made in parallel to minimize delays which would occur if done sequentially for all the symbols. 116 // NB: the responses aren't deduplicated by any means, so the file descriptors need to be properly parsed and de-duped by name. 117 // This is valid behaviour as specified in https://github.com/grpc/grpc/blob/aa67587bac54458464d38126c92d3a586a7c7a21/src/proto/grpc/reflection/v1/reflection.proto#L94. 118 func (c *client) execFileDescriptorRequests(requests []*reflectionpb.ServerReflectionRequest, name string) ([][]byte, error) { 119 // avoid all this logic when it's not needed because why not? 120 if len(requests) == 0 { 121 return [][]byte{}, nil 122 } 123 124 semaphore := make(chan struct{}, len(requests)) 125 sendErr := make(chan error, 1) 126 recvErr := make(chan error, 1) 127 128 // wait for goroutines to exit to avoid any leaks 129 var wg sync.WaitGroup 130 wg.Add(2) 131 defer wg.Wait() 132 133 // single base context to cancel both goroutines when one fails 134 ctx, cancel := context.WithCancel(context.Background()) 135 defer cancel() 136 137 go func() { 138 defer wg.Done() 139 sendErr <- c.fileDescriptorsRequester(ctx, semaphore, requests, name) 140 }() 141 142 var res [][]byte 143 go func() { 144 defer wg.Done() 145 recvd, err := c.fileDescriptorsReceiver(ctx, semaphore, requests, name) 146 res = recvd // before channel send, which is a synchronizing operation 147 recvErr <- err 148 }() 149 150 // If one of the goroutines fails prematurely, immediately return and cancel the context to stop the other one. 151 // Otherwise wait for both of the signals to arrive and return the result. 152 for range 2 { 153 var err error 154 155 select { 156 case err = <-sendErr: 157 case err = <-recvErr: 158 } 159 160 if err != nil { 161 return nil, err 162 } 163 } 164 165 return res, nil 166 } 167 168 func (c *client) fileDescriptorsRequester(ctx context.Context, semaphore chan struct{}, requests []*reflectionpb.ServerReflectionRequest, name string) error { 169 for i, req := range requests { 170 if err := c.sendTimeout(ctx, req); err != nil { 171 return fmt.Errorf("sending %s request %d/%d (%v): %w", name, i, len(requests), req, err) 172 } 173 174 semaphore <- struct{}{} // buffered channel 175 } 176 177 return nil 178 } 179 180 func (c *client) fileDescriptorsReceiver(ctx context.Context, semaphore chan struct{}, requests []*reflectionpb.ServerReflectionRequest, name string) ([][]byte, error) { 181 // this preallocation is just a base prediction of the number of files, 182 // based on the assumption that each symbol will return a different file. 183 // in reality there can be more (files of dependencies) or less (multiple service in a file). 184 fileDescriptors := make([][]byte, 0, len(requests)) 185 resp := new(reflectionpb.ServerReflectionResponse) 186 187 // no validation of received files is performed here because any potential errors will be covered anyway during proto parsing and registry construction. 188 for i := range requests { 189 // wait for barrier #i to be released by the sender, otherwise we would start waiting for a response before the request is sent. 190 select { 191 case <-semaphore: 192 case <-ctx.Done(): 193 return nil, ctx.Err() 194 } 195 196 if err := c.recvTimeout(ctx, resp); err != nil { 197 return nil, fmt.Errorf("receiving response %d/%d to %s request: %w", i, len(requests), name, err) 198 } 199 200 // sanity check to ensure that the response is valid 201 if _, ok := resp.MessageResponse.(*reflectionpb.ServerReflectionResponse_FileDescriptorResponse); !ok { 202 return nil, fmt.Errorf("received response to different request instead of %s (parallel call to reflection client?): %v", name, resp.MessageResponse) 203 } 204 205 fileDescriptors = append(fileDescriptors, resp.GetFileDescriptorResponse().GetFileDescriptorProto()...) 206 } 207 208 return fileDescriptors, nil 209 } 210 211 func (c *client) sendTimeout(ctx context.Context, req *reflectionpb.ServerReflectionRequest) error { 212 ctx, cancel := context.WithTimeout(ctx, c.timeout) 213 defer cancel() 214 215 if err := c.stream.Send(ctx, req); err != nil { 216 // wrapped by the caller 217 return err 218 } 219 220 return nil 221 } 222 223 func (c *client) recvTimeout(ctx context.Context, resp *reflectionpb.ServerReflectionResponse) error { 224 ctx, cancel := context.WithTimeout(ctx, c.timeout) 225 defer cancel() 226 227 return c.recv(ctx, resp) 228 } 229 230 func (c *client) recv(ctx context.Context, resp *reflectionpb.ServerReflectionResponse) error { 231 if err := c.stream.Recv(ctx, resp); err != nil { 232 // wrapped by the caller 233 return err 234 } else if errRespWrapper, ok := resp.MessageResponse.(*reflectionpb.ServerReflectionResponse_ErrorResponse); ok { 235 errResp := errRespWrapper.ErrorResponse 236 return fmt.Errorf("ErrorResponse with status %w", status.Error(codes.Code(errResp.GetErrorCode()), errResp.GetErrorMessage())) 237 } 238 239 return nil 240 }