github.com/jhump/protoreflect@v1.16.0/grpcreflect/client.go (about)

     1  package grpcreflect
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"runtime"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/golang/protobuf/proto"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	refv1 "google.golang.org/grpc/reflection/grpc_reflection_v1"
    18  	refv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    19  	"google.golang.org/grpc/status"
    20  	"google.golang.org/protobuf/types/descriptorpb"
    21  
    22  	"github.com/jhump/protoreflect/desc"
    23  	"github.com/jhump/protoreflect/internal"
    24  )
    25  
    26  // If we try the v1 reflection API and get back "not implemented", we'll wait
    27  // this long before trying v1 again. This allows a long-lived client to
    28  // dynamically switch from v1alpha to v1 if the underlying server is updated
    29  // to support it. But it also prevents every stream request from always trying
    30  // v1 first: if we try it and see it fail, we shouldn't continually retry it
    31  // if we expect it will fail again.
    32  const durationBetweenV1Attempts = time.Hour
    33  
    34  // elementNotFoundError is the error returned by reflective operations where the
    35  // server does not recognize a given file name, symbol name, or extension.
    36  type elementNotFoundError struct {
    37  	name    string
    38  	kind    elementKind
    39  	symType symbolType // only used when kind == elementKindSymbol
    40  	tag     int32      // only used when kind == elementKindExtension
    41  
    42  	// only errors with a kind of elementKindFile will have a cause, which means
    43  	// the named file count not be resolved because of a dependency that could
    44  	// not be found where cause describes the missing dependency
    45  	cause *elementNotFoundError
    46  }
    47  
    48  type elementKind int
    49  
    50  const (
    51  	elementKindSymbol elementKind = iota
    52  	elementKindFile
    53  	elementKindExtension
    54  )
    55  
    56  type symbolType string
    57  
    58  const (
    59  	symbolTypeService = "Service"
    60  	symbolTypeMessage = "Message"
    61  	symbolTypeEnum    = "Enum"
    62  	symbolTypeUnknown = "Symbol"
    63  )
    64  
    65  func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error {
    66  	if cause != nil && cause.kind == elementKindSymbol && cause.name == symbol {
    67  		// no need to wrap
    68  		if symType != symbolTypeUnknown && cause.symType == symbolTypeUnknown {
    69  			// We previously didn't know symbol type but now do?
    70  			// Create a new error that has the right symbol type.
    71  			return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol}
    72  		}
    73  		return cause
    74  	}
    75  	return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause}
    76  }
    77  
    78  func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error {
    79  	if cause != nil && cause.kind == elementKindExtension && cause.name == extendee && cause.tag == tag {
    80  		// no need to wrap
    81  		return cause
    82  	}
    83  	return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause}
    84  }
    85  
    86  func fileNotFound(file string, cause *elementNotFoundError) error {
    87  	if cause != nil && cause.kind == elementKindFile && cause.name == file {
    88  		// no need to wrap
    89  		return cause
    90  	}
    91  	return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause}
    92  }
    93  
    94  func (e *elementNotFoundError) Error() string {
    95  	first := true
    96  	var b bytes.Buffer
    97  	for ; e != nil; e = e.cause {
    98  		if first {
    99  			first = false
   100  		} else {
   101  			_, _ = fmt.Fprint(&b, "\ncaused by: ")
   102  		}
   103  		switch e.kind {
   104  		case elementKindSymbol:
   105  			_, _ = fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name)
   106  		case elementKindExtension:
   107  			_, _ = fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name)
   108  		default:
   109  			_, _ = fmt.Fprintf(&b, "File not found: %s", e.name)
   110  		}
   111  	}
   112  	return b.String()
   113  }
   114  
   115  // IsElementNotFoundError determines if the given error indicates that a file
   116  // name, symbol name, or extension field was could not be found by the server.
   117  func IsElementNotFoundError(err error) bool {
   118  	_, ok := err.(*elementNotFoundError)
   119  	return ok
   120  }
   121  
   122  // ProtocolError is an error returned when the server sends a response of the
   123  // wrong type.
   124  type ProtocolError struct {
   125  	missingType reflect.Type
   126  }
   127  
   128  func (p ProtocolError) Error() string {
   129  	return fmt.Sprintf("Protocol error: response was missing %v", p.missingType)
   130  }
   131  
   132  type extDesc struct {
   133  	extendedMessageName string
   134  	extensionNumber     int32
   135  }
   136  
   137  // Client is a client connection to a server for performing reflection calls
   138  // and resolving remote symbols.
   139  type Client struct {
   140  	ctx          context.Context
   141  	now          func() time.Time
   142  	stubV1       refv1.ServerReflectionClient
   143  	stubV1Alpha  refv1alpha.ServerReflectionClient
   144  	allowMissing atomic.Bool
   145  
   146  	connMu      sync.Mutex
   147  	cancel      context.CancelFunc
   148  	stream      refv1alpha.ServerReflection_ServerReflectionInfoClient
   149  	useV1Alpha  bool
   150  	lastTriedV1 time.Time
   151  
   152  	cacheMu          sync.RWMutex
   153  	protosByName     map[string]*descriptorpb.FileDescriptorProto
   154  	filesByName      map[string]*desc.FileDescriptor
   155  	filesBySymbol    map[string]*desc.FileDescriptor
   156  	filesByExtension map[extDesc]*desc.FileDescriptor
   157  }
   158  
   159  // NewClient creates a new Client with the given root context and using the
   160  // given RPC stub for talking to the server.
   161  //
   162  // Deprecated: Use NewClientV1Alpha if you are intentionally pinning the
   163  // v1alpha version of the reflection service. Otherwise, use NewClientAuto
   164  // instead.
   165  func NewClient(ctx context.Context, stub refv1alpha.ServerReflectionClient) *Client {
   166  	return NewClientV1Alpha(ctx, stub)
   167  }
   168  
   169  // NewClientV1Alpha creates a new Client using the v1alpha version of reflection
   170  // with the given root context and using the given RPC stub for talking to the
   171  // server.
   172  func NewClientV1Alpha(ctx context.Context, stub refv1alpha.ServerReflectionClient) *Client {
   173  	return newClient(ctx, nil, stub)
   174  }
   175  
   176  func newClient(ctx context.Context, stubv1 refv1.ServerReflectionClient, stubv1alpha refv1alpha.ServerReflectionClient) *Client {
   177  	cr := &Client{
   178  		ctx:              ctx,
   179  		now:              time.Now,
   180  		stubV1:           stubv1,
   181  		stubV1Alpha:      stubv1alpha,
   182  		protosByName:     map[string]*descriptorpb.FileDescriptorProto{},
   183  		filesByName:      map[string]*desc.FileDescriptor{},
   184  		filesBySymbol:    map[string]*desc.FileDescriptor{},
   185  		filesByExtension: map[extDesc]*desc.FileDescriptor{},
   186  	}
   187  	// don't leak a grpc stream
   188  	runtime.SetFinalizer(cr, (*Client).Reset)
   189  	return cr
   190  }
   191  
   192  // NewClientAuto creates a new Client that will use either v1 or v1alpha version
   193  // of reflection (based on what the server supports) with the given root context
   194  // and using the given client connection.
   195  //
   196  // It will first the v1 version of the reflection service. If it gets back an
   197  // "Unimplemented" error, it will fall back to using the v1alpha version. It
   198  // will remember which version the server supports for any subsequent operations
   199  // that need to re-invoke the streaming RPC. But, if it's a very long-lived
   200  // client, it will periodically retry the v1 version (in case the server is
   201  // updated to support it also). The period for these retries is every hour.
   202  func NewClientAuto(ctx context.Context, cc grpc.ClientConnInterface) *Client {
   203  	stubv1 := refv1.NewServerReflectionClient(cc)
   204  	stubv1alpha := refv1alpha.NewServerReflectionClient(cc)
   205  	return newClient(ctx, stubv1, stubv1alpha)
   206  }
   207  
   208  // AllowMissingFileDescriptors configures the client to allow missing files
   209  // when building descriptors when possible. Missing files are often fatal
   210  // errors, but with this option they can sometimes be worked around. Building
   211  // a schema can only succeed with some files missing if the files in question
   212  // only provide custom options and/or other unused types.
   213  func (cr *Client) AllowMissingFileDescriptors() {
   214  	cr.allowMissing.Store(true)
   215  }
   216  
   217  // TODO: We should also have a NewClientV1. However that should not refer to internal
   218  // generated code. So it will have to wait until the grpc-go team fixes this issue:
   219  //  https://github.com/grpc/grpc-go/issues/5684
   220  
   221  // FileByFilename asks the server for a file descriptor for the proto file with
   222  // the given name.
   223  func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) {
   224  	// hit the cache first
   225  	cr.cacheMu.RLock()
   226  	if fd, ok := cr.filesByName[filename]; ok {
   227  		cr.cacheMu.RUnlock()
   228  		return fd, nil
   229  	}
   230  	fdp, ok := cr.protosByName[filename]
   231  	cr.cacheMu.RUnlock()
   232  	// not there? see if we've downloaded the proto
   233  	if ok {
   234  		return cr.descriptorFromProto(fdp)
   235  	}
   236  
   237  	req := &refv1alpha.ServerReflectionRequest{
   238  		MessageRequest: &refv1alpha.ServerReflectionRequest_FileByFilename{
   239  			FileByFilename: filename,
   240  		},
   241  	}
   242  	accept := func(fd *desc.FileDescriptor) bool {
   243  		return fd.GetName() == filename
   244  	}
   245  
   246  	fd, err := cr.getAndCacheFileDescriptors(req, filename, "", accept)
   247  	if isNotFound(err) {
   248  		// file not found? see if we can look up via alternate name
   249  		if alternate, ok := internal.StdFileAliases[filename]; ok {
   250  			req := &refv1alpha.ServerReflectionRequest{
   251  				MessageRequest: &refv1alpha.ServerReflectionRequest_FileByFilename{
   252  					FileByFilename: alternate,
   253  				},
   254  			}
   255  			fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename, accept)
   256  			if isNotFound(err) {
   257  				err = fileNotFound(filename, nil)
   258  			}
   259  		} else {
   260  			err = fileNotFound(filename, nil)
   261  		}
   262  	} else if e, ok := err.(*elementNotFoundError); ok {
   263  		err = fileNotFound(filename, e)
   264  	}
   265  	return fd, err
   266  }
   267  
   268  // FileContainingSymbol asks the server for a file descriptor for the proto file
   269  // that declares the given fully-qualified symbol.
   270  func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) {
   271  	// hit the cache first
   272  	cr.cacheMu.RLock()
   273  	fd, ok := cr.filesBySymbol[symbol]
   274  	cr.cacheMu.RUnlock()
   275  	if ok {
   276  		return fd, nil
   277  	}
   278  
   279  	req := &refv1alpha.ServerReflectionRequest{
   280  		MessageRequest: &refv1alpha.ServerReflectionRequest_FileContainingSymbol{
   281  			FileContainingSymbol: symbol,
   282  		},
   283  	}
   284  	accept := func(fd *desc.FileDescriptor) bool {
   285  		return fd.FindSymbol(symbol) != nil
   286  	}
   287  	fd, err := cr.getAndCacheFileDescriptors(req, "", "", accept)
   288  	if isNotFound(err) {
   289  		err = symbolNotFound(symbol, symbolTypeUnknown, nil)
   290  	} else if e, ok := err.(*elementNotFoundError); ok {
   291  		err = symbolNotFound(symbol, symbolTypeUnknown, e)
   292  	}
   293  	return fd, err
   294  }
   295  
   296  // FileContainingExtension asks the server for a file descriptor for the proto
   297  // file that declares an extension with the given number for the given
   298  // fully-qualified message name.
   299  func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) {
   300  	// hit the cache first
   301  	cr.cacheMu.RLock()
   302  	fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}]
   303  	cr.cacheMu.RUnlock()
   304  	if ok {
   305  		return fd, nil
   306  	}
   307  
   308  	req := &refv1alpha.ServerReflectionRequest{
   309  		MessageRequest: &refv1alpha.ServerReflectionRequest_FileContainingExtension{
   310  			FileContainingExtension: &refv1alpha.ExtensionRequest{
   311  				ContainingType:  extendedMessageName,
   312  				ExtensionNumber: extensionNumber,
   313  			},
   314  		},
   315  	}
   316  	accept := func(fd *desc.FileDescriptor) bool {
   317  		return fd.FindExtension(extendedMessageName, extensionNumber) != nil
   318  	}
   319  	fd, err := cr.getAndCacheFileDescriptors(req, "", "", accept)
   320  	if isNotFound(err) {
   321  		err = extensionNotFound(extendedMessageName, extensionNumber, nil)
   322  	} else if e, ok := err.(*elementNotFoundError); ok {
   323  		err = extensionNotFound(extendedMessageName, extensionNumber, e)
   324  	}
   325  	return fd, err
   326  }
   327  
   328  func (cr *Client) getAndCacheFileDescriptors(req *refv1alpha.ServerReflectionRequest, expectedName, alias string, accept func(*desc.FileDescriptor) bool) (*desc.FileDescriptor, error) {
   329  	resp, err := cr.send(req)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	fdResp := resp.GetFileDescriptorResponse()
   335  	if fdResp == nil {
   336  		return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()}
   337  	}
   338  
   339  	// Response can contain the result file descriptor, but also its transitive
   340  	// deps. Furthermore, protocol states that subsequent requests do not need
   341  	// to send transitive deps that have been sent in prior responses. So we
   342  	// need to cache all file descriptors and then return the first one (which
   343  	// should be the answer). If we're looking for a file by name, we can be
   344  	// smarter and make sure to grab one by name instead of just grabbing the
   345  	// first one.
   346  	var fds []*descriptorpb.FileDescriptorProto
   347  	for _, fdBytes := range fdResp.FileDescriptorProto {
   348  		fd := &descriptorpb.FileDescriptorProto{}
   349  		if err = proto.Unmarshal(fdBytes, fd); err != nil {
   350  			return nil, err
   351  		}
   352  
   353  		if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName {
   354  			// we found a file was aliased, so we need to update the proto to reflect that
   355  			fd.Name = proto.String(alias)
   356  		}
   357  
   358  		cr.cacheMu.Lock()
   359  		// store in cache of raw descriptor protos, but don't overwrite existing protos
   360  		if existingFd, ok := cr.protosByName[fd.GetName()]; ok {
   361  			fd = existingFd
   362  		} else {
   363  			cr.protosByName[fd.GetName()] = fd
   364  		}
   365  		cr.cacheMu.Unlock()
   366  
   367  		fds = append(fds, fd)
   368  	}
   369  
   370  	// find the right result from the files returned
   371  	for _, fd := range fds {
   372  		result, err := cr.descriptorFromProto(fd)
   373  		if err != nil {
   374  			return nil, err
   375  		}
   376  		if accept(result) {
   377  			return result, nil
   378  		}
   379  	}
   380  
   381  	return nil, status.Errorf(codes.NotFound, "response does not include expected file")
   382  }
   383  
   384  func (cr *Client) descriptorFromProto(fd *descriptorpb.FileDescriptorProto) (*desc.FileDescriptor, error) {
   385  	allowMissing := cr.allowMissing.Load()
   386  	deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
   387  	var deferredErr error
   388  	var missingDeps []int
   389  	for i, depName := range fd.GetDependency() {
   390  		if dep, err := cr.FileByFilename(depName); err != nil {
   391  			if _, ok := err.(*elementNotFoundError); !ok || !allowMissing {
   392  				return nil, err
   393  			}
   394  			// We'll ignore for now to see if the file is really necessary.
   395  			// (If it only supplies custom options, we can get by without it.)
   396  			if deferredErr == nil {
   397  				deferredErr = err
   398  			}
   399  			missingDeps = append(missingDeps, i)
   400  		} else {
   401  			deps = append(deps, dep)
   402  		}
   403  	}
   404  	if len(missingDeps) > 0 {
   405  		fd = fileWithoutDeps(fd, missingDeps)
   406  	}
   407  	d, err := desc.CreateFileDescriptor(fd, deps...)
   408  	if err != nil {
   409  		if deferredErr != nil {
   410  			// assume the issue is the missing dep
   411  			return nil, deferredErr
   412  		}
   413  		return nil, err
   414  	}
   415  	d = cr.cacheFile(d)
   416  	return d, nil
   417  }
   418  
   419  func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor {
   420  	cr.cacheMu.Lock()
   421  	defer cr.cacheMu.Unlock()
   422  
   423  	// cache file descriptor by name, but don't overwrite existing entry
   424  	// (existing entry could come from concurrent caller)
   425  	if existingFd, ok := cr.filesByName[fd.GetName()]; ok {
   426  		return existingFd
   427  	}
   428  	cr.filesByName[fd.GetName()] = fd
   429  
   430  	// also cache by symbols and extensions
   431  	for _, m := range fd.GetMessageTypes() {
   432  		cr.cacheMessageLocked(fd, m)
   433  	}
   434  	for _, e := range fd.GetEnumTypes() {
   435  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   436  		for _, v := range e.GetValues() {
   437  			cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
   438  		}
   439  	}
   440  	for _, e := range fd.GetExtensions() {
   441  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   442  		cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
   443  	}
   444  	for _, s := range fd.GetServices() {
   445  		cr.filesBySymbol[s.GetFullyQualifiedName()] = fd
   446  		for _, m := range s.GetMethods() {
   447  			cr.filesBySymbol[m.GetFullyQualifiedName()] = fd
   448  		}
   449  	}
   450  
   451  	return fd
   452  }
   453  
   454  func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) {
   455  	cr.filesBySymbol[md.GetFullyQualifiedName()] = fd
   456  	for _, f := range md.GetFields() {
   457  		cr.filesBySymbol[f.GetFullyQualifiedName()] = fd
   458  	}
   459  	for _, o := range md.GetOneOfs() {
   460  		cr.filesBySymbol[o.GetFullyQualifiedName()] = fd
   461  	}
   462  	for _, e := range md.GetNestedEnumTypes() {
   463  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   464  		for _, v := range e.GetValues() {
   465  			cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
   466  		}
   467  	}
   468  	for _, e := range md.GetNestedExtensions() {
   469  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   470  		cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
   471  	}
   472  	for _, m := range md.GetNestedMessageTypes() {
   473  		cr.cacheMessageLocked(fd, m) // recurse
   474  	}
   475  }
   476  
   477  // AllExtensionNumbersForType asks the server for all known extension numbers
   478  // for the given fully-qualified message name.
   479  func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) {
   480  	req := &refv1alpha.ServerReflectionRequest{
   481  		MessageRequest: &refv1alpha.ServerReflectionRequest_AllExtensionNumbersOfType{
   482  			AllExtensionNumbersOfType: extendedMessageName,
   483  		},
   484  	}
   485  	resp, err := cr.send(req)
   486  	if err != nil {
   487  		if isNotFound(err) {
   488  			return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil)
   489  		}
   490  		return nil, err
   491  	}
   492  
   493  	extResp := resp.GetAllExtensionNumbersResponse()
   494  	if extResp == nil {
   495  		return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()}
   496  	}
   497  	return extResp.ExtensionNumber, nil
   498  }
   499  
   500  // ListServices asks the server for the fully-qualified names of all exposed
   501  // services.
   502  func (cr *Client) ListServices() ([]string, error) {
   503  	req := &refv1alpha.ServerReflectionRequest{
   504  		MessageRequest: &refv1alpha.ServerReflectionRequest_ListServices{
   505  			// proto doesn't indicate any purpose for this value and server impl
   506  			// doesn't actually use it...
   507  			ListServices: "*",
   508  		},
   509  	}
   510  	resp, err := cr.send(req)
   511  	if err != nil {
   512  		return nil, err
   513  	}
   514  
   515  	listResp := resp.GetListServicesResponse()
   516  	if listResp == nil {
   517  		return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()}
   518  	}
   519  	serviceNames := make([]string, len(listResp.Service))
   520  	for i, s := range listResp.Service {
   521  		serviceNames[i] = s.Name
   522  	}
   523  	return serviceNames, nil
   524  }
   525  
   526  func (cr *Client) send(req *refv1alpha.ServerReflectionRequest) (*refv1alpha.ServerReflectionResponse, error) {
   527  	// we allow one immediate retry, in case we have a stale stream
   528  	// (e.g. closed by server)
   529  	resp, err := cr.doSend(req)
   530  	if err != nil {
   531  		return nil, err
   532  	}
   533  
   534  	// convert error response messages into errors
   535  	errResp := resp.GetErrorResponse()
   536  	if errResp != nil {
   537  		return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage)
   538  	}
   539  
   540  	return resp, nil
   541  }
   542  
   543  func isNotFound(err error) bool {
   544  	if err == nil {
   545  		return false
   546  	}
   547  	s, ok := status.FromError(err)
   548  	return ok && s.Code() == codes.NotFound
   549  }
   550  
   551  func (cr *Client) doSend(req *refv1alpha.ServerReflectionRequest) (*refv1alpha.ServerReflectionResponse, error) {
   552  	// TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery
   553  	// (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus
   554  	// delivered in correct oder.
   555  	cr.connMu.Lock()
   556  	defer cr.connMu.Unlock()
   557  	return cr.doSendLocked(0, nil, req)
   558  }
   559  
   560  func (cr *Client) doSendLocked(attemptCount int, prevErr error, req *refv1alpha.ServerReflectionRequest) (*refv1alpha.ServerReflectionResponse, error) {
   561  	if attemptCount >= 3 && prevErr != nil {
   562  		return nil, prevErr
   563  	}
   564  	if (status.Code(prevErr) == codes.Unimplemented ||
   565  		status.Code(prevErr) == codes.Unavailable) &&
   566  		cr.useV1() {
   567  		// If v1 is unimplemented, fallback to v1alpha.
   568  		// We also fallback on unavailable because some servers have been
   569  		// observed to close the connection/cancel the stream, w/out sending
   570  		// back status or headers, when the service name is not known. When
   571  		// this happens, the RPC status code is unavailable.
   572  		// See https://github.com/fullstorydev/grpcurl/issues/434
   573  		cr.useV1Alpha = true
   574  		cr.lastTriedV1 = cr.now()
   575  	}
   576  	attemptCount++
   577  
   578  	if err := cr.initStreamLocked(); err != nil {
   579  		return nil, err
   580  	}
   581  
   582  	if err := cr.stream.Send(req); err != nil {
   583  		if err == io.EOF {
   584  			// if send returns EOF, must call Recv to get real underlying error
   585  			_, err = cr.stream.Recv()
   586  		}
   587  		cr.resetLocked()
   588  		return cr.doSendLocked(attemptCount, err, req)
   589  	}
   590  
   591  	resp, err := cr.stream.Recv()
   592  	if err != nil {
   593  		cr.resetLocked()
   594  		return cr.doSendLocked(attemptCount, err, req)
   595  	}
   596  	return resp, nil
   597  }
   598  
   599  func (cr *Client) initStreamLocked() error {
   600  	if cr.stream != nil {
   601  		return nil
   602  	}
   603  	var newCtx context.Context
   604  	newCtx, cr.cancel = context.WithCancel(cr.ctx)
   605  	if cr.useV1Alpha == true && cr.now().Sub(cr.lastTriedV1) > durationBetweenV1Attempts {
   606  		// we're due for periodic retry of v1
   607  		cr.useV1Alpha = false
   608  	}
   609  	if cr.useV1() {
   610  		// try the v1 API
   611  		streamv1, err := cr.stubV1.ServerReflectionInfo(newCtx)
   612  		if err == nil {
   613  			cr.stream = adaptStreamFromV1{streamv1}
   614  			return nil
   615  		}
   616  		if status.Code(err) != codes.Unimplemented {
   617  			return err
   618  		}
   619  		// oh well, fall through below to try v1alpha and update state
   620  		// so we skip straight to v1alpha next time
   621  		cr.useV1Alpha = true
   622  		cr.lastTriedV1 = cr.now()
   623  	}
   624  	var err error
   625  	cr.stream, err = cr.stubV1Alpha.ServerReflectionInfo(newCtx)
   626  	return err
   627  }
   628  
   629  func (cr *Client) useV1() bool {
   630  	return !cr.useV1Alpha && cr.stubV1 != nil
   631  }
   632  
   633  // Reset ensures that any active stream with the server is closed, releasing any
   634  // resources.
   635  func (cr *Client) Reset() {
   636  	cr.connMu.Lock()
   637  	defer cr.connMu.Unlock()
   638  	cr.resetLocked()
   639  }
   640  
   641  func (cr *Client) resetLocked() {
   642  	if cr.stream != nil {
   643  		cr.stream.CloseSend()
   644  		for {
   645  			// drain the stream, this covers io.EOF too
   646  			if _, err := cr.stream.Recv(); err != nil {
   647  				break
   648  			}
   649  		}
   650  		cr.stream = nil
   651  	}
   652  	if cr.cancel != nil {
   653  		cr.cancel()
   654  		cr.cancel = nil
   655  	}
   656  }
   657  
   658  // ResolveService asks the server to resolve the given fully-qualified service
   659  // name into a service descriptor.
   660  func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) {
   661  	file, err := cr.FileContainingSymbol(serviceName)
   662  	if err != nil {
   663  		return nil, setSymbolType(err, serviceName, symbolTypeService)
   664  	}
   665  	d := file.FindSymbol(serviceName)
   666  	if d == nil {
   667  		return nil, symbolNotFound(serviceName, symbolTypeService, nil)
   668  	}
   669  	if s, ok := d.(*desc.ServiceDescriptor); ok {
   670  		return s, nil
   671  	} else {
   672  		return nil, symbolNotFound(serviceName, symbolTypeService, nil)
   673  	}
   674  }
   675  
   676  // ResolveMessage asks the server to resolve the given fully-qualified message
   677  // name into a message descriptor.
   678  func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) {
   679  	file, err := cr.FileContainingSymbol(messageName)
   680  	if err != nil {
   681  		return nil, setSymbolType(err, messageName, symbolTypeMessage)
   682  	}
   683  	d := file.FindSymbol(messageName)
   684  	if d == nil {
   685  		return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
   686  	}
   687  	if s, ok := d.(*desc.MessageDescriptor); ok {
   688  		return s, nil
   689  	} else {
   690  		return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
   691  	}
   692  }
   693  
   694  // ResolveEnum asks the server to resolve the given fully-qualified enum name
   695  // into an enum descriptor.
   696  func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) {
   697  	file, err := cr.FileContainingSymbol(enumName)
   698  	if err != nil {
   699  		return nil, setSymbolType(err, enumName, symbolTypeEnum)
   700  	}
   701  	d := file.FindSymbol(enumName)
   702  	if d == nil {
   703  		return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
   704  	}
   705  	if s, ok := d.(*desc.EnumDescriptor); ok {
   706  		return s, nil
   707  	} else {
   708  		return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
   709  	}
   710  }
   711  
   712  func setSymbolType(err error, name string, symType symbolType) error {
   713  	if e, ok := err.(*elementNotFoundError); ok {
   714  		if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown {
   715  			e.symType = symType
   716  		}
   717  	}
   718  	return err
   719  }
   720  
   721  // ResolveEnumValues asks the server to resolve the given fully-qualified enum
   722  // name into a map of names to numbers that represents the enum's values.
   723  func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) {
   724  	enumDesc, err := cr.ResolveEnum(enumName)
   725  	if err != nil {
   726  		return nil, err
   727  	}
   728  	vals := map[string]int32{}
   729  	for _, valDesc := range enumDesc.GetValues() {
   730  		vals[valDesc.GetName()] = valDesc.GetNumber()
   731  	}
   732  	return vals, nil
   733  }
   734  
   735  // ResolveExtension asks the server to resolve the given extension number and
   736  // fully-qualified message name into a field descriptor.
   737  func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) {
   738  	file, err := cr.FileContainingExtension(extendedType, extensionNumber)
   739  	if err != nil {
   740  		return nil, err
   741  	}
   742  	d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file})
   743  	if d == nil {
   744  		return nil, extensionNotFound(extendedType, extensionNumber, nil)
   745  	} else {
   746  		return d, nil
   747  	}
   748  }
   749  
   750  func fileWithoutDeps(fd *descriptorpb.FileDescriptorProto, missingDeps []int) *descriptorpb.FileDescriptorProto {
   751  	// We need to rebuild the file without the missing deps.
   752  	fd = proto.Clone(fd).(*descriptorpb.FileDescriptorProto)
   753  	newNumDeps := len(fd.GetDependency()) - len(missingDeps)
   754  	newDeps := make([]string, 0, newNumDeps)
   755  	remapped := make(map[int]int, newNumDeps)
   756  	missingIdx := 0
   757  	for i, dep := range fd.GetDependency() {
   758  		if missingIdx < len(missingDeps) {
   759  			if i == missingDeps[missingIdx] {
   760  				// This dep was missing. Skip it.
   761  				missingIdx++
   762  				continue
   763  			}
   764  		}
   765  		remapped[i] = len(newDeps)
   766  		newDeps = append(newDeps, dep)
   767  	}
   768  	// Also rebuild public and weak import slices.
   769  	newPublic := make([]int32, 0, len(fd.GetPublicDependency()))
   770  	for _, idx := range fd.GetPublicDependency() {
   771  		newIdx, ok := remapped[int(idx)]
   772  		if ok {
   773  			newPublic = append(newPublic, int32(newIdx))
   774  		}
   775  	}
   776  	newWeak := make([]int32, 0, len(fd.GetWeakDependency()))
   777  	for _, idx := range fd.GetWeakDependency() {
   778  		newIdx, ok := remapped[int(idx)]
   779  		if ok {
   780  			newWeak = append(newWeak, int32(newIdx))
   781  		}
   782  	}
   783  
   784  	fd.Dependency = newDeps
   785  	fd.PublicDependency = newPublic
   786  	fd.WeakDependency = newWeak
   787  	return fd
   788  }
   789  
   790  func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor {
   791  	// search extensions in this scope
   792  	for _, ext := range scope.extensions() {
   793  		if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType {
   794  			return ext
   795  		}
   796  	}
   797  
   798  	// if not found, search nested scopes
   799  	for _, nested := range scope.nestedScopes() {
   800  		ext := findExtension(extendedType, extensionNumber, nested)
   801  		if ext != nil {
   802  			return ext
   803  		}
   804  	}
   805  
   806  	return nil
   807  }
   808  
   809  type extensionScope interface {
   810  	extensions() []*desc.FieldDescriptor
   811  	nestedScopes() []extensionScope
   812  }
   813  
   814  // fileDescriptorExtensions implements extensionHolder interface on top of
   815  // FileDescriptorProto
   816  type fileDescriptorExtensions struct {
   817  	proto *desc.FileDescriptor
   818  }
   819  
   820  func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor {
   821  	return fde.proto.GetExtensions()
   822  }
   823  
   824  func (fde fileDescriptorExtensions) nestedScopes() []extensionScope {
   825  	scopes := make([]extensionScope, len(fde.proto.GetMessageTypes()))
   826  	for i, m := range fde.proto.GetMessageTypes() {
   827  		scopes[i] = msgDescriptorExtensions{m}
   828  	}
   829  	return scopes
   830  }
   831  
   832  // msgDescriptorExtensions implements extensionHolder interface on top of
   833  // DescriptorProto
   834  type msgDescriptorExtensions struct {
   835  	proto *desc.MessageDescriptor
   836  }
   837  
   838  func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor {
   839  	return mde.proto.GetNestedExtensions()
   840  }
   841  
   842  func (mde msgDescriptorExtensions) nestedScopes() []extensionScope {
   843  	scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes()))
   844  	for i, m := range mde.proto.GetNestedMessageTypes() {
   845  		scopes[i] = msgDescriptorExtensions{m}
   846  	}
   847  	return scopes
   848  }
   849  
   850  type adaptStreamFromV1 struct {
   851  	refv1.ServerReflection_ServerReflectionInfoClient
   852  }
   853  
   854  func (a adaptStreamFromV1) Send(request *refv1alpha.ServerReflectionRequest) error {
   855  	v1req := toV1Request(request)
   856  	return a.ServerReflection_ServerReflectionInfoClient.Send(v1req)
   857  }
   858  
   859  func (a adaptStreamFromV1) Recv() (*refv1alpha.ServerReflectionResponse, error) {
   860  	v1resp, err := a.ServerReflection_ServerReflectionInfoClient.Recv()
   861  	if err != nil {
   862  		return nil, err
   863  	}
   864  	return toV1AlphaResponse(v1resp), nil
   865  }