github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/message_registry.go (about)

     1  package msgregistry
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/golang/protobuf/jsonpb"
    10  	"github.com/golang/protobuf/proto"
    11  	"github.com/golang/protobuf/ptypes"
    12  	"google.golang.org/protobuf/types/descriptorpb"
    13  	"google.golang.org/protobuf/types/known/anypb"
    14  	"google.golang.org/protobuf/types/known/apipb"
    15  	"google.golang.org/protobuf/types/known/sourcecontextpb"
    16  	"google.golang.org/protobuf/types/known/typepb"
    17  	"google.golang.org/protobuf/types/known/wrapperspb"
    18  
    19  	"github.com/jhump/protoreflect/desc"
    20  	"github.com/jhump/protoreflect/dynamic"
    21  )
    22  
    23  const googleApisDomain = "type.googleapis.com"
    24  
    25  // ErrUnexpectedType is returned if the URL that was requested
    26  // resolved to an enum instead of a message, or vice versa.
    27  type ErrUnexpectedType struct {
    28  	URL          string
    29  	ShouldBeEnum bool
    30  }
    31  
    32  func (e *ErrUnexpectedType) Error() string {
    33  	msg := "wanted message, got enum"
    34  	if e.ShouldBeEnum {
    35  		msg = "wanted enum, got message"
    36  	}
    37  	return fmt.Sprintf("type for URL %q is the wrong type: %s", e.URL, msg)
    38  }
    39  
    40  // MessageRegistry is a registry that maps URLs to message types. It allows for marshalling
    41  // and unmarshalling Any types to and from dynamic messages.
    42  type MessageRegistry struct {
    43  	resolver       typeResolver
    44  	mf             *dynamic.MessageFactory
    45  	er             *dynamic.ExtensionRegistry
    46  	mu             sync.RWMutex
    47  	types          map[string]desc.Descriptor
    48  	baseUrls       map[string]string
    49  	defaultBaseUrl string
    50  }
    51  
    52  // NewMessageRegistryWithDefaults is a registry that includes all "default" message types,
    53  // which are those that are statically linked into the current program (e.g. registered by
    54  // protoc-generated code via proto.RegisterType). Note that it cannot resolve "default" enum
    55  // types since those don't actually get registered by protoc-generated code the same way.
    56  // Any types explicitly added to the registry will override any default message types with
    57  // the same URL.
    58  func NewMessageRegistryWithDefaults() *MessageRegistry {
    59  	mf := dynamic.NewMessageFactoryWithDefaults()
    60  	return &MessageRegistry{
    61  		mf: mf,
    62  		er: mf.GetExtensionRegistry(),
    63  	}
    64  }
    65  
    66  // WithFetcher sets the TypeFetcher that this registry uses to resolve unknown URLs. If no fetcher
    67  // is configured for the registry then unknown URLs cannot be resolved. Known URLs are those for
    68  // explicitly registered types and, if the registry includes "default" types, those for statically
    69  // linked message types. This method is not thread-safe and is intended to be used for one-time
    70  // initialization of the registry, before it is published for use by other threads.
    71  func (r *MessageRegistry) WithFetcher(fetcher TypeFetcher) *MessageRegistry {
    72  	r.resolver = typeResolver{fetcher: fetcher, mr: r}
    73  	return r
    74  }
    75  
    76  // WithMessageFactory sets the MessageFactory used to instantiate any messages.
    77  // This method is not thread-safe and is intended to be used for one-time
    78  // initialization of the registry, before it is published for use by other threads.
    79  func (r *MessageRegistry) WithMessageFactory(mf *dynamic.MessageFactory) *MessageRegistry {
    80  	r.mf = mf
    81  	if mf == nil {
    82  		r.er = nil
    83  	} else {
    84  		r.er = mf.GetExtensionRegistry()
    85  	}
    86  	return r
    87  }
    88  
    89  // WithDefaultBaseUrl sets the default base URL used when constructing type URLs for
    90  // marshalling messages as Any types and converting descriptors to well-known type
    91  // descriptions (ptypes). If unspecified, the default base URL will be "type.googleapis.com".
    92  // This method is not thread-safe and is intended to be used for one-time initialization
    93  // of the registry, before it is published for use by other threads.
    94  func (r *MessageRegistry) WithDefaultBaseUrl(baseUrl string) *MessageRegistry {
    95  	baseUrl = stripTrailingSlash(baseUrl)
    96  	r.defaultBaseUrl = baseUrl
    97  	return r
    98  }
    99  
   100  func stripTrailingSlash(url string) string {
   101  	if url[len(url)-1] == '/' {
   102  		return url[:len(url)-1]
   103  	}
   104  	return url
   105  }
   106  
   107  // AddMessage adds the given URL and associated message descriptor to the registry.
   108  func (r *MessageRegistry) AddMessage(url string, md *desc.MessageDescriptor) error {
   109  	url = ensureScheme(url)
   110  	baseUrl := strings.TrimSuffix(url, "/"+md.GetFullyQualifiedName())
   111  	if url == baseUrl {
   112  		return fmt.Errorf("URL %s is invalid: it should end with path element %s", url, md.GetFullyQualifiedName())
   113  	}
   114  	r.mu.Lock()
   115  	defer r.mu.Unlock()
   116  	if r.types == nil {
   117  		r.types = map[string]desc.Descriptor{}
   118  	}
   119  	r.types[url] = md
   120  	if r.baseUrls == nil {
   121  		r.baseUrls = map[string]string{}
   122  	}
   123  	r.baseUrls[md.GetFullyQualifiedName()] = baseUrl
   124  	return nil
   125  }
   126  
   127  // AddEnum adds the given URL and associated enum descriptor to the registry.
   128  func (r *MessageRegistry) AddEnum(url string, ed *desc.EnumDescriptor) error {
   129  	url = ensureScheme(url)
   130  	baseUrl := strings.TrimSuffix(url, "/"+ed.GetFullyQualifiedName())
   131  	if url == baseUrl {
   132  		return fmt.Errorf("URL %s is invalid: it should end with path element %s", url, ed.GetFullyQualifiedName())
   133  	}
   134  	r.mu.Lock()
   135  	defer r.mu.Unlock()
   136  	if r.types == nil {
   137  		r.types = map[string]desc.Descriptor{}
   138  	}
   139  	r.types[url] = ed
   140  	if r.baseUrls == nil {
   141  		r.baseUrls = map[string]string{}
   142  	}
   143  	r.baseUrls[ed.GetFullyQualifiedName()] = baseUrl
   144  	return nil
   145  }
   146  
   147  // AddFile adds to the registry all message and enum types in the given file. The URL for each type
   148  // is derived using the given base URL as "baseURL/fully.qualified.type.name".
   149  func (r *MessageRegistry) AddFile(baseUrl string, fd *desc.FileDescriptor) {
   150  	baseUrl = stripTrailingSlash(ensureScheme(baseUrl))
   151  	r.mu.Lock()
   152  	defer r.mu.Unlock()
   153  	if r.types == nil {
   154  		r.types = map[string]desc.Descriptor{}
   155  	}
   156  	if r.baseUrls == nil {
   157  		r.baseUrls = map[string]string{}
   158  	}
   159  	r.addEnumTypesLocked(baseUrl, fd.GetEnumTypes())
   160  	r.addMessageTypesLocked(baseUrl, fd.GetMessageTypes())
   161  }
   162  
   163  func (r *MessageRegistry) addEnumTypesLocked(baseUrl string, enums []*desc.EnumDescriptor) {
   164  	for _, ed := range enums {
   165  		url := fmt.Sprintf("%s/%s", baseUrl, ed.GetFullyQualifiedName())
   166  		r.types[url] = ed
   167  		r.baseUrls[ed.GetFullyQualifiedName()] = baseUrl
   168  	}
   169  }
   170  
   171  func (r *MessageRegistry) addMessageTypesLocked(baseUrl string, msgs []*desc.MessageDescriptor) {
   172  	for _, md := range msgs {
   173  		url := fmt.Sprintf("%s/%s", baseUrl, md.GetFullyQualifiedName())
   174  		r.types[url] = md
   175  		r.baseUrls[md.GetFullyQualifiedName()] = baseUrl
   176  		r.addEnumTypesLocked(baseUrl, md.GetNestedEnumTypes())
   177  		r.addMessageTypesLocked(baseUrl, md.GetNestedMessageTypes())
   178  	}
   179  }
   180  
   181  // FindMessageTypeByUrl finds a message descriptor for the type at the given URL. It may
   182  // return nil if the registry is empty and cannot resolve unknown URLs. If an error occurs
   183  // while resolving the URL, it is returned. If the resolved type is a enum, ErrUnexpectedType
   184  // is returned.
   185  func (r *MessageRegistry) FindMessageTypeByUrl(url string) (*desc.MessageDescriptor, error) {
   186  	md, err := r.getRegisteredMessageTypeByUrl(url)
   187  	if err != nil {
   188  		return nil, err
   189  	} else if md != nil {
   190  		return md, err
   191  	}
   192  
   193  	if r.resolver.fetcher == nil {
   194  		return nil, nil
   195  	}
   196  	return r.resolver.resolveUrlToMessageDescriptor(url)
   197  }
   198  
   199  func (r *MessageRegistry) getRegisteredMessageTypeByUrl(url string) (*desc.MessageDescriptor, error) {
   200  	if r != nil {
   201  		r.mu.RLock()
   202  		m := r.types[ensureScheme(url)]
   203  		r.mu.RUnlock()
   204  		if m != nil {
   205  			if md, ok := m.(*desc.MessageDescriptor); ok {
   206  				return md, nil
   207  			}
   208  			return nil, &ErrUnexpectedType{
   209  				URL: url,
   210  			}
   211  		}
   212  	}
   213  
   214  	var ktr *dynamic.KnownTypeRegistry
   215  	if r != nil {
   216  		ktr = r.mf.GetKnownTypeRegistry()
   217  	}
   218  	msgType := ktr.GetKnownType(typeName(url))
   219  	if msgType == nil {
   220  		return nil, nil
   221  	}
   222  	return desc.LoadMessageDescriptorForType(msgType)
   223  }
   224  
   225  // FindEnumTypeByUrl finds an enum descriptor for the type at the given URL. It may return nil
   226  // if the registry is empty and cannot resolve unknown URLs. If an error occurs while resolving
   227  // the URL, it is returned. If the resolved type is a message, ErrUnexpectedType is returned.
   228  func (r *MessageRegistry) FindEnumTypeByUrl(url string) (*desc.EnumDescriptor, error) {
   229  	ed, err := r.getRegisteredEnumTypeByUrl(url)
   230  	if err != nil {
   231  		return nil, err
   232  	} else if ed != nil {
   233  		return ed, err
   234  	}
   235  
   236  	if r.resolver.fetcher == nil {
   237  		return nil, nil
   238  	}
   239  	if ed, err := r.resolver.resolveUrlToEnumDescriptor(url); err != nil {
   240  		return nil, err
   241  	} else {
   242  		return ed, nil
   243  	}
   244  }
   245  
   246  func (r *MessageRegistry) getRegisteredEnumTypeByUrl(url string) (*desc.EnumDescriptor, error) {
   247  	if r == nil {
   248  		return nil, nil
   249  	}
   250  	r.mu.RLock()
   251  	m := r.types[ensureScheme(url)]
   252  	r.mu.RUnlock()
   253  	if m != nil {
   254  		if ed, ok := m.(*desc.EnumDescriptor); ok {
   255  			return ed, nil
   256  		}
   257  		return nil, &ErrUnexpectedType{
   258  			URL: url,
   259  		}
   260  	}
   261  	return nil, nil
   262  }
   263  
   264  // ResolveApiIntoServiceDescriptor constructs a service descriptor that describes the given API.
   265  // If any of the service's request or response type URLs cannot be resolved by this registry, a
   266  // nil descriptor is returned.
   267  func (r *MessageRegistry) ResolveApiIntoServiceDescriptor(a *apipb.Api) (*desc.ServiceDescriptor, error) {
   268  	if r == nil {
   269  		return nil, nil
   270  	}
   271  
   272  	msgs := map[string]*desc.MessageDescriptor{}
   273  	unresolved := map[string]struct{}{}
   274  	for _, m := range a.Methods {
   275  		// request type
   276  		md, err := r.getRegisteredMessageTypeByUrl(m.RequestTypeUrl)
   277  		if err != nil {
   278  			return nil, err
   279  		} else if md == nil {
   280  			if r.resolver.fetcher == nil {
   281  				return nil, nil
   282  			}
   283  			unresolved[m.RequestTypeUrl] = struct{}{}
   284  		} else {
   285  			msgs[m.RequestTypeUrl] = md
   286  		}
   287  		// and response type
   288  		md, err = r.getRegisteredMessageTypeByUrl(m.ResponseTypeUrl)
   289  		if err != nil {
   290  			return nil, err
   291  		} else if md == nil {
   292  			if r.resolver.fetcher == nil {
   293  				return nil, nil
   294  			}
   295  			unresolved[m.ResponseTypeUrl] = struct{}{}
   296  		} else {
   297  			msgs[m.ResponseTypeUrl] = md
   298  		}
   299  	}
   300  
   301  	if len(unresolved) > 0 {
   302  		unresolvedSlice := make([]string, 0, len(unresolved))
   303  		for k := range unresolved {
   304  			unresolvedSlice = append(unresolvedSlice, k)
   305  		}
   306  		mp, err := r.resolver.resolveUrlsToMessageDescriptors(unresolvedSlice...)
   307  		if err != nil {
   308  			return nil, err
   309  		}
   310  		for u, md := range mp {
   311  			msgs[u] = md
   312  		}
   313  	}
   314  
   315  	var fileName string
   316  	if a.SourceContext != nil && a.SourceContext.FileName != "" {
   317  		fileName = a.SourceContext.FileName
   318  	} else {
   319  		fileName = "--unknown--.proto"
   320  	}
   321  
   322  	// now we add all types we care about to a typeTrie and use that to generate file descriptors
   323  	files := map[string]*fileEntry{}
   324  	fe := &fileEntry{}
   325  	fe.proto3 = a.Syntax == typepb.Syntax_SYNTAX_PROTO3
   326  	files[fileName] = fe
   327  	fe.types.addType(a.Name, createServiceDescriptor(a, r))
   328  	added := newNameTracker()
   329  	for _, md := range msgs {
   330  		addDescriptors(fileName, files, md, msgs, added)
   331  	}
   332  
   333  	// build resulting file descriptor(s) and return the final service descriptor
   334  	fileDescriptors, err := toFileDescriptors(files, (*typeTrie).rewriteDescriptor)
   335  	if err != nil {
   336  		return nil, err
   337  	}
   338  	return fileDescriptors[fileName].FindService(a.Name), nil
   339  }
   340  
   341  // UnmarshalAny will unmarshal the value embedded in the given Any value. This will use this
   342  // registry to resolve the given value's type URL. Use this instead of ptypes.UnmarshalAny for
   343  // cases where the type might not be statically linked into the current program.
   344  func (r *MessageRegistry) UnmarshalAny(a *anypb.Any) (proto.Message, error) {
   345  	return r.unmarshalAny(a, r.FindMessageTypeByUrl)
   346  }
   347  
   348  func (r *MessageRegistry) unmarshalAny(a *anypb.Any, fetch func(string) (*desc.MessageDescriptor, error)) (proto.Message, error) {
   349  	name, err := ptypes.AnyMessageName(a)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	var msg proto.Message
   355  
   356  	var mf *dynamic.MessageFactory
   357  	var ktr *dynamic.KnownTypeRegistry
   358  	if r != nil {
   359  		mf = r.mf
   360  		ktr = r.mf.GetKnownTypeRegistry()
   361  	}
   362  	if msg = ktr.CreateIfKnown(name); msg == nil {
   363  		if md, err := fetch(a.TypeUrl); err != nil {
   364  			return nil, err
   365  		} else if md == nil {
   366  			return nil, fmt.Errorf("unknown message type: %s", a.TypeUrl)
   367  		} else {
   368  			msg = mf.NewDynamicMessage(md)
   369  		}
   370  	}
   371  
   372  	err = proto.Unmarshal(a.Value, msg)
   373  	if err != nil {
   374  		return nil, err
   375  	} else {
   376  		return msg, nil
   377  	}
   378  }
   379  
   380  // AddBaseUrlForElement adds a base URL for the given package or fully-qualified type name.
   381  // This is used to construct type URLs for message types. If a given type has an associated
   382  // base URL, it is used. Otherwise, the base URL for the type's package is used. If that is
   383  // also absent, the registry's default base URL is used.
   384  func (r *MessageRegistry) AddBaseUrlForElement(baseUrl, packageOrTypeName string) {
   385  	if baseUrl[len(baseUrl)-1] == '/' {
   386  		baseUrl = baseUrl[:len(baseUrl)-1]
   387  	}
   388  	r.mu.Lock()
   389  	defer r.mu.Unlock()
   390  	if r.baseUrls == nil {
   391  		r.baseUrls = map[string]string{}
   392  	}
   393  	r.baseUrls[packageOrTypeName] = baseUrl
   394  }
   395  
   396  // MarshalAny wraps the given message in an Any value.
   397  func (r *MessageRegistry) MarshalAny(m proto.Message) (*anypb.Any, error) {
   398  	var md *desc.MessageDescriptor
   399  	if dm, ok := m.(*dynamic.Message); ok {
   400  		md = dm.GetMessageDescriptor()
   401  	} else {
   402  		var err error
   403  		md, err = desc.LoadMessageDescriptorForMessage(m)
   404  		if err != nil {
   405  			return nil, err
   406  		}
   407  	}
   408  
   409  	if b, err := proto.Marshal(m); err != nil {
   410  		return nil, err
   411  	} else {
   412  		return &anypb.Any{TypeUrl: r.ComputeUrl(md), Value: b}, nil
   413  	}
   414  }
   415  
   416  // MessageAsPType converts the given message descriptor into a ptype.Type. Registered
   417  // base URLs are used to compute type URLs for any fields that have message or enum
   418  // types.
   419  func (r *MessageRegistry) MessageAsPType(md *desc.MessageDescriptor) *typepb.Type {
   420  	fs := md.GetFields()
   421  	fields := make([]*typepb.Field, len(fs))
   422  	for i, f := range fs {
   423  		fields[i] = r.fieldAsPType(f)
   424  	}
   425  	oos := md.GetOneOfs()
   426  	oneOfs := make([]string, len(oos))
   427  	for i, oo := range oos {
   428  		oneOfs[i] = oo.GetName()
   429  	}
   430  	return &typepb.Type{
   431  		Name:          md.GetFullyQualifiedName(),
   432  		Fields:        fields,
   433  		Oneofs:        oneOfs,
   434  		Options:       r.options(md.GetOptions()),
   435  		Syntax:        syntax(md.GetFile()),
   436  		SourceContext: &sourcecontextpb.SourceContext{FileName: md.GetFile().GetName()},
   437  	}
   438  }
   439  
   440  func (r *MessageRegistry) fieldAsPType(fd *desc.FieldDescriptor) *typepb.Field {
   441  	opts := r.options(fd.GetOptions())
   442  	// remove the "packed" option as that is represented via separate field in ptype.Field
   443  	for i, o := range opts {
   444  		if o.Name == "packed" {
   445  			opts = append(opts[:i], opts[i+1:]...)
   446  			break
   447  		}
   448  	}
   449  
   450  	var oneOf int32
   451  	if fd.AsFieldDescriptorProto().OneofIndex != nil {
   452  		oneOf = fd.AsFieldDescriptorProto().GetOneofIndex() + 1
   453  	}
   454  
   455  	var card typepb.Field_Cardinality
   456  	switch fd.GetLabel() {
   457  	case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL:
   458  		card = typepb.Field_CARDINALITY_OPTIONAL
   459  	case descriptorpb.FieldDescriptorProto_LABEL_REPEATED:
   460  		card = typepb.Field_CARDINALITY_REPEATED
   461  	case descriptorpb.FieldDescriptorProto_LABEL_REQUIRED:
   462  		card = typepb.Field_CARDINALITY_REQUIRED
   463  	}
   464  
   465  	var url string
   466  	var kind typepb.Field_Kind
   467  	switch fd.GetType() {
   468  	case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
   469  		kind = typepb.Field_TYPE_ENUM
   470  		url = r.ComputeUrl(fd.GetEnumType())
   471  	case descriptorpb.FieldDescriptorProto_TYPE_GROUP:
   472  		kind = typepb.Field_TYPE_GROUP
   473  		url = r.ComputeUrl(fd.GetMessageType())
   474  	case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
   475  		kind = typepb.Field_TYPE_MESSAGE
   476  		url = r.ComputeUrl(fd.GetMessageType())
   477  	case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
   478  		kind = typepb.Field_TYPE_BYTES
   479  	case descriptorpb.FieldDescriptorProto_TYPE_STRING:
   480  		kind = typepb.Field_TYPE_STRING
   481  	case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   482  		kind = typepb.Field_TYPE_BOOL
   483  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
   484  		kind = typepb.Field_TYPE_DOUBLE
   485  	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
   486  		kind = typepb.Field_TYPE_FLOAT
   487  	case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
   488  		kind = typepb.Field_TYPE_FIXED32
   489  	case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
   490  		kind = typepb.Field_TYPE_FIXED64
   491  	case descriptorpb.FieldDescriptorProto_TYPE_INT32:
   492  		kind = typepb.Field_TYPE_INT32
   493  	case descriptorpb.FieldDescriptorProto_TYPE_INT64:
   494  		kind = typepb.Field_TYPE_INT64
   495  	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
   496  		kind = typepb.Field_TYPE_SFIXED32
   497  	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
   498  		kind = typepb.Field_TYPE_SFIXED64
   499  	case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
   500  		kind = typepb.Field_TYPE_SINT32
   501  	case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   502  		kind = typepb.Field_TYPE_SINT64
   503  	case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
   504  		kind = typepb.Field_TYPE_UINT32
   505  	case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
   506  		kind = typepb.Field_TYPE_UINT64
   507  	}
   508  
   509  	return &typepb.Field{
   510  		Name:         fd.GetName(),
   511  		Number:       fd.GetNumber(),
   512  		JsonName:     fd.AsFieldDescriptorProto().GetJsonName(),
   513  		OneofIndex:   oneOf,
   514  		DefaultValue: fd.AsFieldDescriptorProto().GetDefaultValue(),
   515  		Options:      opts,
   516  		Packed:       fd.GetFieldOptions().GetPacked(),
   517  		TypeUrl:      url,
   518  		Cardinality:  card,
   519  		Kind:         kind,
   520  	}
   521  }
   522  
   523  // EnumAsPType converts the given enum descriptor into a ptype.Enum.
   524  func (r *MessageRegistry) EnumAsPType(ed *desc.EnumDescriptor) *typepb.Enum {
   525  	vs := ed.GetValues()
   526  	vals := make([]*typepb.EnumValue, len(vs))
   527  	for i, v := range vs {
   528  		vals[i] = r.enumValueAsPType(v)
   529  	}
   530  	return &typepb.Enum{
   531  		Name:          ed.GetFullyQualifiedName(),
   532  		Enumvalue:     vals,
   533  		Options:       r.options(ed.GetOptions()),
   534  		Syntax:        syntax(ed.GetFile()),
   535  		SourceContext: &sourcecontextpb.SourceContext{FileName: ed.GetFile().GetName()},
   536  	}
   537  }
   538  
   539  func (r *MessageRegistry) enumValueAsPType(vd *desc.EnumValueDescriptor) *typepb.EnumValue {
   540  	return &typepb.EnumValue{
   541  		Name:    vd.GetName(),
   542  		Number:  vd.GetNumber(),
   543  		Options: r.options(vd.GetOptions()),
   544  	}
   545  }
   546  
   547  // ServiceAsApi converts the given service descriptor into a ptype API description.
   548  func (r *MessageRegistry) ServiceAsApi(sd *desc.ServiceDescriptor) *apipb.Api {
   549  	ms := sd.GetMethods()
   550  	methods := make([]*apipb.Method, len(ms))
   551  	for i, m := range ms {
   552  		methods[i] = r.methodAsApi(m)
   553  	}
   554  	return &apipb.Api{
   555  		Name:          sd.GetFullyQualifiedName(),
   556  		Methods:       methods,
   557  		Options:       r.options(sd.GetOptions()),
   558  		Syntax:        syntax(sd.GetFile()),
   559  		SourceContext: &sourcecontextpb.SourceContext{FileName: sd.GetFile().GetName()},
   560  	}
   561  }
   562  
   563  func (r *MessageRegistry) methodAsApi(md *desc.MethodDescriptor) *apipb.Method {
   564  	return &apipb.Method{
   565  		Name:              md.GetName(),
   566  		RequestStreaming:  md.IsClientStreaming(),
   567  		ResponseStreaming: md.IsServerStreaming(),
   568  		RequestTypeUrl:    r.ComputeUrl(md.GetInputType()),
   569  		ResponseTypeUrl:   r.ComputeUrl(md.GetOutputType()),
   570  		Options:           r.options(md.GetOptions()),
   571  		Syntax:            syntax(md.GetFile()),
   572  	}
   573  }
   574  
   575  func (r *MessageRegistry) options(options proto.Message) []*typepb.Option {
   576  	rv := reflect.ValueOf(options)
   577  	if rv.Kind() == reflect.Ptr {
   578  		if rv.IsNil() {
   579  			return nil
   580  		}
   581  		rv = rv.Elem()
   582  	}
   583  	var opts []*typepb.Option
   584  	for _, p := range proto.GetProperties(rv.Type()).Prop {
   585  		if p.Tag == 0 {
   586  			continue
   587  		}
   588  		o := r.option(p.OrigName, rv.FieldByName(p.Name))
   589  		if o != nil {
   590  			opts = append(opts, o...)
   591  		}
   592  	}
   593  	for _, ext := range proto.RegisteredExtensions(options) {
   594  		if proto.HasExtension(options, ext) {
   595  			v, err := proto.GetExtension(options, ext)
   596  			if err == nil && v != nil {
   597  				o := r.option(ext.Name, reflect.ValueOf(v))
   598  				if o != nil {
   599  					opts = append(opts, o...)
   600  				}
   601  			}
   602  		}
   603  	}
   604  	return opts
   605  }
   606  
   607  var typeOfBytes = reflect.TypeOf([]byte(nil))
   608  
   609  func (r *MessageRegistry) option(name string, value reflect.Value) []*typepb.Option {
   610  	if value.Kind() == reflect.Slice && value.Type() != typeOfBytes {
   611  		// repeated field
   612  		ret := make([]*typepb.Option, value.Len())
   613  		j := 0
   614  		for i := 0; i < value.Len(); i++ {
   615  			opt := r.singleOption(name, value.Index(i))
   616  			if opt != nil {
   617  				ret[j] = opt
   618  				j++
   619  			}
   620  		}
   621  		return ret[:j]
   622  	} else {
   623  		opt := r.singleOption(name, value)
   624  		if opt != nil {
   625  			return []*typepb.Option{opt}
   626  		}
   627  		return nil
   628  	}
   629  }
   630  
   631  func (r *MessageRegistry) singleOption(name string, value reflect.Value) *typepb.Option {
   632  	pm := wrap(value)
   633  	if pm == nil {
   634  		return nil
   635  	}
   636  	a, err := r.MarshalAny(pm)
   637  	if err != nil {
   638  		return nil
   639  	}
   640  	return &typepb.Option{
   641  		Name:  name,
   642  		Value: a,
   643  	}
   644  }
   645  
   646  func wrap(v reflect.Value) proto.Message {
   647  	if pm, ok := v.Interface().(proto.Message); ok {
   648  		return pm
   649  	}
   650  	if !v.IsValid() {
   651  		return nil
   652  	}
   653  	if v.Kind() == reflect.Ptr {
   654  		if v.IsNil() {
   655  			return nil
   656  		}
   657  		v = v.Elem()
   658  	}
   659  	switch v.Kind() {
   660  	case reflect.Bool:
   661  		return &wrapperspb.BoolValue{Value: v.Bool()}
   662  	case reflect.Slice:
   663  		if v.Type() != typeOfBytes {
   664  			panic(fmt.Sprintf("cannot convert/wrap %T as proto", v.Type()))
   665  		}
   666  		return &wrapperspb.BytesValue{Value: v.Bytes()}
   667  	case reflect.String:
   668  		return &wrapperspb.StringValue{Value: v.String()}
   669  	case reflect.Float32:
   670  		return &wrapperspb.FloatValue{Value: float32(v.Float())}
   671  	case reflect.Float64:
   672  		return &wrapperspb.DoubleValue{Value: v.Float()}
   673  	case reflect.Int32:
   674  		return &wrapperspb.Int32Value{Value: int32(v.Int())}
   675  	case reflect.Int64:
   676  		return &wrapperspb.Int64Value{Value: v.Int()}
   677  	case reflect.Uint32:
   678  		return &wrapperspb.UInt32Value{Value: uint32(v.Uint())}
   679  	case reflect.Uint64:
   680  		return &wrapperspb.UInt64Value{Value: v.Uint()}
   681  	default:
   682  		panic(fmt.Sprintf("cannot convert/wrap %T as proto", v.Type()))
   683  	}
   684  }
   685  
   686  func syntax(fd *desc.FileDescriptor) typepb.Syntax {
   687  	if fd.IsProto3() {
   688  		return typepb.Syntax_SYNTAX_PROTO3
   689  	} else {
   690  		return typepb.Syntax_SYNTAX_PROTO2
   691  	}
   692  }
   693  
   694  // ComputeUrl computes a type URL for element described by the given descriptor.
   695  // The given descriptor must be an enum or message descriptor. This will use any
   696  // registered URLs and base URLs to determine the appropriate URL for the given
   697  // type.
   698  //
   699  // Deprecated: This method is deprecated due to its use of non-idiomatic naming.
   700  // Use ComputeURL instead.
   701  func (r *MessageRegistry) ComputeUrl(d desc.Descriptor) string {
   702  	return r.ComputeURL(d)
   703  }
   704  
   705  // ComputeURL computes a type URL string for the element described by the given
   706  // descriptor. The given descriptor must be an enum or message descriptor. This
   707  // will use any registered URLs and base URLs to determine the appropriate URL
   708  // for the given type.
   709  func (r *MessageRegistry) ComputeURL(d desc.Descriptor) string {
   710  	name, pkg := d.GetFullyQualifiedName(), d.GetFile().GetPackage()
   711  	r.mu.RLock()
   712  	baseUrl := r.baseUrls[name]
   713  	if baseUrl == "" {
   714  		// lookup domain for the package
   715  		baseUrl = r.baseUrls[pkg]
   716  	}
   717  	r.mu.RUnlock()
   718  
   719  	if baseUrl == "" {
   720  		baseUrl = r.defaultBaseUrl
   721  		if baseUrl == "" {
   722  			baseUrl = googleApisDomain
   723  		}
   724  	}
   725  
   726  	return fmt.Sprintf("%s/%s", baseUrl, name)
   727  }
   728  
   729  // Resolve resolves the given type URL into an instance of a message. This
   730  // implements the jsonpb.AnyResolver interface, for use with marshaling and
   731  // unmarshaling Any messages to/from JSON.
   732  func (r *MessageRegistry) Resolve(typeUrl string) (proto.Message, error) {
   733  	md, err := r.FindMessageTypeByUrl(typeUrl)
   734  	if err != nil {
   735  		return nil, err
   736  	}
   737  	if md == nil {
   738  		return nil, fmt.Errorf("unknown message type: %s", typeUrl)
   739  	}
   740  	return r.mf.NewMessage(md), nil
   741  }
   742  
   743  var _ jsonpb.AnyResolver = (*MessageRegistry)(nil)