github.com/renbou/grpcbridge@v0.0.2-0.20240416012907-bcbd8b12648a/reflection/client.go (about)

     1  package reflection
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/renbou/grpcbridge/grpcadapter"
    10  	"google.golang.org/grpc/codes"
    11  	reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
    12  	"google.golang.org/grpc/status"
    13  	"google.golang.org/protobuf/reflect/protoreflect"
    14  )
    15  
    16  type client struct {
    17  	// Untyped stream instead of the ServerReflection_ServerReflectionInfoClient
    18  	// because both v1 and v1alpha are exactly the same, so messages for them can be interchanged.
    19  	stream  grpcadapter.ClientStream
    20  	timeout time.Duration
    21  }
    22  
    23  func connectClient(timeout time.Duration, conn grpcadapter.ClientConn, method string) (*client, error) {
    24  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
    25  	defer cancel()
    26  
    27  	stream, err := conn.Stream(ctx, method)
    28  	if err != nil {
    29  		return nil, fmt.Errorf("establishing reflection stream with method %q: %w", method, err)
    30  	}
    31  
    32  	return &client{stream: stream, timeout: timeout}, nil
    33  }
    34  
    35  func (c *client) close() {
    36  	c.stream.CloseSend()
    37  
    38  	// Close the reflection stream gracefully when possible, to avoid spurious errors on target servers.
    39  	ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
    40  	defer cancel()
    41  
    42  	_ = c.stream.Recv(ctx, new(reflectionpb.ServerReflectionResponse))
    43  
    44  	c.stream.Close()
    45  }
    46  
    47  // listServiceNames executes the ListServices reflection request using a single timeout for both request and response.
    48  // it expects that the response is of type ListServiceResponse, so it should be used once at the start and not
    49  // reused alongside other requests.
    50  // listServices doesn't deduplicate any received service names, if any, so it should be done by the caller if necessary,
    51  // even though this isn't allowed by the protocol, but who knows what some external server might return.
    52  // The returned names aren't validated to actually be complete service names, this needs to happen on the caller's side.
    53  func (c *client) listServiceNames() ([]string, error) {
    54  	ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
    55  	defer cancel()
    56  
    57  	// NB: if this returns an error, the stream is successfully closed.
    58  	if err := c.stream.Send(ctx, &reflectionpb.ServerReflectionRequest{
    59  		MessageRequest: &reflectionpb.ServerReflectionRequest_ListServices{},
    60  	}); err != nil {
    61  		return nil, fmt.Errorf("sending ListServices request: %w", err)
    62  	}
    63  
    64  	resp := new(reflectionpb.ServerReflectionResponse)
    65  	if err := c.recv(ctx, resp); err != nil {
    66  		return nil, fmt.Errorf("receiving response to ListServices request: %w", err)
    67  	}
    68  
    69  	// sanity check to ensure that the response is valid
    70  	if _, ok := resp.MessageResponse.(*reflectionpb.ServerReflectionResponse_ListServicesResponse); !ok {
    71  		return nil, fmt.Errorf("received response to different request instead of ListServices (parallel call to reflection client?): %v", resp.MessageResponse)
    72  	}
    73  
    74  	// the client doesn't do any processing, so just return the names as is.
    75  	services := resp.GetListServicesResponse().GetService()
    76  	serviceNames := make([]string, len(services))
    77  
    78  	for i, s := range services {
    79  		serviceNames[i] = s.GetName()
    80  	}
    81  
    82  	return serviceNames, nil
    83  }
    84  
    85  // fileDescriptorsBySymbols retrieves the file descriptors for all the given symbols using the FileContainingSymbol request.
    86  func (c *client) fileDescriptorsBySymbols(serviceNames []protoreflect.FullName) ([][]byte, error) {
    87  	requests := make([]*reflectionpb.ServerReflectionRequest, len(serviceNames))
    88  
    89  	for i, name := range serviceNames {
    90  		requests[i] = &reflectionpb.ServerReflectionRequest{
    91  			MessageRequest: &reflectionpb.ServerReflectionRequest_FileContainingSymbol{
    92  				FileContainingSymbol: string(name),
    93  			},
    94  		}
    95  	}
    96  
    97  	return c.execFileDescriptorRequests(requests, "FileContainingSymbol")
    98  }
    99  
   100  // fileDescriptorsByFilenames retrieves the file descriptors for all the given symbols using the FileByFilename request.
   101  func (c *client) fileDescriptorsByFilenames(fileNames []string) ([][]byte, error) {
   102  	requests := make([]*reflectionpb.ServerReflectionRequest, len(fileNames))
   103  
   104  	for i, name := range fileNames {
   105  		requests[i] = &reflectionpb.ServerReflectionRequest{
   106  			MessageRequest: &reflectionpb.ServerReflectionRequest_FileByFilename{
   107  				FileByFilename: name,
   108  			},
   109  		}
   110  	}
   111  
   112  	return c.execFileDescriptorRequests(requests, "FileByFilename")
   113  }
   114  
   115  // Sends/Recvs are made in parallel to minimize delays which would occur if done sequentially for all the symbols.
   116  // NB: the responses aren't deduplicated by any means, so the file descriptors need to be properly parsed and de-duped by name.
   117  // This is valid behaviour as specified in https://github.com/grpc/grpc/blob/aa67587bac54458464d38126c92d3a586a7c7a21/src/proto/grpc/reflection/v1/reflection.proto#L94.
   118  func (c *client) execFileDescriptorRequests(requests []*reflectionpb.ServerReflectionRequest, name string) ([][]byte, error) {
   119  	// avoid all this logic when it's not needed because why not?
   120  	if len(requests) == 0 {
   121  		return [][]byte{}, nil
   122  	}
   123  
   124  	semaphore := make(chan struct{}, len(requests))
   125  	sendErr := make(chan error, 1)
   126  	recvErr := make(chan error, 1)
   127  
   128  	// wait for goroutines to exit to avoid any leaks
   129  	var wg sync.WaitGroup
   130  	wg.Add(2)
   131  	defer wg.Wait()
   132  
   133  	// single base context to cancel both goroutines when one fails
   134  	ctx, cancel := context.WithCancel(context.Background())
   135  	defer cancel()
   136  
   137  	go func() {
   138  		defer wg.Done()
   139  		sendErr <- c.fileDescriptorsRequester(ctx, semaphore, requests, name)
   140  	}()
   141  
   142  	var res [][]byte
   143  	go func() {
   144  		defer wg.Done()
   145  		recvd, err := c.fileDescriptorsReceiver(ctx, semaphore, requests, name)
   146  		res = recvd // before channel send, which is a synchronizing operation
   147  		recvErr <- err
   148  	}()
   149  
   150  	// If one of the goroutines fails prematurely, immediately return and cancel the context to stop the other one.
   151  	// Otherwise wait for both of the signals to arrive and return the result.
   152  	for range 2 {
   153  		var err error
   154  
   155  		select {
   156  		case err = <-sendErr:
   157  		case err = <-recvErr:
   158  		}
   159  
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  	}
   164  
   165  	return res, nil
   166  }
   167  
   168  func (c *client) fileDescriptorsRequester(ctx context.Context, semaphore chan struct{}, requests []*reflectionpb.ServerReflectionRequest, name string) error {
   169  	for i, req := range requests {
   170  		if err := c.sendTimeout(ctx, req); err != nil {
   171  			return fmt.Errorf("sending %s request %d/%d (%v): %w", name, i, len(requests), req, err)
   172  		}
   173  
   174  		semaphore <- struct{}{} // buffered channel
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func (c *client) fileDescriptorsReceiver(ctx context.Context, semaphore chan struct{}, requests []*reflectionpb.ServerReflectionRequest, name string) ([][]byte, error) {
   181  	// this preallocation is just a base prediction of the number of files,
   182  	// based on the assumption that each symbol will return a different file.
   183  	// in reality there can be more (files of dependencies) or less (multiple service in a file).
   184  	fileDescriptors := make([][]byte, 0, len(requests))
   185  	resp := new(reflectionpb.ServerReflectionResponse)
   186  
   187  	// no validation of received files is performed here because any potential errors will be covered anyway during proto parsing and registry construction.
   188  	for i := range requests {
   189  		// wait for barrier #i to be released by the sender, otherwise we would start waiting for a response before the request is sent.
   190  		select {
   191  		case <-semaphore:
   192  		case <-ctx.Done():
   193  			return nil, ctx.Err()
   194  		}
   195  
   196  		if err := c.recvTimeout(ctx, resp); err != nil {
   197  			return nil, fmt.Errorf("receiving response %d/%d to %s request: %w", i, len(requests), name, err)
   198  		}
   199  
   200  		// sanity check to ensure that the response is valid
   201  		if _, ok := resp.MessageResponse.(*reflectionpb.ServerReflectionResponse_FileDescriptorResponse); !ok {
   202  			return nil, fmt.Errorf("received response to different request instead of %s (parallel call to reflection client?): %v", name, resp.MessageResponse)
   203  		}
   204  
   205  		fileDescriptors = append(fileDescriptors, resp.GetFileDescriptorResponse().GetFileDescriptorProto()...)
   206  	}
   207  
   208  	return fileDescriptors, nil
   209  }
   210  
   211  func (c *client) sendTimeout(ctx context.Context, req *reflectionpb.ServerReflectionRequest) error {
   212  	ctx, cancel := context.WithTimeout(ctx, c.timeout)
   213  	defer cancel()
   214  
   215  	if err := c.stream.Send(ctx, req); err != nil {
   216  		// wrapped by the caller
   217  		return err
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  func (c *client) recvTimeout(ctx context.Context, resp *reflectionpb.ServerReflectionResponse) error {
   224  	ctx, cancel := context.WithTimeout(ctx, c.timeout)
   225  	defer cancel()
   226  
   227  	return c.recv(ctx, resp)
   228  }
   229  
   230  func (c *client) recv(ctx context.Context, resp *reflectionpb.ServerReflectionResponse) error {
   231  	if err := c.stream.Recv(ctx, resp); err != nil {
   232  		// wrapped by the caller
   233  		return err
   234  	} else if errRespWrapper, ok := resp.MessageResponse.(*reflectionpb.ServerReflectionResponse_ErrorResponse); ok {
   235  		errResp := errRespWrapper.ErrorResponse
   236  		return fmt.Errorf("ErrorResponse with status %w", status.Error(codes.Code(errResp.GetErrorCode()), errResp.GetErrorMessage()))
   237  	}
   238  
   239  	return nil
   240  }