github.com/xiaoshude/protoreflect@v1.16.1-0.20220310024924-8c94d7247598/grpcreflect/client.go (about)

     1  package grpcreflect
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"runtime"
     9  	"sync"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    13  	"golang.org/x/net/context"
    14  	"google.golang.org/grpc/codes"
    15  	rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    16  	"google.golang.org/grpc/status"
    17  
    18  	"github.com/xiaoshude/protoreflect/desc"
    19  	"github.com/xiaoshude/protoreflect/internal"
    20  )
    21  
    22  // elementNotFoundError is the error returned by reflective operations where the
    23  // server does not recognize a given file name, symbol name, or extension.
    24  type elementNotFoundError struct {
    25  	name    string
    26  	kind    elementKind
    27  	symType symbolType // only used when kind == elementKindSymbol
    28  	tag     int32      // only used when kind == elementKindExtension
    29  
    30  	// only errors with a kind of elementKindFile will have a cause, which means
    31  	// the named file count not be resolved because of a dependency that could
    32  	// not be found where cause describes the missing dependency
    33  	cause *elementNotFoundError
    34  }
    35  
    36  type elementKind int
    37  
    38  const (
    39  	elementKindSymbol elementKind = iota
    40  	elementKindFile
    41  	elementKindExtension
    42  )
    43  
    44  type symbolType string
    45  
    46  const (
    47  	symbolTypeService = "Service"
    48  	symbolTypeMessage = "Message"
    49  	symbolTypeEnum    = "Enum"
    50  	symbolTypeUnknown = "Symbol"
    51  )
    52  
    53  func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error {
    54  	return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause}
    55  }
    56  
    57  func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error {
    58  	return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause}
    59  }
    60  
    61  func fileNotFound(file string, cause *elementNotFoundError) error {
    62  	return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause}
    63  }
    64  
    65  func (e *elementNotFoundError) Error() string {
    66  	first := true
    67  	var b bytes.Buffer
    68  	for ; e != nil; e = e.cause {
    69  		if first {
    70  			first = false
    71  		} else {
    72  			fmt.Fprint(&b, "\ncaused by: ")
    73  		}
    74  		switch e.kind {
    75  		case elementKindSymbol:
    76  			fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name)
    77  		case elementKindExtension:
    78  			fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name)
    79  		default:
    80  			fmt.Fprintf(&b, "File not found: %s", e.name)
    81  		}
    82  	}
    83  	return b.String()
    84  }
    85  
    86  // IsElementNotFoundError determines if the given error indicates that a file
    87  // name, symbol name, or extension field was could not be found by the server.
    88  func IsElementNotFoundError(err error) bool {
    89  	_, ok := err.(*elementNotFoundError)
    90  	return ok
    91  }
    92  
    93  // ProtocolError is an error returned when the server sends a response of the
    94  // wrong type.
    95  type ProtocolError struct {
    96  	missingType reflect.Type
    97  }
    98  
    99  func (p ProtocolError) Error() string {
   100  	return fmt.Sprintf("Protocol error: response was missing %v", p.missingType)
   101  }
   102  
   103  type extDesc struct {
   104  	extendedMessageName string
   105  	extensionNumber     int32
   106  }
   107  
   108  // Client is a client connection to a server for performing reflection calls
   109  // and resolving remote symbols.
   110  type Client struct {
   111  	ctx  context.Context
   112  	stub rpb.ServerReflectionClient
   113  
   114  	connMu sync.Mutex
   115  	cancel context.CancelFunc
   116  	stream rpb.ServerReflection_ServerReflectionInfoClient
   117  
   118  	cacheMu          sync.RWMutex
   119  	protosByName     map[string]*dpb.FileDescriptorProto
   120  	filesByName      map[string]*desc.FileDescriptor
   121  	filesBySymbol    map[string]*desc.FileDescriptor
   122  	filesByExtension map[extDesc]*desc.FileDescriptor
   123  }
   124  
   125  // NewClient creates a new Client with the given root context and using the
   126  // given RPC stub for talking to the server.
   127  func NewClient(ctx context.Context, stub rpb.ServerReflectionClient) *Client {
   128  	cr := &Client{
   129  		ctx:              ctx,
   130  		stub:             stub,
   131  		protosByName:     map[string]*dpb.FileDescriptorProto{},
   132  		filesByName:      map[string]*desc.FileDescriptor{},
   133  		filesBySymbol:    map[string]*desc.FileDescriptor{},
   134  		filesByExtension: map[extDesc]*desc.FileDescriptor{},
   135  	}
   136  	// don't leak a grpc stream
   137  	runtime.SetFinalizer(cr, (*Client).Reset)
   138  	return cr
   139  }
   140  
   141  // FileByFilename asks the server for a file descriptor for the proto file with
   142  // the given name.
   143  func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) {
   144  	// hit the cache first
   145  	cr.cacheMu.RLock()
   146  	if fd, ok := cr.filesByName[filename]; ok {
   147  		cr.cacheMu.RUnlock()
   148  		return fd, nil
   149  	}
   150  	fdp, ok := cr.protosByName[filename]
   151  	cr.cacheMu.RUnlock()
   152  	// not there? see if we've downloaded the proto
   153  	if ok {
   154  		return cr.descriptorFromProto(fdp)
   155  	}
   156  
   157  	req := &rpb.ServerReflectionRequest{
   158  		MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
   159  			FileByFilename: filename,
   160  		},
   161  	}
   162  	fd, err := cr.getAndCacheFileDescriptors(req, filename, "")
   163  	if isNotFound(err) {
   164  		// file not found? see if we can look up via alternate name
   165  		if alternate, ok := internal.StdFileAliases[filename]; ok {
   166  			req := &rpb.ServerReflectionRequest{
   167  				MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
   168  					FileByFilename: alternate,
   169  				},
   170  			}
   171  			fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename)
   172  			if isNotFound(err) {
   173  				err = fileNotFound(filename, nil)
   174  			}
   175  		} else {
   176  			err = fileNotFound(filename, nil)
   177  		}
   178  	} else if e, ok := err.(*elementNotFoundError); ok {
   179  		err = fileNotFound(filename, e)
   180  	}
   181  	return fd, err
   182  }
   183  
   184  // FileContainingSymbol asks the server for a file descriptor for the proto file
   185  // that declares the given fully-qualified symbol.
   186  func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) {
   187  	// hit the cache first
   188  	cr.cacheMu.RLock()
   189  	fd, ok := cr.filesBySymbol[symbol]
   190  	cr.cacheMu.RUnlock()
   191  	if ok {
   192  		return fd, nil
   193  	}
   194  
   195  	req := &rpb.ServerReflectionRequest{
   196  		MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{
   197  			FileContainingSymbol: symbol,
   198  		},
   199  	}
   200  	fd, err := cr.getAndCacheFileDescriptors(req, "", "")
   201  	if isNotFound(err) {
   202  		err = symbolNotFound(symbol, symbolTypeUnknown, nil)
   203  	} else if e, ok := err.(*elementNotFoundError); ok {
   204  		err = symbolNotFound(symbol, symbolTypeUnknown, e)
   205  	}
   206  	return fd, err
   207  }
   208  
   209  // FileContainingExtension asks the server for a file descriptor for the proto
   210  // file that declares an extension with the given number for the given
   211  // fully-qualified message name.
   212  func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) {
   213  	// hit the cache first
   214  	cr.cacheMu.RLock()
   215  	fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}]
   216  	cr.cacheMu.RUnlock()
   217  	if ok {
   218  		return fd, nil
   219  	}
   220  
   221  	req := &rpb.ServerReflectionRequest{
   222  		MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{
   223  			FileContainingExtension: &rpb.ExtensionRequest{
   224  				ContainingType:  extendedMessageName,
   225  				ExtensionNumber: extensionNumber,
   226  			},
   227  		},
   228  	}
   229  	fd, err := cr.getAndCacheFileDescriptors(req, "", "")
   230  	if isNotFound(err) {
   231  		err = extensionNotFound(extendedMessageName, extensionNumber, nil)
   232  	} else if e, ok := err.(*elementNotFoundError); ok {
   233  		err = extensionNotFound(extendedMessageName, extensionNumber, e)
   234  	}
   235  	return fd, err
   236  }
   237  
   238  func (cr *Client) getAndCacheFileDescriptors(req *rpb.ServerReflectionRequest, expectedName, alias string) (*desc.FileDescriptor, error) {
   239  	resp, err := cr.send(req)
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	fdResp := resp.GetFileDescriptorResponse()
   245  	if fdResp == nil {
   246  		return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()}
   247  	}
   248  
   249  	// Response can contain the result file descriptor, but also its transitive
   250  	// deps. Furthermore, protocol states that subsequent requests do not need
   251  	// to send transitive deps that have been sent in prior responses. So we
   252  	// need to cache all file descriptors and then return the first one (which
   253  	// should be the answer). If we're looking for a file by name, we can be
   254  	// smarter and make sure to grab one by name instead of just grabbing the
   255  	// first one.
   256  	var firstFd *dpb.FileDescriptorProto
   257  	for _, fdBytes := range fdResp.FileDescriptorProto {
   258  		fd := &dpb.FileDescriptorProto{}
   259  		if err = proto.Unmarshal(fdBytes, fd); err != nil {
   260  			return nil, err
   261  		}
   262  
   263  		if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName {
   264  			// we found a file was aliased, so we need to update the proto to reflect that
   265  			fd.Name = proto.String(alias)
   266  		}
   267  
   268  		cr.cacheMu.Lock()
   269  		// see if this file was created and cached concurrently
   270  		if firstFd == nil {
   271  			if d, ok := cr.filesByName[fd.GetName()]; ok {
   272  				cr.cacheMu.Unlock()
   273  				return d, nil
   274  			}
   275  		}
   276  		// store in cache of raw descriptor protos, but don't overwrite existing protos
   277  		if existingFd, ok := cr.protosByName[fd.GetName()]; ok {
   278  			fd = existingFd
   279  		} else {
   280  			cr.protosByName[fd.GetName()] = fd
   281  		}
   282  		cr.cacheMu.Unlock()
   283  		if firstFd == nil {
   284  			firstFd = fd
   285  		}
   286  	}
   287  	if firstFd == nil {
   288  		return nil, &ProtocolError{reflect.TypeOf(firstFd).Elem()}
   289  	}
   290  
   291  	return cr.descriptorFromProto(firstFd)
   292  }
   293  
   294  func (cr *Client) descriptorFromProto(fd *dpb.FileDescriptorProto) (*desc.FileDescriptor, error) {
   295  	deps := make([]*desc.FileDescriptor, len(fd.GetDependency()))
   296  	for i, depName := range fd.GetDependency() {
   297  		if dep, err := cr.FileByFilename(depName); err != nil {
   298  			return nil, err
   299  		} else {
   300  			deps[i] = dep
   301  		}
   302  	}
   303  	d, err := desc.CreateFileDescriptor(fd, deps...)
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  	d = cr.cacheFile(d)
   308  	return d, nil
   309  }
   310  
   311  func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor {
   312  	cr.cacheMu.Lock()
   313  	defer cr.cacheMu.Unlock()
   314  
   315  	// cache file descriptor by name, but don't overwrite existing entry
   316  	// (existing entry could come from concurrent caller)
   317  	if existingFd, ok := cr.filesByName[fd.GetName()]; ok {
   318  		return existingFd
   319  	}
   320  	cr.filesByName[fd.GetName()] = fd
   321  
   322  	// also cache by symbols and extensions
   323  	for _, m := range fd.GetMessageTypes() {
   324  		cr.cacheMessageLocked(fd, m)
   325  	}
   326  	for _, e := range fd.GetEnumTypes() {
   327  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   328  		for _, v := range e.GetValues() {
   329  			cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
   330  		}
   331  	}
   332  	for _, e := range fd.GetExtensions() {
   333  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   334  		cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
   335  	}
   336  	for _, s := range fd.GetServices() {
   337  		cr.filesBySymbol[s.GetFullyQualifiedName()] = fd
   338  		for _, m := range s.GetMethods() {
   339  			cr.filesBySymbol[m.GetFullyQualifiedName()] = fd
   340  		}
   341  	}
   342  
   343  	return fd
   344  }
   345  
   346  func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) {
   347  	cr.filesBySymbol[md.GetFullyQualifiedName()] = fd
   348  	for _, f := range md.GetFields() {
   349  		cr.filesBySymbol[f.GetFullyQualifiedName()] = fd
   350  	}
   351  	for _, o := range md.GetOneOfs() {
   352  		cr.filesBySymbol[o.GetFullyQualifiedName()] = fd
   353  	}
   354  	for _, e := range md.GetNestedEnumTypes() {
   355  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   356  		for _, v := range e.GetValues() {
   357  			cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
   358  		}
   359  	}
   360  	for _, e := range md.GetNestedExtensions() {
   361  		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
   362  		cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
   363  	}
   364  	for _, m := range md.GetNestedMessageTypes() {
   365  		cr.cacheMessageLocked(fd, m) // recurse
   366  	}
   367  }
   368  
   369  // AllExtensionNumbersForType asks the server for all known extension numbers
   370  // for the given fully-qualified message name.
   371  func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) {
   372  	req := &rpb.ServerReflectionRequest{
   373  		MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{
   374  			AllExtensionNumbersOfType: extendedMessageName,
   375  		},
   376  	}
   377  	resp, err := cr.send(req)
   378  	if err != nil {
   379  		if isNotFound(err) {
   380  			return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil)
   381  		}
   382  		return nil, err
   383  	}
   384  
   385  	extResp := resp.GetAllExtensionNumbersResponse()
   386  	if extResp == nil {
   387  		return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()}
   388  	}
   389  	return extResp.ExtensionNumber, nil
   390  }
   391  
   392  // ListServices asks the server for the fully-qualified names of all exposed
   393  // services.
   394  func (cr *Client) ListServices() ([]string, error) {
   395  	req := &rpb.ServerReflectionRequest{
   396  		MessageRequest: &rpb.ServerReflectionRequest_ListServices{
   397  			// proto doesn't indicate any purpose for this value and server impl
   398  			// doesn't actually use it...
   399  			ListServices: "*",
   400  		},
   401  	}
   402  	resp, err := cr.send(req)
   403  	if err != nil {
   404  		return nil, err
   405  	}
   406  
   407  	listResp := resp.GetListServicesResponse()
   408  	if listResp == nil {
   409  		return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()}
   410  	}
   411  	serviceNames := make([]string, len(listResp.Service))
   412  	for i, s := range listResp.Service {
   413  		serviceNames[i] = s.Name
   414  	}
   415  	return serviceNames, nil
   416  }
   417  
   418  func (cr *Client) send(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
   419  	// we allow one immediate retry, in case we have a stale stream
   420  	// (e.g. closed by server)
   421  	resp, err := cr.doSend(true, req)
   422  	if err != nil {
   423  		return nil, err
   424  	}
   425  
   426  	// convert error response messages into errors
   427  	errResp := resp.GetErrorResponse()
   428  	if errResp != nil {
   429  		return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage)
   430  	}
   431  
   432  	return resp, nil
   433  }
   434  
   435  func isNotFound(err error) bool {
   436  	if err == nil {
   437  		return false
   438  	}
   439  	s, ok := status.FromError(err)
   440  	return ok && s.Code() == codes.NotFound
   441  }
   442  
   443  func (cr *Client) doSend(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
   444  	// TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery
   445  	// (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus
   446  	// delivered in correct oder.
   447  	cr.connMu.Lock()
   448  	defer cr.connMu.Unlock()
   449  	return cr.doSendLocked(retry, req)
   450  }
   451  
   452  func (cr *Client) doSendLocked(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
   453  	if err := cr.initStreamLocked(); err != nil {
   454  		return nil, err
   455  	}
   456  
   457  	if err := cr.stream.Send(req); err != nil {
   458  		if err == io.EOF {
   459  			// if send returns EOF, must call Recv to get real underlying error
   460  			_, err = cr.stream.Recv()
   461  		}
   462  		cr.resetLocked()
   463  		if retry {
   464  			return cr.doSendLocked(false, req)
   465  		}
   466  		return nil, err
   467  	}
   468  
   469  	if resp, err := cr.stream.Recv(); err != nil {
   470  		cr.resetLocked()
   471  		if retry {
   472  			return cr.doSendLocked(false, req)
   473  		}
   474  		return nil, err
   475  	} else {
   476  		return resp, nil
   477  	}
   478  }
   479  
   480  func (cr *Client) initStreamLocked() error {
   481  	if cr.stream != nil {
   482  		return nil
   483  	}
   484  	var newCtx context.Context
   485  	newCtx, cr.cancel = context.WithCancel(cr.ctx)
   486  	var err error
   487  	cr.stream, err = cr.stub.ServerReflectionInfo(newCtx)
   488  	return err
   489  }
   490  
   491  // Reset ensures that any active stream with the server is closed, releasing any
   492  // resources.
   493  func (cr *Client) Reset() {
   494  	cr.connMu.Lock()
   495  	defer cr.connMu.Unlock()
   496  	cr.resetLocked()
   497  }
   498  
   499  func (cr *Client) resetLocked() {
   500  	if cr.stream != nil {
   501  		cr.stream.CloseSend()
   502  		for {
   503  			// drain the stream, this covers io.EOF too
   504  			if _, err := cr.stream.Recv(); err != nil {
   505  				break
   506  			}
   507  		}
   508  		cr.stream = nil
   509  	}
   510  	if cr.cancel != nil {
   511  		cr.cancel()
   512  		cr.cancel = nil
   513  	}
   514  }
   515  
   516  // ResolveService asks the server to resolve the given fully-qualified service
   517  // name into a service descriptor.
   518  func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) {
   519  	file, err := cr.FileContainingSymbol(serviceName)
   520  	if err != nil {
   521  		return nil, setSymbolType(err, serviceName, symbolTypeService)
   522  	}
   523  	d := file.FindSymbol(serviceName)
   524  	if d == nil {
   525  		return nil, symbolNotFound(serviceName, symbolTypeService, nil)
   526  	}
   527  	if s, ok := d.(*desc.ServiceDescriptor); ok {
   528  		return s, nil
   529  	} else {
   530  		return nil, symbolNotFound(serviceName, symbolTypeService, nil)
   531  	}
   532  }
   533  
   534  // ResolveMessage asks the server to resolve the given fully-qualified message
   535  // name into a message descriptor.
   536  func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) {
   537  	file, err := cr.FileContainingSymbol(messageName)
   538  	if err != nil {
   539  		return nil, setSymbolType(err, messageName, symbolTypeMessage)
   540  	}
   541  	d := file.FindSymbol(messageName)
   542  	if d == nil {
   543  		return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
   544  	}
   545  	if s, ok := d.(*desc.MessageDescriptor); ok {
   546  		return s, nil
   547  	} else {
   548  		return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
   549  	}
   550  }
   551  
   552  // ResolveEnum asks the server to resolve the given fully-qualified enum name
   553  // into an enum descriptor.
   554  func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) {
   555  	file, err := cr.FileContainingSymbol(enumName)
   556  	if err != nil {
   557  		return nil, setSymbolType(err, enumName, symbolTypeEnum)
   558  	}
   559  	d := file.FindSymbol(enumName)
   560  	if d == nil {
   561  		return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
   562  	}
   563  	if s, ok := d.(*desc.EnumDescriptor); ok {
   564  		return s, nil
   565  	} else {
   566  		return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
   567  	}
   568  }
   569  
   570  func setSymbolType(err error, name string, symType symbolType) error {
   571  	if e, ok := err.(*elementNotFoundError); ok {
   572  		if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown {
   573  			e.symType = symType
   574  		}
   575  	}
   576  	return err
   577  }
   578  
   579  // ResolveEnumValues asks the server to resolve the given fully-qualified enum
   580  // name into a map of names to numbers that represents the enum's values.
   581  func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) {
   582  	enumDesc, err := cr.ResolveEnum(enumName)
   583  	if err != nil {
   584  		return nil, err
   585  	}
   586  	vals := map[string]int32{}
   587  	for _, valDesc := range enumDesc.GetValues() {
   588  		vals[valDesc.GetName()] = valDesc.GetNumber()
   589  	}
   590  	return vals, nil
   591  }
   592  
   593  // ResolveExtension asks the server to resolve the given extension number and
   594  // fully-qualified message name into a field descriptor.
   595  func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) {
   596  	file, err := cr.FileContainingExtension(extendedType, extensionNumber)
   597  	if err != nil {
   598  		return nil, err
   599  	}
   600  	d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file})
   601  	if d == nil {
   602  		return nil, extensionNotFound(extendedType, extensionNumber, nil)
   603  	} else {
   604  		return d, nil
   605  	}
   606  }
   607  
   608  func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor {
   609  	// search extensions in this scope
   610  	for _, ext := range scope.extensions() {
   611  		if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType {
   612  			return ext
   613  		}
   614  	}
   615  
   616  	// if not found, search nested scopes
   617  	for _, nested := range scope.nestedScopes() {
   618  		ext := findExtension(extendedType, extensionNumber, nested)
   619  		if ext != nil {
   620  			return ext
   621  		}
   622  	}
   623  
   624  	return nil
   625  }
   626  
   627  type extensionScope interface {
   628  	extensions() []*desc.FieldDescriptor
   629  	nestedScopes() []extensionScope
   630  }
   631  
   632  // fileDescriptorExtensions implements extensionHolder interface on top of
   633  // FileDescriptorProto
   634  type fileDescriptorExtensions struct {
   635  	proto *desc.FileDescriptor
   636  }
   637  
   638  func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor {
   639  	return fde.proto.GetExtensions()
   640  }
   641  
   642  func (fde fileDescriptorExtensions) nestedScopes() []extensionScope {
   643  	scopes := make([]extensionScope, len(fde.proto.GetMessageTypes()))
   644  	for i, m := range fde.proto.GetMessageTypes() {
   645  		scopes[i] = msgDescriptorExtensions{m}
   646  	}
   647  	return scopes
   648  }
   649  
   650  // msgDescriptorExtensions implements extensionHolder interface on top of
   651  // DescriptorProto
   652  type msgDescriptorExtensions struct {
   653  	proto *desc.MessageDescriptor
   654  }
   655  
   656  func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor {
   657  	return mde.proto.GetNestedExtensions()
   658  }
   659  
   660  func (mde msgDescriptorExtensions) nestedScopes() []extensionScope {
   661  	scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes()))
   662  	for i, m := range mde.proto.GetNestedMessageTypes() {
   663  		scopes[i] = msgDescriptorExtensions{m}
   664  	}
   665  	return scopes
   666  }