github.com/syumai/protoreflect@v1.7.1-0.20200810020253-2ac7e3b3a321/desc/load.go (about)

     1  package desc
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"github.com/golang/protobuf/proto"
     9  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    10  
    11  	"github.com/syumai/protoreflect/internal"
    12  )
    13  
    14  var (
    15  	cacheMu       sync.RWMutex
    16  	filesCache    = map[string]*FileDescriptor{}
    17  	messagesCache = map[string]*MessageDescriptor{}
    18  	enumCache     = map[reflect.Type]*EnumDescriptor{}
    19  )
    20  
    21  // LoadFileDescriptor creates a file descriptor using the bytes returned by
    22  // proto.FileDescriptor. Descriptors are cached so that they do not need to be
    23  // re-processed if the same file is fetched again later.
    24  func LoadFileDescriptor(file string) (*FileDescriptor, error) {
    25  	return loadFileDescriptor(file, nil)
    26  }
    27  
    28  func loadFileDescriptor(file string, r *ImportResolver) (*FileDescriptor, error) {
    29  	f := getFileFromCache(file)
    30  	if f != nil {
    31  		return f, nil
    32  	}
    33  	cacheMu.Lock()
    34  	defer cacheMu.Unlock()
    35  	return loadFileDescriptorLocked(file, r)
    36  }
    37  
    38  func loadFileDescriptorLocked(file string, r *ImportResolver) (*FileDescriptor, error) {
    39  	f := filesCache[file]
    40  	if f != nil {
    41  		return f, nil
    42  	}
    43  	fd, err := internal.LoadFileDescriptor(file)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	f, err = toFileDescriptorLocked(fd, r)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	putCacheLocked(file, f)
    53  	return f, nil
    54  }
    55  
    56  func toFileDescriptorLocked(fd *dpb.FileDescriptorProto, r *ImportResolver) (*FileDescriptor, error) {
    57  	deps := make([]*FileDescriptor, len(fd.GetDependency()))
    58  	for i, dep := range fd.GetDependency() {
    59  		resolvedDep := r.ResolveImport(fd.GetName(), dep)
    60  		var err error
    61  		deps[i], err = loadFileDescriptorLocked(resolvedDep, r)
    62  		if _, ok := err.(internal.ErrNoSuchFile); ok && resolvedDep != dep {
    63  			// try original path
    64  			deps[i], err = loadFileDescriptorLocked(dep, r)
    65  		}
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  	}
    70  	return CreateFileDescriptor(fd, deps...)
    71  }
    72  
    73  func getFileFromCache(file string) *FileDescriptor {
    74  	cacheMu.RLock()
    75  	defer cacheMu.RUnlock()
    76  	return filesCache[file]
    77  }
    78  
    79  func putCacheLocked(filename string, fd *FileDescriptor) {
    80  	filesCache[filename] = fd
    81  	putMessageCacheLocked(fd.messages)
    82  }
    83  
    84  func putMessageCacheLocked(mds []*MessageDescriptor) {
    85  	for _, md := range mds {
    86  		messagesCache[md.fqn] = md
    87  		putMessageCacheLocked(md.nested)
    88  	}
    89  }
    90  
    91  // interface implemented by generated messages, which all have a Descriptor() method in
    92  // addition to the methods of proto.Message
    93  type protoMessage interface {
    94  	proto.Message
    95  	Descriptor() ([]byte, []int)
    96  }
    97  
    98  // LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by
    99  // Message.Descriptor() for the given message type. If the given type is not recognized,
   100  // then a nil descriptor is returned.
   101  func LoadMessageDescriptor(message string) (*MessageDescriptor, error) {
   102  	return loadMessageDescriptor(message, nil)
   103  }
   104  
   105  func loadMessageDescriptor(message string, r *ImportResolver) (*MessageDescriptor, error) {
   106  	m := getMessageFromCache(message)
   107  	if m != nil {
   108  		return m, nil
   109  	}
   110  
   111  	pt := proto.MessageType(message)
   112  	if pt == nil {
   113  		return nil, nil
   114  	}
   115  	msg, err := messageFromType(pt)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	cacheMu.Lock()
   121  	defer cacheMu.Unlock()
   122  	return loadMessageDescriptorForTypeLocked(message, msg, r)
   123  }
   124  
   125  // LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned
   126  // by message.Descriptor() for the given message type. If the given type is not recognized,
   127  // then a nil descriptor is returned.
   128  func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) {
   129  	return loadMessageDescriptorForType(messageType, nil)
   130  }
   131  
   132  func loadMessageDescriptorForType(messageType reflect.Type, r *ImportResolver) (*MessageDescriptor, error) {
   133  	m, err := messageFromType(messageType)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	return loadMessageDescriptorForMessage(m, r)
   138  }
   139  
   140  // LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto
   141  // returned by message.Descriptor(). If the given type is not recognized, then a nil
   142  // descriptor is returned.
   143  func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) {
   144  	return loadMessageDescriptorForMessage(message, nil)
   145  }
   146  
   147  func loadMessageDescriptorForMessage(message proto.Message, r *ImportResolver) (*MessageDescriptor, error) {
   148  	// efficiently handle dynamic messages
   149  	type descriptorable interface {
   150  		GetMessageDescriptor() *MessageDescriptor
   151  	}
   152  	if d, ok := message.(descriptorable); ok {
   153  		return d.GetMessageDescriptor(), nil
   154  	}
   155  
   156  	name := proto.MessageName(message)
   157  	if name == "" {
   158  		return nil, nil
   159  	}
   160  	m := getMessageFromCache(name)
   161  	if m != nil {
   162  		return m, nil
   163  	}
   164  
   165  	cacheMu.Lock()
   166  	defer cacheMu.Unlock()
   167  	return loadMessageDescriptorForTypeLocked(name, message.(protoMessage), nil)
   168  }
   169  
   170  func messageFromType(mt reflect.Type) (protoMessage, error) {
   171  	if mt.Kind() != reflect.Ptr {
   172  		mt = reflect.PtrTo(mt)
   173  	}
   174  	m, ok := reflect.Zero(mt).Interface().(protoMessage)
   175  	if !ok {
   176  		return nil, fmt.Errorf("failed to create message from type: %v", mt)
   177  	}
   178  	return m, nil
   179  }
   180  
   181  func loadMessageDescriptorForTypeLocked(name string, message protoMessage, r *ImportResolver) (*MessageDescriptor, error) {
   182  	m := messagesCache[name]
   183  	if m != nil {
   184  		return m, nil
   185  	}
   186  
   187  	fdb, _ := message.Descriptor()
   188  	fd, err := internal.DecodeFileDescriptor(name, fdb)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	f, err := toFileDescriptorLocked(fd, r)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	putCacheLocked(fd.GetName(), f)
   198  	return f.FindSymbol(name).(*MessageDescriptor), nil
   199  }
   200  
   201  func getMessageFromCache(message string) *MessageDescriptor {
   202  	cacheMu.RLock()
   203  	defer cacheMu.RUnlock()
   204  	return messagesCache[message]
   205  }
   206  
   207  // interface implemented by all generated enums
   208  type protoEnum interface {
   209  	EnumDescriptor() ([]byte, []int)
   210  }
   211  
   212  // NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because
   213  // it is not useful since protoc-gen-go does not expose the name anywhere in generated
   214  // code or register it in a way that is it accessible for reflection code. This also
   215  // means we have to cache enum descriptors differently -- we can only cache them as
   216  // they are requested, as opposed to caching all enum types whenever a file descriptor
   217  // is cached. This is because we need to know the generated type of the enums, and we
   218  // don't know that at the time of caching file descriptors.
   219  
   220  // LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned
   221  // by enum.EnumDescriptor() for the given enum type.
   222  func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) {
   223  	return loadEnumDescriptorForType(enumType, nil)
   224  }
   225  
   226  func loadEnumDescriptorForType(enumType reflect.Type, r *ImportResolver) (*EnumDescriptor, error) {
   227  	// we cache descriptors using non-pointer type
   228  	if enumType.Kind() == reflect.Ptr {
   229  		enumType = enumType.Elem()
   230  	}
   231  	e := getEnumFromCache(enumType)
   232  	if e != nil {
   233  		return e, nil
   234  	}
   235  	enum, err := enumFromType(enumType)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  
   240  	cacheMu.Lock()
   241  	defer cacheMu.Unlock()
   242  	return loadEnumDescriptorForTypeLocked(enumType, enum, r)
   243  }
   244  
   245  // LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto
   246  // returned by enum.EnumDescriptor().
   247  func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) {
   248  	return loadEnumDescriptorForEnum(enum, nil)
   249  }
   250  
   251  func loadEnumDescriptorForEnum(enum protoEnum, r *ImportResolver) (*EnumDescriptor, error) {
   252  	et := reflect.TypeOf(enum)
   253  	// we cache descriptors using non-pointer type
   254  	if et.Kind() == reflect.Ptr {
   255  		et = et.Elem()
   256  		enum = reflect.Zero(et).Interface().(protoEnum)
   257  	}
   258  	e := getEnumFromCache(et)
   259  	if e != nil {
   260  		return e, nil
   261  	}
   262  
   263  	cacheMu.Lock()
   264  	defer cacheMu.Unlock()
   265  	return loadEnumDescriptorForTypeLocked(et, enum, r)
   266  }
   267  
   268  func enumFromType(et reflect.Type) (protoEnum, error) {
   269  	if et.Kind() != reflect.Int32 {
   270  		et = reflect.PtrTo(et)
   271  	}
   272  	e, ok := reflect.Zero(et).Interface().(protoEnum)
   273  	if !ok {
   274  		return nil, fmt.Errorf("failed to create enum from type: %v", et)
   275  	}
   276  	return e, nil
   277  }
   278  
   279  func loadEnumDescriptorForTypeLocked(et reflect.Type, enum protoEnum, r *ImportResolver) (*EnumDescriptor, error) {
   280  	e := enumCache[et]
   281  	if e != nil {
   282  		return e, nil
   283  	}
   284  
   285  	fdb, path := enum.EnumDescriptor()
   286  	name := fmt.Sprintf("%v", et)
   287  	fd, err := internal.DecodeFileDescriptor(name, fdb)
   288  	if err != nil {
   289  		return nil, err
   290  	}
   291  	// see if we already have cached "rich" descriptor
   292  	f, ok := filesCache[fd.GetName()]
   293  	if !ok {
   294  		f, err = toFileDescriptorLocked(fd, r)
   295  		if err != nil {
   296  			return nil, err
   297  		}
   298  		putCacheLocked(fd.GetName(), f)
   299  	}
   300  
   301  	ed := findEnum(f, path)
   302  	enumCache[et] = ed
   303  	return ed, nil
   304  }
   305  
   306  func getEnumFromCache(et reflect.Type) *EnumDescriptor {
   307  	cacheMu.RLock()
   308  	defer cacheMu.RUnlock()
   309  	return enumCache[et]
   310  }
   311  
   312  func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor {
   313  	if len(path) == 1 {
   314  		return fd.GetEnumTypes()[path[0]]
   315  	}
   316  	md := fd.GetMessageTypes()[path[0]]
   317  	for _, i := range path[1 : len(path)-1] {
   318  		md = md.GetNestedMessageTypes()[i]
   319  	}
   320  	return md.GetNestedEnumTypes()[path[len(path)-1]]
   321  }
   322  
   323  // LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given
   324  // extension description.
   325  func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) {
   326  	return loadFieldDescriptorForExtension(ext, nil)
   327  }
   328  
   329  func loadFieldDescriptorForExtension(ext *proto.ExtensionDesc, r *ImportResolver) (*FieldDescriptor, error) {
   330  	file, err := loadFileDescriptor(ext.Filename, r)
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  	field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor)
   335  	// make sure descriptor agrees with attributes of the ExtensionDesc
   336  	if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) ||
   337  		field.GetNumber() != ext.Field {
   338  		return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name)
   339  	}
   340  	return field, nil
   341  }