github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/desc/load.go (about) 1 package desc 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync" 7 8 "github.com/golang/protobuf/proto" 9 "google.golang.org/protobuf/reflect/protoreflect" 10 "google.golang.org/protobuf/reflect/protoregistry" 11 "google.golang.org/protobuf/types/descriptorpb" 12 13 "github.com/Big-big-orange/protoreflect/desc/sourceinfo" 14 "github.com/Big-big-orange/protoreflect/internal" 15 ) 16 17 // The global cache is used to store descriptors that wrap items in 18 // protoregistry.GlobalTypes and protoregistry.GlobalFiles. This prevents 19 // repeating work to re-wrap underlying global descriptors. 20 var ( 21 // We put all wrapped file and message descriptors in this cache. 22 loadedDescriptors = lockingCache{cache: mapCache{}} 23 24 // Unfortunately, we need a different mechanism for enums for 25 // compatibility with old APIs, which required that they were 26 // registered in a different way :( 27 loadedEnumsMu sync.RWMutex 28 loadedEnums = map[reflect.Type]*EnumDescriptor{} 29 ) 30 31 // LoadFileDescriptor creates a file descriptor using the bytes returned by 32 // proto.FileDescriptor. Descriptors are cached so that they do not need to be 33 // re-processed if the same file is fetched again later. 34 func LoadFileDescriptor(file string) (*FileDescriptor, error) { 35 d, err := sourceinfo.GlobalFiles.FindFileByPath(file) 36 if err == protoregistry.NotFound { 37 // for backwards compatibility, see if this matches a known old 38 // alias for the file (older versions of libraries that registered 39 // the files using incorrect/non-canonical paths) 40 if alt := internal.StdFileAliases[file]; alt != "" { 41 d, err = sourceinfo.GlobalFiles.FindFileByPath(alt) 42 } 43 } 44 if err != nil { 45 if err != protoregistry.NotFound { 46 return nil, internal.ErrNoSuchFile(file) 47 } 48 return nil, err 49 } 50 if fd := loadedDescriptors.get(d); fd != nil { 51 return fd.(*FileDescriptor), nil 52 } 53 54 var fd *FileDescriptor 55 loadedDescriptors.withLock(func(cache descriptorCache) { 56 // double-check cache, in case it was concurrently added while 57 // we were waiting for the lock 58 f := cache.get(d) 59 if f != nil { 60 fd = f.(*FileDescriptor) 61 return 62 } 63 fd, err = wrapFile(d, cache) 64 }) 65 return fd, err 66 } 67 68 // LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by 69 // Message.Descriptor() for the given message type. If the given type is not recognized, 70 // then a nil descriptor is returned. 71 func LoadMessageDescriptor(message string) (*MessageDescriptor, error) { 72 mt, err := sourceinfo.GlobalTypes.FindMessageByName(protoreflect.FullName(message)) 73 if err != nil { 74 if err == protoregistry.NotFound { 75 return nil, nil 76 } 77 return nil, err 78 } 79 return loadMessageDescriptor(mt.Descriptor()) 80 } 81 82 func loadMessageDescriptor(md protoreflect.MessageDescriptor) (*MessageDescriptor, error) { 83 d := loadedDescriptors.get(md) 84 if d != nil { 85 return d.(*MessageDescriptor), nil 86 } 87 88 var err error 89 loadedDescriptors.withLock(func(cache descriptorCache) { 90 d, err = wrapMessage(md, cache) 91 }) 92 if err != nil { 93 return nil, err 94 } 95 return d.(*MessageDescriptor), err 96 } 97 98 // LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned 99 // by message.Descriptor() for the given message type. If the given type is not recognized, 100 // then a nil descriptor is returned. 101 func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) { 102 m, err := messageFromType(messageType) 103 if err != nil { 104 return nil, err 105 } 106 return LoadMessageDescriptorForMessage(m) 107 } 108 109 // LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto 110 // returned by message.Descriptor(). If the given type is not recognized, then a nil 111 // descriptor is returned. 112 func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) { 113 // efficiently handle dynamic messages 114 type descriptorable interface { 115 GetMessageDescriptor() *MessageDescriptor 116 } 117 if d, ok := message.(descriptorable); ok { 118 return d.GetMessageDescriptor(), nil 119 } 120 121 var md protoreflect.MessageDescriptor 122 if m, ok := message.(protoreflect.ProtoMessage); ok { 123 md = m.ProtoReflect().Descriptor() 124 } else { 125 md = proto.MessageReflect(message).Descriptor() 126 } 127 return loadMessageDescriptor(sourceinfo.WrapMessage(md)) 128 } 129 130 func messageFromType(mt reflect.Type) (proto.Message, error) { 131 if mt.Kind() != reflect.Ptr { 132 mt = reflect.PtrTo(mt) 133 } 134 m, ok := reflect.Zero(mt).Interface().(proto.Message) 135 if !ok { 136 return nil, fmt.Errorf("failed to create message from type: %v", mt) 137 } 138 return m, nil 139 } 140 141 // interface implemented by all generated enums 142 type protoEnum interface { 143 EnumDescriptor() ([]byte, []int) 144 } 145 146 // NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because 147 // it is not useful since protoc-gen-go does not expose the name anywhere in generated 148 // code or register it in a way that is it accessible for reflection code. This also 149 // means we have to cache enum descriptors differently -- we can only cache them as 150 // they are requested, as opposed to caching all enum types whenever a file descriptor 151 // is cached. This is because we need to know the generated type of the enums, and we 152 // don't know that at the time of caching file descriptors. 153 154 // LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned 155 // by enum.EnumDescriptor() for the given enum type. 156 func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) { 157 // we cache descriptors using non-pointer type 158 if enumType.Kind() == reflect.Ptr { 159 enumType = enumType.Elem() 160 } 161 e := getEnumFromCache(enumType) 162 if e != nil { 163 return e, nil 164 } 165 enum, err := enumFromType(enumType) 166 if err != nil { 167 return nil, err 168 } 169 170 return loadEnumDescriptor(enumType, enum) 171 } 172 173 func getEnumFromCache(t reflect.Type) *EnumDescriptor { 174 loadedEnumsMu.RLock() 175 defer loadedEnumsMu.RUnlock() 176 return loadedEnums[t] 177 } 178 179 func putEnumInCache(t reflect.Type, d *EnumDescriptor) { 180 loadedEnumsMu.Lock() 181 defer loadedEnumsMu.Unlock() 182 loadedEnums[t] = d 183 } 184 185 // LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto 186 // returned by enum.EnumDescriptor(). 187 func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) { 188 et := reflect.TypeOf(enum) 189 // we cache descriptors using non-pointer type 190 if et.Kind() == reflect.Ptr { 191 et = et.Elem() 192 enum = reflect.Zero(et).Interface().(protoEnum) 193 } 194 e := getEnumFromCache(et) 195 if e != nil { 196 return e, nil 197 } 198 199 return loadEnumDescriptor(et, enum) 200 } 201 202 func enumFromType(et reflect.Type) (protoEnum, error) { 203 e, ok := reflect.Zero(et).Interface().(protoEnum) 204 if !ok { 205 if et.Kind() != reflect.Ptr { 206 et = et.Elem() 207 } 208 e, ok = reflect.Zero(et).Interface().(protoEnum) 209 } 210 if !ok { 211 return nil, fmt.Errorf("failed to create enum from type: %v", et) 212 } 213 return e, nil 214 } 215 216 func getDescriptorForEnum(enum protoEnum) (*descriptorpb.FileDescriptorProto, []int, error) { 217 fdb, path := enum.EnumDescriptor() 218 name := fmt.Sprintf("%T", enum) 219 fd, err := internal.DecodeFileDescriptor(name, fdb) 220 return fd, path, err 221 } 222 223 func loadEnumDescriptor(et reflect.Type, enum protoEnum) (*EnumDescriptor, error) { 224 fdp, path, err := getDescriptorForEnum(enum) 225 if err != nil { 226 return nil, err 227 } 228 229 fd, err := LoadFileDescriptor(fdp.GetName()) 230 if err != nil { 231 return nil, err 232 } 233 234 ed := findEnum(fd, path) 235 putEnumInCache(et, ed) 236 return ed, nil 237 } 238 239 func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor { 240 if len(path) == 1 { 241 return fd.GetEnumTypes()[path[0]] 242 } 243 md := fd.GetMessageTypes()[path[0]] 244 for _, i := range path[1 : len(path)-1] { 245 md = md.GetNestedMessageTypes()[i] 246 } 247 return md.GetNestedEnumTypes()[path[len(path)-1]] 248 } 249 250 // LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given 251 // extension description. 252 func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) { 253 file, err := LoadFileDescriptor(ext.Filename) 254 if err != nil { 255 return nil, err 256 } 257 field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor) 258 // make sure descriptor agrees with attributes of the ExtensionDesc 259 if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) || 260 field.GetNumber() != ext.Field { 261 return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name) 262 } 263 return field, nil 264 }