github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/ptype_resolver.go (about)

     1  package msgregistry
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"reflect"
     8  	"sort"
     9  	"strings"
    10  	"sync"
    11  	"sync/atomic"
    12  
    13  	"github.com/golang/protobuf/proto"
    14  	"google.golang.org/protobuf/types/descriptorpb"
    15  	"google.golang.org/protobuf/types/known/apipb"
    16  	"google.golang.org/protobuf/types/known/typepb"
    17  	"google.golang.org/protobuf/types/known/wrapperspb"
    18  
    19  	"github.com/jhump/protoreflect/desc"
    20  	"github.com/jhump/protoreflect/dynamic"
    21  )
    22  
    23  var (
    24  	enumOptionsDesc, enumValueOptionsDesc *desc.MessageDescriptor
    25  	msgOptionsDesc, fieldOptionsDesc      *desc.MessageDescriptor
    26  	svcOptionsDesc, methodOptionsDesc     *desc.MessageDescriptor
    27  )
    28  
    29  func init() {
    30  	var err error
    31  	enumOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.EnumOptions)(nil))
    32  	if err != nil {
    33  		panic("Failed to load descriptor for EnumOptions")
    34  	}
    35  	enumValueOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.EnumValueOptions)(nil))
    36  	if err != nil {
    37  		panic("Failed to load descriptor for EnumValueOptions")
    38  	}
    39  	msgOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.MessageOptions)(nil))
    40  	if err != nil {
    41  		panic("Failed to load descriptor for MessageOptions")
    42  	}
    43  	fieldOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.FieldOptions)(nil))
    44  	if err != nil {
    45  		panic("Failed to load descriptor for FieldOptions")
    46  	}
    47  	svcOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.ServiceOptions)(nil))
    48  	if err != nil {
    49  		panic("Failed to load descriptor for ServiceOptions")
    50  	}
    51  	methodOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.MethodOptions)(nil))
    52  	if err != nil {
    53  		panic("Failed to load descriptor for MethodOptions")
    54  	}
    55  }
    56  
    57  func ensureScheme(url string) string {
    58  	pos := strings.Index(url, "://")
    59  	if pos < 0 {
    60  		return "https://" + url
    61  	}
    62  	return url
    63  }
    64  
    65  // typeResolver is used by MessageRegistry to resolve message types. It uses a given TypeFetcher
    66  // to retrieve type definitions and caches resulting descriptor objects.
    67  type typeResolver struct {
    68  	fetcher TypeFetcher
    69  	mr      *MessageRegistry
    70  	mu      sync.RWMutex
    71  	cache   map[string]desc.Descriptor
    72  }
    73  
    74  // resolveUrlToMessageDescriptor returns a message descriptor that represents the type at the given URL.
    75  func (r *typeResolver) resolveUrlToMessageDescriptor(url string) (*desc.MessageDescriptor, error) {
    76  	url = ensureScheme(url)
    77  	r.mu.RLock()
    78  	cached := r.cache[url]
    79  	r.mu.RUnlock()
    80  	if cached != nil {
    81  		if md, ok := cached.(*desc.MessageDescriptor); ok {
    82  			return md, nil
    83  		} else {
    84  			return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", url)
    85  		}
    86  	}
    87  
    88  	rc := newResolutionContext(r)
    89  	if err := rc.addType(url, false); err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	var files map[string]*desc.FileDescriptor
    94  	files, err := rc.toFileDescriptors(r.mr)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  	r.mu.Lock()
    99  	defer r.mu.Unlock()
   100  	var md *desc.MessageDescriptor
   101  	if len(rc.typeLocations) > 0 {
   102  		if r.cache == nil {
   103  			r.cache = map[string]desc.Descriptor{}
   104  		}
   105  	}
   106  	for typeUrl, fileName := range rc.typeLocations {
   107  		fd := files[fileName]
   108  		sym := fd.FindSymbol(typeName(typeUrl))
   109  		r.cache[typeUrl] = sym
   110  		if url == typeUrl {
   111  			md = sym.(*desc.MessageDescriptor)
   112  		}
   113  	}
   114  	return md, nil
   115  }
   116  
   117  // resolveUrlsToMessageDescriptors returns a map of the given URLs to corresponding
   118  // message descriptors that represent the types at those URLs.
   119  func (r *typeResolver) resolveUrlsToMessageDescriptors(urls ...string) (map[string]*desc.MessageDescriptor, error) {
   120  	ret := map[string]*desc.MessageDescriptor{}
   121  	var unresolved []string
   122  	r.mu.RLock()
   123  	for _, u := range urls {
   124  		u = ensureScheme(u)
   125  		cached := r.cache[u]
   126  		if cached != nil {
   127  			if md, ok := cached.(*desc.MessageDescriptor); ok {
   128  				ret[u] = md
   129  			} else {
   130  				r.mu.RUnlock()
   131  				return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", u)
   132  			}
   133  		} else {
   134  			ret[u] = nil
   135  			unresolved = append(unresolved, u)
   136  		}
   137  	}
   138  	r.mu.RUnlock()
   139  
   140  	if len(unresolved) == 0 {
   141  		return ret, nil
   142  	}
   143  
   144  	rc := newResolutionContext(r)
   145  	for _, u := range unresolved {
   146  		if err := rc.addType(u, false); err != nil {
   147  			return nil, err
   148  		}
   149  	}
   150  
   151  	var files map[string]*desc.FileDescriptor
   152  	files, err := rc.toFileDescriptors(r.mr)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	r.mu.Lock()
   157  	defer r.mu.Unlock()
   158  	if len(rc.typeLocations) > 0 {
   159  		if r.cache == nil {
   160  			r.cache = map[string]desc.Descriptor{}
   161  		}
   162  	}
   163  	for typeUrl, fileName := range rc.typeLocations {
   164  		fd := files[fileName]
   165  		sym := fd.FindSymbol(typeName(typeUrl))
   166  		r.cache[typeUrl] = sym
   167  		if _, ok := ret[typeUrl]; ok {
   168  			ret[typeUrl] = sym.(*desc.MessageDescriptor)
   169  		}
   170  	}
   171  	return ret, nil
   172  }
   173  
   174  // resolveUrlToEnumDescriptor returns an enum descriptor that represents the enum type at the given URL.
   175  func (r *typeResolver) resolveUrlToEnumDescriptor(url string) (*desc.EnumDescriptor, error) {
   176  	url = ensureScheme(url)
   177  	r.mu.RLock()
   178  	cached := r.cache[url]
   179  	r.mu.RUnlock()
   180  	if cached != nil {
   181  		if ed, ok := cached.(*desc.EnumDescriptor); ok {
   182  			return ed, nil
   183  		} else {
   184  			return nil, fmt.Errorf("type for URL %v is the wrong type: wanted enum, got message", url)
   185  		}
   186  	}
   187  
   188  	rc := newResolutionContext(r)
   189  	if err := rc.addType(url, true); err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	var files map[string]*desc.FileDescriptor
   194  	files, err := rc.toFileDescriptors(r.mr)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	r.mu.Lock()
   199  	defer r.mu.Unlock()
   200  	var ed *desc.EnumDescriptor
   201  	if len(rc.typeLocations) > 0 {
   202  		if r.cache == nil {
   203  			r.cache = map[string]desc.Descriptor{}
   204  		}
   205  	}
   206  	for typeUrl, fileName := range rc.typeLocations {
   207  		fd := files[fileName]
   208  		sym := fd.FindSymbol(typeName(typeUrl))
   209  		r.cache[typeUrl] = sym
   210  		if url == typeUrl {
   211  			ed = sym.(*desc.EnumDescriptor)
   212  		}
   213  	}
   214  	return ed, nil
   215  }
   216  
   217  type tracker func(d desc.Descriptor) bool
   218  
   219  func newNameTracker() tracker {
   220  	names := map[string]struct{}{}
   221  	return func(d desc.Descriptor) bool {
   222  		name := d.GetFullyQualifiedName()
   223  		if _, ok := names[name]; ok {
   224  			return false
   225  		}
   226  		names[name] = struct{}{}
   227  		return true
   228  	}
   229  }
   230  
   231  func addDescriptors(ref string, files map[string]*fileEntry, d desc.Descriptor, msgs map[string]*desc.MessageDescriptor, onAdd tracker) {
   232  	name := d.GetFullyQualifiedName()
   233  
   234  	fileName := d.GetFile().GetName()
   235  	if fileName != ref {
   236  		dependee := files[ref]
   237  		if dependee.deps == nil {
   238  			dependee.deps = map[string]struct{}{}
   239  		}
   240  		dependee.deps[fileName] = struct{}{}
   241  	}
   242  
   243  	if !onAdd(d) {
   244  		// already added this one
   245  		return
   246  	}
   247  
   248  	fe := files[fileName]
   249  	if fe == nil {
   250  		fe = &fileEntry{}
   251  		fe.proto3 = d.GetFile().IsProto3()
   252  		files[fileName] = fe
   253  	}
   254  	fe.types.addType(name, d.AsProto())
   255  
   256  	if md, ok := d.(*desc.MessageDescriptor); ok {
   257  		for _, fld := range md.GetFields() {
   258  			if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE || fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP {
   259  				// prefer descriptor in msgs map over what the field descriptor indicates
   260  				md := msgs[fld.GetMessageType().GetFullyQualifiedName()]
   261  				if md == nil {
   262  					md = fld.GetMessageType()
   263  				}
   264  				addDescriptors(fileName, files, md, msgs, onAdd)
   265  			} else if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_ENUM {
   266  				addDescriptors(fileName, files, fld.GetEnumType(), msgs, onAdd)
   267  			}
   268  		}
   269  	}
   270  }
   271  
   272  // resolutionContext provides the state for a resolution operation, accumulating details about
   273  // type descriptions and the files that contain them.
   274  type resolutionContext struct {
   275  	// The context and cancel function, used to coordinate multiple goroutines when there are multiple
   276  	// type or enum descriptions to download.
   277  	ctx    context.Context
   278  	cancel func()
   279  	res    *typeResolver
   280  
   281  	mu sync.Mutex
   282  	// map of file names to details regarding the files' contents
   283  	files map[string]*fileEntry
   284  	// map of type URLs to the file name that defines them
   285  	typeLocations map[string]string
   286  	// count of source contexts that do not indicate a file name (used to generate unique file names
   287  	// when synthesizing file descriptors)
   288  	unknownCount int
   289  }
   290  
   291  func newResolutionContext(res *typeResolver) *resolutionContext {
   292  	ctx, cancel := context.WithCancel(context.Background())
   293  	return &resolutionContext{
   294  		ctx:           ctx,
   295  		cancel:        cancel,
   296  		res:           res,
   297  		typeLocations: map[string]string{},
   298  		files:         map[string]*fileEntry{},
   299  	}
   300  }
   301  
   302  // addType adds the type at the given URL to the context, using the given fetcher to download the type's
   303  // description. This function will recursively add dependencies (e.g. types referenced by the given type's
   304  // fields if it is a message type), fetching their type descriptions concurrently.
   305  func (rc *resolutionContext) addType(url string, enum bool) error {
   306  	if err := rc.ctx.Err(); err != nil {
   307  		return err
   308  	}
   309  
   310  	m, err := rc.res.fetcher(url, enum)
   311  	if err != nil {
   312  		return err
   313  	} else if m == nil {
   314  		return fmt.Errorf("failed to locate type for %s", url)
   315  	}
   316  
   317  	if enum {
   318  		rc.recordEnum(url, m.(*typepb.Enum))
   319  		return nil
   320  	}
   321  
   322  	// for messages, resolve dependencies in parallel
   323  	t := m.(*typepb.Type)
   324  	fe, fileName := rc.recordType(url, t)
   325  	if fe == nil {
   326  		// already resolved this one
   327  		return nil
   328  	}
   329  
   330  	var wg sync.WaitGroup
   331  	var failed int32
   332  	for _, f := range t.Fields {
   333  		if f.Kind == typepb.Field_TYPE_GROUP || f.Kind == typepb.Field_TYPE_MESSAGE || f.Kind == typepb.Field_TYPE_ENUM {
   334  			typeUrl := ensureScheme(f.TypeUrl)
   335  			kind := f.Kind
   336  			wg.Add(1)
   337  			go func() {
   338  				defer wg.Done()
   339  				// first check the registry for descriptors
   340  				var d desc.Descriptor
   341  				var innerErr error
   342  				if kind == typepb.Field_TYPE_ENUM {
   343  					var ed *desc.EnumDescriptor
   344  					ed, innerErr = rc.res.mr.getRegisteredEnumTypeByUrl(typeUrl)
   345  					if ed != nil {
   346  						d = ed
   347  					}
   348  				} else {
   349  					var md *desc.MessageDescriptor
   350  					md, innerErr = rc.res.mr.getRegisteredMessageTypeByUrl(typeUrl)
   351  					if md != nil {
   352  						d = md
   353  					}
   354  				}
   355  
   356  				if innerErr == nil {
   357  					if d != nil {
   358  						// found it!
   359  						rc.recordDescriptor(typeUrl, fileName, d)
   360  					} else {
   361  						// not in registry, so we have to recursively fetch
   362  						innerErr = rc.addType(typeUrl, kind == typepb.Field_TYPE_ENUM)
   363  					}
   364  				}
   365  
   366  				// We want the "real" error to ultimately propagate to root, not
   367  				// one of the resulting cancellations (from any concurrent goroutines
   368  				// working in the same resolution context).
   369  				if innerErr != nil && (rc.ctx.Err() == nil || innerErr != context.Canceled) {
   370  					if atomic.CompareAndSwapInt32(&failed, 0, 1) {
   371  						err = innerErr
   372  					}
   373  					rc.cancel()
   374  				}
   375  			}()
   376  		}
   377  	}
   378  	wg.Wait()
   379  	if err != nil {
   380  		return err
   381  	}
   382  	// double-check if context has been cancelled
   383  	if err = rc.ctx.Err(); err != nil {
   384  		return err
   385  	}
   386  
   387  	rc.mu.Lock()
   388  	defer rc.mu.Unlock()
   389  
   390  	for _, f := range t.Fields {
   391  		if f.Kind == typepb.Field_TYPE_GROUP || f.Kind == typepb.Field_TYPE_MESSAGE || f.Kind == typepb.Field_TYPE_ENUM {
   392  			typeUrl := ensureScheme(f.TypeUrl)
   393  			if fe.deps == nil {
   394  				fe.deps = map[string]struct{}{}
   395  			}
   396  			dep := rc.typeLocations[typeUrl]
   397  			if dep != fileName {
   398  				fe.deps[dep] = struct{}{}
   399  			}
   400  		}
   401  	}
   402  	return nil
   403  }
   404  
   405  func (rc *resolutionContext) recordEnum(url string, e *typepb.Enum) {
   406  	rc.mu.Lock()
   407  	defer rc.mu.Unlock()
   408  
   409  	var fileName string
   410  	if e.SourceContext != nil && e.SourceContext.FileName != "" {
   411  		fileName = e.SourceContext.FileName
   412  	} else {
   413  		fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount)
   414  		rc.unknownCount++
   415  	}
   416  	rc.typeLocations[url] = fileName
   417  
   418  	fe := rc.files[fileName]
   419  	if fe == nil {
   420  		fe = &fileEntry{}
   421  		rc.files[fileName] = fe
   422  	}
   423  	fe.types.addType(e.Name, e)
   424  	if e.Syntax == typepb.Syntax_SYNTAX_PROTO3 {
   425  		fe.proto3 = true
   426  	}
   427  }
   428  
   429  func (rc *resolutionContext) recordType(url string, t *typepb.Type) (*fileEntry, string) {
   430  	rc.mu.Lock()
   431  	defer rc.mu.Unlock()
   432  
   433  	if _, ok := rc.typeLocations[url]; ok {
   434  		return nil, ""
   435  	}
   436  
   437  	var fileName string
   438  	if t.SourceContext != nil && t.SourceContext.FileName != "" {
   439  		fileName = t.SourceContext.FileName
   440  	} else {
   441  		fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount)
   442  		rc.unknownCount++
   443  	}
   444  	rc.typeLocations[url] = fileName
   445  
   446  	fe := rc.files[fileName]
   447  	if fe == nil {
   448  		fe = &fileEntry{}
   449  		rc.files[fileName] = fe
   450  	}
   451  	fe.types.addType(t.Name, t)
   452  	if t.Syntax == typepb.Syntax_SYNTAX_PROTO3 {
   453  		fe.proto3 = true
   454  	}
   455  
   456  	return fe, fileName
   457  }
   458  
   459  func (rc *resolutionContext) recordDescriptor(url, ref string, d desc.Descriptor) {
   460  	rc.mu.Lock()
   461  	defer rc.mu.Unlock()
   462  
   463  	addDescriptors(ref, rc.files, d, nil, func(dsc desc.Descriptor) bool {
   464  		u := ensureScheme(rc.res.mr.ComputeUrl(dsc))
   465  		if _, ok := rc.typeLocations[u]; ok {
   466  			// already seen this one
   467  			return false
   468  		}
   469  		fileName := dsc.GetFile().GetName()
   470  		rc.typeLocations[u] = fileName
   471  		if dsc == d {
   472  			// make sure we're also adding the actual URL reference used
   473  			rc.typeLocations[url] = fileName
   474  		}
   475  		return true
   476  	})
   477  }
   478  
   479  // toFileDescriptors converts the information in the context into a map of file names to file descriptors.
   480  func (rc *resolutionContext) toFileDescriptors(mr *MessageRegistry) (map[string]*desc.FileDescriptor, error) {
   481  	return toFileDescriptors(rc.files, func(tt *typeTrie, name string) (proto.Message, error) {
   482  		mdp, edp := tt.ptypeToDescriptor(name, mr)
   483  		if mdp != nil {
   484  			return mdp, nil
   485  		} else {
   486  			return edp, nil
   487  		}
   488  	})
   489  }
   490  
   491  // converts a map of file entries into a map of file descriptors using the given function to convert
   492  // each trie node into a descriptor proto.
   493  func toFileDescriptors(files map[string]*fileEntry, trieFn func(*typeTrie, string) (proto.Message, error)) (map[string]*desc.FileDescriptor, error) {
   494  	fdps := map[string]*descriptorpb.FileDescriptorProto{}
   495  	for name, file := range files {
   496  		fdp, err := file.toFileDescriptor(name, trieFn)
   497  		if err != nil {
   498  			return nil, err
   499  		}
   500  		fdps[name] = fdp
   501  	}
   502  	fds := map[string]*desc.FileDescriptor{}
   503  	for name, fdp := range fdps {
   504  		if _, ok := fds[name]; ok {
   505  			continue
   506  		}
   507  		var err error
   508  		if fds[name], err = makeFileDesc(fdp, fds, fdps); err != nil {
   509  			return nil, err
   510  		}
   511  	}
   512  	return fds, nil
   513  }
   514  
   515  func makeFileDesc(fdp *descriptorpb.FileDescriptorProto, fds map[string]*desc.FileDescriptor, fdps map[string]*descriptorpb.FileDescriptorProto) (*desc.FileDescriptor, error) {
   516  	deps := make([]*desc.FileDescriptor, len(fdp.Dependency))
   517  	for i, dep := range fdp.Dependency {
   518  		d := fds[dep]
   519  		if d == nil {
   520  			var err error
   521  			depFd := fdps[dep]
   522  			if depFd == nil {
   523  				return nil, fmt.Errorf("missing dependency: %s", dep)
   524  			}
   525  			d, err = makeFileDesc(depFd, fds, fdps)
   526  			if err != nil {
   527  				return nil, err
   528  			}
   529  		}
   530  		deps[i] = d
   531  	}
   532  	if fd, err := desc.CreateFileDescriptor(fdp, deps...); err != nil {
   533  		return nil, err
   534  	} else {
   535  		fds[fdp.GetName()] = fd
   536  		return fd, nil
   537  	}
   538  }
   539  
   540  // fileEntry represents the contents of a single file.
   541  type fileEntry struct {
   542  	types  typeTrie
   543  	deps   map[string]struct{}
   544  	proto3 bool
   545  }
   546  
   547  // toFileDescriptor converts this file entry into a file descriptor proto. The given function
   548  // is used to transform nodes in a typeTrie into message and/or enum descriptor protos.
   549  func (fe *fileEntry) toFileDescriptor(name string, trieFn func(*typeTrie, string) (proto.Message, error)) (*descriptorpb.FileDescriptorProto, error) {
   550  	var pkg bytes.Buffer
   551  	tt := &fe.types
   552  	first := true
   553  	last := ""
   554  	for tt.typ == nil {
   555  		if last != "" {
   556  			if first {
   557  				first = false
   558  			} else {
   559  				pkg.WriteByte('.')
   560  			}
   561  			pkg.WriteString(last)
   562  		}
   563  		if len(tt.children) != 1 {
   564  			break
   565  		}
   566  		for last, tt = range tt.children {
   567  		}
   568  	}
   569  	fd := createFileDescriptor(name, pkg.String(), fe.proto3, fe.deps)
   570  	if tt.typ != nil {
   571  		pm, err := trieFn(tt, last)
   572  		if err != nil {
   573  			return nil, err
   574  		}
   575  		if mdp, ok := pm.(*descriptorpb.DescriptorProto); ok {
   576  			fd.MessageType = append(fd.MessageType, mdp)
   577  		} else if edp, ok := pm.(*descriptorpb.EnumDescriptorProto); ok {
   578  			fd.EnumType = append(fd.EnumType, edp)
   579  		} else {
   580  			sdp := pm.(*descriptorpb.ServiceDescriptorProto)
   581  			fd.Service = append(fd.Service, sdp)
   582  		}
   583  	} else {
   584  		for name, nested := range tt.children {
   585  			pm, err := trieFn(nested, name)
   586  			if err != nil {
   587  				return nil, err
   588  			}
   589  			if mdp, ok := pm.(*descriptorpb.DescriptorProto); ok {
   590  				fd.MessageType = append(fd.MessageType, mdp)
   591  			} else if edp, ok := pm.(*descriptorpb.EnumDescriptorProto); ok {
   592  				fd.EnumType = append(fd.EnumType, edp)
   593  			} else {
   594  				sdp := pm.(*descriptorpb.ServiceDescriptorProto)
   595  				fd.Service = append(fd.Service, sdp)
   596  			}
   597  		}
   598  	}
   599  	return fd, nil
   600  }
   601  
   602  // typeTrie is a prefix trie where each key component is part of a fully-qualified type name. So key components
   603  // will either be package name components or element names.
   604  type typeTrie struct {
   605  	// successor key components
   606  	children map[string]*typeTrie
   607  	// if non-nil, the element whose fully-qualified name is the path from the trie root to this node
   608  	typ proto.Message
   609  }
   610  
   611  // addType recursively adds an element to the trie.
   612  func (t *typeTrie) addType(key string, typ proto.Message) {
   613  	if key == "" {
   614  		t.typ = typ
   615  		return
   616  	}
   617  	if t.children == nil {
   618  		t.children = map[string]*typeTrie{}
   619  	}
   620  	curr, rest := split(key)
   621  	child := t.children[curr]
   622  	if child == nil {
   623  		child = &typeTrie{}
   624  		t.children[curr] = child
   625  	}
   626  	child.addType(rest, typ)
   627  }
   628  
   629  // ptypeToDescriptor converts this level of the trie into a message or enum
   630  // descriptor proto, requiring that the element stored in t.typ is a *ptype.Type
   631  // or *ptype.Enum. If t.typ is nil, a placeholder message (with no fields) is
   632  // returned that contains the trie's children as nested message and/or enum
   633  // types.
   634  //
   635  // If the value in t.typ is already a *descriptor.DescriptorProto or a
   636  // *descriptor.EnumDescriptorProto then it is returned as is. This function
   637  // should not be used in type tries that may have service descriptors. That will
   638  // result in a panic.
   639  func (t *typeTrie) ptypeToDescriptor(name string, mr *MessageRegistry) (*descriptorpb.DescriptorProto, *descriptorpb.EnumDescriptorProto) {
   640  	switch typ := t.typ.(type) {
   641  	case *descriptorpb.EnumDescriptorProto:
   642  		return nil, typ
   643  	case *typepb.Enum:
   644  		return nil, createEnumDescriptor(typ, mr)
   645  	case *descriptorpb.DescriptorProto:
   646  		return typ, nil
   647  	default:
   648  		var msg *descriptorpb.DescriptorProto
   649  		if t.typ == nil {
   650  			msg = createIntermediateMessageDescriptor(name)
   651  		} else {
   652  			msg = createMessageDescriptor(t.typ.(*typepb.Type), mr)
   653  		}
   654  		// sort children for deterministic output
   655  		var keys []string
   656  		for k := range t.children {
   657  			keys = append(keys, k)
   658  		}
   659  		for _, name := range keys {
   660  			nested := t.children[name]
   661  			chMsg, chEnum := nested.ptypeToDescriptor(name, mr)
   662  			if chMsg != nil {
   663  				msg.NestedType = append(msg.NestedType, chMsg)
   664  			}
   665  			if chEnum != nil {
   666  				msg.EnumType = append(msg.EnumType, chEnum)
   667  			}
   668  		}
   669  		return msg, nil
   670  	}
   671  }
   672  
   673  // rewriteDescriptor converts this level of the trie into a new descriptor
   674  // proto, requiring that the element stored in t.type is already a service,
   675  // message, or enum descriptor proto. If this trie has children then t.typ must
   676  // be a message descriptor proto. The returned descriptor proto is the same as
   677  // .type but with possibly new nested elements to represent this trie node's
   678  // children.
   679  func (t *typeTrie) rewriteDescriptor(name string) (proto.Message, error) {
   680  	if len(t.children) == 0 && t.typ != nil {
   681  		if mdp, ok := t.typ.(*descriptorpb.DescriptorProto); ok {
   682  			if len(mdp.NestedType) == 0 && len(mdp.EnumType) == 0 {
   683  				return mdp, nil
   684  			}
   685  			mdp = proto.Clone(mdp).(*descriptorpb.DescriptorProto)
   686  			mdp.NestedType = nil
   687  			mdp.EnumType = nil
   688  			return mdp, nil
   689  		}
   690  		return t.typ, nil
   691  	}
   692  	var mdp *descriptorpb.DescriptorProto
   693  	if t.typ == nil {
   694  		mdp = createIntermediateMessageDescriptor(name)
   695  	} else {
   696  		mdp = t.typ.(*descriptorpb.DescriptorProto)
   697  		mdp = proto.Clone(mdp).(*descriptorpb.DescriptorProto)
   698  		mdp.NestedType = nil
   699  		mdp.EnumType = nil
   700  	}
   701  	// sort children for deterministic output
   702  	var keys []string
   703  	for k := range t.children {
   704  		keys = append(keys, k)
   705  	}
   706  	for _, n := range keys {
   707  		ch := t.children[n]
   708  		typ, err := ch.rewriteDescriptor(n)
   709  		if err != nil {
   710  			return nil, err
   711  		}
   712  		switch typ := typ.(type) {
   713  		case (*descriptorpb.DescriptorProto):
   714  			mdp.NestedType = append(mdp.NestedType, typ)
   715  		case (*descriptorpb.EnumDescriptorProto):
   716  			mdp.EnumType = append(mdp.EnumType, typ)
   717  		default:
   718  			// TODO: this should probably panic instead
   719  			return nil, fmt.Errorf("invalid descriptor trie: message cannot have child of type %v", reflect.TypeOf(typ))
   720  		}
   721  	}
   722  	return mdp, nil
   723  }
   724  
   725  func split(s string) (string, string) {
   726  	pos := strings.Index(s, ".")
   727  	if pos >= 0 {
   728  		return s[:pos], s[pos+1:]
   729  	} else {
   730  		return s, ""
   731  	}
   732  }
   733  
   734  func createEnumDescriptor(e *typepb.Enum, mr *MessageRegistry) *descriptorpb.EnumDescriptorProto {
   735  	var opts *descriptorpb.EnumOptions
   736  	if len(e.Options) > 0 {
   737  		dopts := createOptions(e.Options, enumOptionsDesc, mr)
   738  		opts = &descriptorpb.EnumOptions{}
   739  		dopts.ConvertTo(opts) // ignore any error
   740  	}
   741  
   742  	var vals []*descriptorpb.EnumValueDescriptorProto
   743  	for _, v := range e.Enumvalue {
   744  		evd := createEnumValueDescriptor(v, mr)
   745  		vals = append(vals, evd)
   746  	}
   747  
   748  	return &descriptorpb.EnumDescriptorProto{
   749  		Name:    proto.String(base(e.Name)),
   750  		Options: opts,
   751  		Value:   vals,
   752  	}
   753  }
   754  
   755  func createEnumValueDescriptor(v *typepb.EnumValue, mr *MessageRegistry) *descriptorpb.EnumValueDescriptorProto {
   756  	var opts *descriptorpb.EnumValueOptions
   757  	if len(v.Options) > 0 {
   758  		dopts := createOptions(v.Options, enumValueOptionsDesc, mr)
   759  		opts = &descriptorpb.EnumValueOptions{}
   760  		dopts.ConvertTo(opts) // ignore any error
   761  	}
   762  
   763  	return &descriptorpb.EnumValueDescriptorProto{
   764  		Name:    proto.String(v.Name),
   765  		Number:  proto.Int32(v.Number),
   766  		Options: opts,
   767  	}
   768  }
   769  
   770  func createMessageDescriptor(m *typepb.Type, mr *MessageRegistry) *descriptorpb.DescriptorProto {
   771  	var opts *descriptorpb.MessageOptions
   772  	if len(m.Options) > 0 {
   773  		dopts := createOptions(m.Options, msgOptionsDesc, mr)
   774  		opts = &descriptorpb.MessageOptions{}
   775  		dopts.ConvertTo(opts) // ignore any error
   776  	}
   777  
   778  	var fields []*descriptorpb.FieldDescriptorProto
   779  	for _, f := range m.Fields {
   780  		fields = append(fields, createFieldDescriptor(f, mr))
   781  	}
   782  
   783  	var oneOfs []*descriptorpb.OneofDescriptorProto
   784  	for _, o := range m.Oneofs {
   785  		oneOfs = append(oneOfs, &descriptorpb.OneofDescriptorProto{
   786  			Name: proto.String(o),
   787  		})
   788  	}
   789  
   790  	return &descriptorpb.DescriptorProto{
   791  		Name:      proto.String(base(m.Name)),
   792  		Options:   opts,
   793  		Field:     fields,
   794  		OneofDecl: oneOfs,
   795  	}
   796  }
   797  
   798  func createFieldDescriptor(f *typepb.Field, mr *MessageRegistry) *descriptorpb.FieldDescriptorProto {
   799  	var opts *descriptorpb.FieldOptions
   800  	if len(f.Options) > 0 {
   801  		dopts := createOptions(f.Options, fieldOptionsDesc, mr)
   802  		opts = &descriptorpb.FieldOptions{}
   803  		dopts.ConvertTo(opts) // ignore any error
   804  	}
   805  	if f.Packed {
   806  		if opts == nil {
   807  			opts = &descriptorpb.FieldOptions{Packed: proto.Bool(true)}
   808  		} else {
   809  			opts.Packed = proto.Bool(true)
   810  		}
   811  	}
   812  
   813  	var oneOf *int32
   814  	if f.OneofIndex > 0 {
   815  		oneOf = proto.Int32(f.OneofIndex - 1)
   816  	}
   817  
   818  	var typeName string
   819  	if f.Kind == typepb.Field_TYPE_GROUP || f.Kind == typepb.Field_TYPE_MESSAGE || f.Kind == typepb.Field_TYPE_ENUM {
   820  		pos := strings.LastIndex(f.TypeUrl, "/")
   821  		typeName = "." + f.TypeUrl[pos+1:]
   822  	}
   823  
   824  	var label descriptorpb.FieldDescriptorProto_Label
   825  	switch f.Cardinality {
   826  	case typepb.Field_CARDINALITY_OPTIONAL:
   827  		label = descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL
   828  	case typepb.Field_CARDINALITY_REPEATED:
   829  		label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED
   830  	case typepb.Field_CARDINALITY_REQUIRED:
   831  		label = descriptorpb.FieldDescriptorProto_LABEL_REQUIRED
   832  	}
   833  
   834  	var typ descriptorpb.FieldDescriptorProto_Type
   835  	switch f.Kind {
   836  	case typepb.Field_TYPE_ENUM:
   837  		typ = descriptorpb.FieldDescriptorProto_TYPE_ENUM
   838  	case typepb.Field_TYPE_GROUP:
   839  		typ = descriptorpb.FieldDescriptorProto_TYPE_GROUP
   840  	case typepb.Field_TYPE_MESSAGE:
   841  		typ = descriptorpb.FieldDescriptorProto_TYPE_MESSAGE
   842  	case typepb.Field_TYPE_BYTES:
   843  		typ = descriptorpb.FieldDescriptorProto_TYPE_BYTES
   844  	case typepb.Field_TYPE_STRING:
   845  		typ = descriptorpb.FieldDescriptorProto_TYPE_STRING
   846  	case typepb.Field_TYPE_BOOL:
   847  		typ = descriptorpb.FieldDescriptorProto_TYPE_BOOL
   848  	case typepb.Field_TYPE_DOUBLE:
   849  		typ = descriptorpb.FieldDescriptorProto_TYPE_DOUBLE
   850  	case typepb.Field_TYPE_FLOAT:
   851  		typ = descriptorpb.FieldDescriptorProto_TYPE_FLOAT
   852  	case typepb.Field_TYPE_FIXED32:
   853  		typ = descriptorpb.FieldDescriptorProto_TYPE_FIXED32
   854  	case typepb.Field_TYPE_FIXED64:
   855  		typ = descriptorpb.FieldDescriptorProto_TYPE_FIXED64
   856  	case typepb.Field_TYPE_INT32:
   857  		typ = descriptorpb.FieldDescriptorProto_TYPE_INT32
   858  	case typepb.Field_TYPE_INT64:
   859  		typ = descriptorpb.FieldDescriptorProto_TYPE_INT64
   860  	case typepb.Field_TYPE_SFIXED32:
   861  		typ = descriptorpb.FieldDescriptorProto_TYPE_SFIXED32
   862  	case typepb.Field_TYPE_SFIXED64:
   863  		typ = descriptorpb.FieldDescriptorProto_TYPE_SFIXED64
   864  	case typepb.Field_TYPE_SINT32:
   865  		typ = descriptorpb.FieldDescriptorProto_TYPE_SINT32
   866  	case typepb.Field_TYPE_SINT64:
   867  		typ = descriptorpb.FieldDescriptorProto_TYPE_SINT64
   868  	case typepb.Field_TYPE_UINT32:
   869  		typ = descriptorpb.FieldDescriptorProto_TYPE_UINT32
   870  	case typepb.Field_TYPE_UINT64:
   871  		typ = descriptorpb.FieldDescriptorProto_TYPE_UINT64
   872  	}
   873  	var defaultVal *string
   874  	if f.DefaultValue != "" {
   875  		defaultVal = proto.String(f.DefaultValue)
   876  	}
   877  	return &descriptorpb.FieldDescriptorProto{
   878  		Name:         proto.String(f.Name),
   879  		Number:       proto.Int32(f.Number),
   880  		DefaultValue: defaultVal,
   881  		JsonName:     proto.String(f.JsonName),
   882  		OneofIndex:   oneOf,
   883  		TypeName:     proto.String(typeName),
   884  		Label:        label.Enum(),
   885  		Type:         typ.Enum(),
   886  		Options:      opts,
   887  	}
   888  }
   889  
   890  func createServiceDescriptor(a *apipb.Api, mr *MessageRegistry) *descriptorpb.ServiceDescriptorProto {
   891  	var opts *descriptorpb.ServiceOptions
   892  	if len(a.Options) > 0 {
   893  		dopts := createOptions(a.Options, svcOptionsDesc, mr)
   894  		opts = &descriptorpb.ServiceOptions{}
   895  		dopts.ConvertTo(opts) // ignore any error
   896  	}
   897  
   898  	methods := make([]*descriptorpb.MethodDescriptorProto, len(a.Methods))
   899  	for i, m := range a.Methods {
   900  		methods[i] = createMethodDescriptor(m, mr)
   901  	}
   902  
   903  	return &descriptorpb.ServiceDescriptorProto{
   904  		Name:    proto.String(base(a.Name)),
   905  		Method:  methods,
   906  		Options: opts,
   907  	}
   908  }
   909  
   910  func createMethodDescriptor(m *apipb.Method, mr *MessageRegistry) *descriptorpb.MethodDescriptorProto {
   911  	var opts *descriptorpb.MethodOptions
   912  	if len(m.Options) > 0 {
   913  		dopts := createOptions(m.Options, methodOptionsDesc, mr)
   914  		opts = &descriptorpb.MethodOptions{}
   915  		dopts.ConvertTo(opts) // ignore any error
   916  	}
   917  
   918  	var reqType, respType string
   919  	pos := strings.LastIndex(m.RequestTypeUrl, "/")
   920  	reqType = "." + m.RequestTypeUrl[pos+1:]
   921  	pos = strings.LastIndex(m.ResponseTypeUrl, "/")
   922  	respType = "." + m.ResponseTypeUrl[pos+1:]
   923  
   924  	return &descriptorpb.MethodDescriptorProto{
   925  		Name:            proto.String(m.Name),
   926  		Options:         opts,
   927  		ClientStreaming: proto.Bool(m.RequestStreaming),
   928  		ServerStreaming: proto.Bool(m.ResponseStreaming),
   929  		InputType:       proto.String(reqType),
   930  		OutputType:      proto.String(respType),
   931  	}
   932  }
   933  
   934  func createIntermediateMessageDescriptor(name string) *descriptorpb.DescriptorProto {
   935  	return &descriptorpb.DescriptorProto{
   936  		Name: proto.String(name),
   937  	}
   938  }
   939  
   940  func createFileDescriptor(name, pkg string, proto3 bool, deps map[string]struct{}) *descriptorpb.FileDescriptorProto {
   941  	imports := make([]string, 0, len(deps))
   942  	for k := range deps {
   943  		imports = append(imports, k)
   944  	}
   945  	sort.Strings(imports)
   946  	var syntax string
   947  	if proto3 {
   948  		syntax = "proto3"
   949  	} else {
   950  		syntax = "proto2"
   951  	}
   952  	return &descriptorpb.FileDescriptorProto{
   953  		Name:       proto.String(name),
   954  		Package:    proto.String(pkg),
   955  		Syntax:     proto.String(syntax),
   956  		Dependency: imports,
   957  	}
   958  }
   959  
   960  func createOptions(options []*typepb.Option, optionsDesc *desc.MessageDescriptor, mr *MessageRegistry) *dynamic.Message {
   961  	// these are created "best effort" so entries which are unresolvable
   962  	// (or seemingly invalid) are simply ignored...
   963  	dopts := mr.mf.NewDynamicMessage(optionsDesc)
   964  	for _, o := range options {
   965  		field := optionsDesc.FindFieldByName(o.Name)
   966  		if field == nil {
   967  			field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), o.Name)
   968  			if field == nil && o.Name[0] != '[' {
   969  				field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), fmt.Sprintf("[%s]", o.Name))
   970  			}
   971  			if field == nil {
   972  				// can't resolve option name? skip it
   973  				continue
   974  			}
   975  		}
   976  		v, err := mr.unmarshalAny(o.Value, func(url string) (*desc.MessageDescriptor, error) {
   977  			// we don't want to try to recursively fetch this value's type, so if it doesn't
   978  			// match the type of the extension field, we'll skip it
   979  			if (field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP ||
   980  				field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) &&
   981  				typeName(url) == field.GetMessageType().GetFullyQualifiedName() {
   982  
   983  				return field.GetMessageType(), nil
   984  			}
   985  			return nil, nil
   986  		})
   987  		if err != nil {
   988  			// can't interpret value? skip it
   989  			continue
   990  		}
   991  		var fv interface{}
   992  		if field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_GROUP {
   993  			fv = unwrap(v)
   994  			if v == nil {
   995  				// non-wrapper type for scalar field? skip it
   996  				continue
   997  			}
   998  		} else {
   999  			fv = v
  1000  		}
  1001  		if field.IsRepeated() {
  1002  			dopts.TryAddRepeatedField(field, fv) // ignore any error
  1003  		} else {
  1004  			dopts.TrySetField(field, fv) // ignore any error
  1005  		}
  1006  	}
  1007  	return dopts
  1008  }
  1009  
  1010  func base(name string) string {
  1011  	pos := strings.LastIndex(name, ".")
  1012  	if pos >= 0 {
  1013  		return name[pos+1:]
  1014  	}
  1015  	return name
  1016  }
  1017  
  1018  func unwrap(msg proto.Message) interface{} {
  1019  	switch m := msg.(type) {
  1020  	case (*wrapperspb.BoolValue):
  1021  		return m.Value
  1022  	case (*wrapperspb.FloatValue):
  1023  		return m.Value
  1024  	case (*wrapperspb.DoubleValue):
  1025  		return m.Value
  1026  	case (*wrapperspb.Int32Value):
  1027  		return m.Value
  1028  	case (*wrapperspb.Int64Value):
  1029  		return m.Value
  1030  	case (*wrapperspb.UInt32Value):
  1031  		return m.Value
  1032  	case (*wrapperspb.UInt64Value):
  1033  		return m.Value
  1034  	case (*wrapperspb.BytesValue):
  1035  		return m.Value
  1036  	case (*wrapperspb.StringValue):
  1037  		return m.Value
  1038  	default:
  1039  		return nil
  1040  	}
  1041  }
  1042  
  1043  func typeName(url string) string {
  1044  	return url[strings.LastIndex(url, "/")+1:]
  1045  }