github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/desc/load.go (about)

     1  package desc
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"github.com/golang/protobuf/proto"
     9  	"google.golang.org/protobuf/reflect/protoreflect"
    10  	"google.golang.org/protobuf/reflect/protoregistry"
    11  	"google.golang.org/protobuf/types/descriptorpb"
    12  
    13  	"github.com/Big-big-orange/protoreflect/desc/sourceinfo"
    14  	"github.com/Big-big-orange/protoreflect/internal"
    15  )
    16  
    17  // The global cache is used to store descriptors that wrap items in
    18  // protoregistry.GlobalTypes and protoregistry.GlobalFiles. This prevents
    19  // repeating work to re-wrap underlying global descriptors.
    20  var (
    21  	// We put all wrapped file and message descriptors in this cache.
    22  	loadedDescriptors = lockingCache{cache: mapCache{}}
    23  
    24  	// Unfortunately, we need a different mechanism for enums for
    25  	// compatibility with old APIs, which required that they were
    26  	// registered in a different way :(
    27  	loadedEnumsMu sync.RWMutex
    28  	loadedEnums   = map[reflect.Type]*EnumDescriptor{}
    29  )
    30  
    31  // LoadFileDescriptor creates a file descriptor using the bytes returned by
    32  // proto.FileDescriptor. Descriptors are cached so that they do not need to be
    33  // re-processed if the same file is fetched again later.
    34  func LoadFileDescriptor(file string) (*FileDescriptor, error) {
    35  	d, err := sourceinfo.GlobalFiles.FindFileByPath(file)
    36  	if err == protoregistry.NotFound {
    37  		// for backwards compatibility, see if this matches a known old
    38  		// alias for the file (older versions of libraries that registered
    39  		// the files using incorrect/non-canonical paths)
    40  		if alt := internal.StdFileAliases[file]; alt != "" {
    41  			d, err = sourceinfo.GlobalFiles.FindFileByPath(alt)
    42  		}
    43  	}
    44  	if err != nil {
    45  		if err != protoregistry.NotFound {
    46  			return nil, internal.ErrNoSuchFile(file)
    47  		}
    48  		return nil, err
    49  	}
    50  	if fd := loadedDescriptors.get(d); fd != nil {
    51  		return fd.(*FileDescriptor), nil
    52  	}
    53  
    54  	var fd *FileDescriptor
    55  	loadedDescriptors.withLock(func(cache descriptorCache) {
    56  		// double-check cache, in case it was concurrently added while
    57  		// we were waiting for the lock
    58  		f := cache.get(d)
    59  		if f != nil {
    60  			fd = f.(*FileDescriptor)
    61  			return
    62  		}
    63  		fd, err = wrapFile(d, cache)
    64  	})
    65  	return fd, err
    66  }
    67  
    68  // LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by
    69  // Message.Descriptor() for the given message type. If the given type is not recognized,
    70  // then a nil descriptor is returned.
    71  func LoadMessageDescriptor(message string) (*MessageDescriptor, error) {
    72  	mt, err := sourceinfo.GlobalTypes.FindMessageByName(protoreflect.FullName(message))
    73  	if err != nil {
    74  		if err == protoregistry.NotFound {
    75  			return nil, nil
    76  		}
    77  		return nil, err
    78  	}
    79  	return loadMessageDescriptor(mt.Descriptor())
    80  }
    81  
    82  func loadMessageDescriptor(md protoreflect.MessageDescriptor) (*MessageDescriptor, error) {
    83  	d := loadedDescriptors.get(md)
    84  	if d != nil {
    85  		return d.(*MessageDescriptor), nil
    86  	}
    87  
    88  	var err error
    89  	loadedDescriptors.withLock(func(cache descriptorCache) {
    90  		d, err = wrapMessage(md, cache)
    91  	})
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	return d.(*MessageDescriptor), err
    96  }
    97  
    98  // LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned
    99  // by message.Descriptor() for the given message type. If the given type is not recognized,
   100  // then a nil descriptor is returned.
   101  func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) {
   102  	m, err := messageFromType(messageType)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	return LoadMessageDescriptorForMessage(m)
   107  }
   108  
   109  // LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto
   110  // returned by message.Descriptor(). If the given type is not recognized, then a nil
   111  // descriptor is returned.
   112  func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) {
   113  	// efficiently handle dynamic messages
   114  	type descriptorable interface {
   115  		GetMessageDescriptor() *MessageDescriptor
   116  	}
   117  	if d, ok := message.(descriptorable); ok {
   118  		return d.GetMessageDescriptor(), nil
   119  	}
   120  
   121  	var md protoreflect.MessageDescriptor
   122  	if m, ok := message.(protoreflect.ProtoMessage); ok {
   123  		md = m.ProtoReflect().Descriptor()
   124  	} else {
   125  		md = proto.MessageReflect(message).Descriptor()
   126  	}
   127  	return loadMessageDescriptor(sourceinfo.WrapMessage(md))
   128  }
   129  
   130  func messageFromType(mt reflect.Type) (proto.Message, error) {
   131  	if mt.Kind() != reflect.Ptr {
   132  		mt = reflect.PtrTo(mt)
   133  	}
   134  	m, ok := reflect.Zero(mt).Interface().(proto.Message)
   135  	if !ok {
   136  		return nil, fmt.Errorf("failed to create message from type: %v", mt)
   137  	}
   138  	return m, nil
   139  }
   140  
   141  // interface implemented by all generated enums
   142  type protoEnum interface {
   143  	EnumDescriptor() ([]byte, []int)
   144  }
   145  
   146  // NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because
   147  // it is not useful since protoc-gen-go does not expose the name anywhere in generated
   148  // code or register it in a way that is it accessible for reflection code. This also
   149  // means we have to cache enum descriptors differently -- we can only cache them as
   150  // they are requested, as opposed to caching all enum types whenever a file descriptor
   151  // is cached. This is because we need to know the generated type of the enums, and we
   152  // don't know that at the time of caching file descriptors.
   153  
   154  // LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned
   155  // by enum.EnumDescriptor() for the given enum type.
   156  func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) {
   157  	// we cache descriptors using non-pointer type
   158  	if enumType.Kind() == reflect.Ptr {
   159  		enumType = enumType.Elem()
   160  	}
   161  	e := getEnumFromCache(enumType)
   162  	if e != nil {
   163  		return e, nil
   164  	}
   165  	enum, err := enumFromType(enumType)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	return loadEnumDescriptor(enumType, enum)
   171  }
   172  
   173  func getEnumFromCache(t reflect.Type) *EnumDescriptor {
   174  	loadedEnumsMu.RLock()
   175  	defer loadedEnumsMu.RUnlock()
   176  	return loadedEnums[t]
   177  }
   178  
   179  func putEnumInCache(t reflect.Type, d *EnumDescriptor) {
   180  	loadedEnumsMu.Lock()
   181  	defer loadedEnumsMu.Unlock()
   182  	loadedEnums[t] = d
   183  }
   184  
   185  // LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto
   186  // returned by enum.EnumDescriptor().
   187  func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) {
   188  	et := reflect.TypeOf(enum)
   189  	// we cache descriptors using non-pointer type
   190  	if et.Kind() == reflect.Ptr {
   191  		et = et.Elem()
   192  		enum = reflect.Zero(et).Interface().(protoEnum)
   193  	}
   194  	e := getEnumFromCache(et)
   195  	if e != nil {
   196  		return e, nil
   197  	}
   198  
   199  	return loadEnumDescriptor(et, enum)
   200  }
   201  
   202  func enumFromType(et reflect.Type) (protoEnum, error) {
   203  	e, ok := reflect.Zero(et).Interface().(protoEnum)
   204  	if !ok {
   205  		if et.Kind() != reflect.Ptr {
   206  			et = et.Elem()
   207  		}
   208  		e, ok = reflect.Zero(et).Interface().(protoEnum)
   209  	}
   210  	if !ok {
   211  		return nil, fmt.Errorf("failed to create enum from type: %v", et)
   212  	}
   213  	return e, nil
   214  }
   215  
   216  func getDescriptorForEnum(enum protoEnum) (*descriptorpb.FileDescriptorProto, []int, error) {
   217  	fdb, path := enum.EnumDescriptor()
   218  	name := fmt.Sprintf("%T", enum)
   219  	fd, err := internal.DecodeFileDescriptor(name, fdb)
   220  	return fd, path, err
   221  }
   222  
   223  func loadEnumDescriptor(et reflect.Type, enum protoEnum) (*EnumDescriptor, error) {
   224  	fdp, path, err := getDescriptorForEnum(enum)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	fd, err := LoadFileDescriptor(fdp.GetName())
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  
   234  	ed := findEnum(fd, path)
   235  	putEnumInCache(et, ed)
   236  	return ed, nil
   237  }
   238  
   239  func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor {
   240  	if len(path) == 1 {
   241  		return fd.GetEnumTypes()[path[0]]
   242  	}
   243  	md := fd.GetMessageTypes()[path[0]]
   244  	for _, i := range path[1 : len(path)-1] {
   245  		md = md.GetNestedMessageTypes()[i]
   246  	}
   247  	return md.GetNestedEnumTypes()[path[len(path)-1]]
   248  }
   249  
   250  // LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given
   251  // extension description.
   252  func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) {
   253  	file, err := LoadFileDescriptor(ext.Filename)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor)
   258  	// make sure descriptor agrees with attributes of the ExtensionDesc
   259  	if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) ||
   260  		field.GetNumber() != ext.Field {
   261  		return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name)
   262  	}
   263  	return field, nil
   264  }