github.com/jhump/protoreflect@v1.16.0/dynamic/extension_registry.go (about) 1 package dynamic 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync" 7 8 "github.com/golang/protobuf/proto" 9 10 "github.com/jhump/protoreflect/desc" 11 ) 12 13 // ExtensionRegistry is a registry of known extension fields. This is used to parse 14 // extension fields encountered when de-serializing a dynamic message. 15 type ExtensionRegistry struct { 16 includeDefault bool 17 mu sync.RWMutex 18 exts map[string]map[int32]*desc.FieldDescriptor 19 } 20 21 // NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions, 22 // which are those that are statically linked into the current program (e.g. registered by 23 // protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the 24 // registry will override any default extensions that are for the same extendee and have the 25 // same tag number and/or name. 26 func NewExtensionRegistryWithDefaults() *ExtensionRegistry { 27 return &ExtensionRegistry{includeDefault: true} 28 } 29 30 // AddExtensionDesc adds the given extensions to the registry. 31 func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error { 32 flds := make([]*desc.FieldDescriptor, len(exts)) 33 for i, ext := range exts { 34 fd, err := desc.LoadFieldDescriptorForExtension(ext) 35 if err != nil { 36 return err 37 } 38 flds[i] = fd 39 } 40 r.mu.Lock() 41 defer r.mu.Unlock() 42 if r.exts == nil { 43 r.exts = map[string]map[int32]*desc.FieldDescriptor{} 44 } 45 for _, fd := range flds { 46 r.putExtensionLocked(fd) 47 } 48 return nil 49 } 50 51 // AddExtension adds the given extensions to the registry. The given extensions 52 // will overwrite any previously added extensions that are for the same extendee 53 // message and same extension tag number. 54 func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error { 55 for _, ext := range exts { 56 if !ext.IsExtension() { 57 return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName()) 58 } 59 } 60 r.mu.Lock() 61 defer r.mu.Unlock() 62 if r.exts == nil { 63 r.exts = map[string]map[int32]*desc.FieldDescriptor{} 64 } 65 for _, ext := range exts { 66 r.putExtensionLocked(ext) 67 } 68 return nil 69 } 70 71 // AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor. 72 func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) { 73 r.mu.Lock() 74 defer r.mu.Unlock() 75 r.addExtensionsFromFileLocked(fd, false, nil) 76 } 77 78 // AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file 79 // descriptor and also recursively adds all extensions defined in that file's dependencies. This adds 80 // extensions from the entire transitive closure for the given file. 81 func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) { 82 r.mu.Lock() 83 defer r.mu.Unlock() 84 already := map[*desc.FileDescriptor]struct{}{} 85 r.addExtensionsFromFileLocked(fd, true, already) 86 } 87 88 func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) { 89 if _, ok := alreadySeen[fd]; ok { 90 return 91 } 92 93 if r.exts == nil { 94 r.exts = map[string]map[int32]*desc.FieldDescriptor{} 95 } 96 for _, ext := range fd.GetExtensions() { 97 r.putExtensionLocked(ext) 98 } 99 for _, msg := range fd.GetMessageTypes() { 100 r.addExtensionsFromMessageLocked(msg) 101 } 102 103 if recursive { 104 alreadySeen[fd] = struct{}{} 105 for _, dep := range fd.GetDependencies() { 106 r.addExtensionsFromFileLocked(dep, recursive, alreadySeen) 107 } 108 } 109 } 110 111 func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) { 112 for _, ext := range md.GetNestedExtensions() { 113 r.putExtensionLocked(ext) 114 } 115 for _, msg := range md.GetNestedMessageTypes() { 116 r.addExtensionsFromMessageLocked(msg) 117 } 118 } 119 120 func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) { 121 msgName := fd.GetOwner().GetFullyQualifiedName() 122 m := r.exts[msgName] 123 if m == nil { 124 m = map[int32]*desc.FieldDescriptor{} 125 r.exts[msgName] = m 126 } 127 m[fd.GetNumber()] = fd 128 } 129 130 // FindExtension queries for the extension field with the given extendee name (must be a fully-qualified 131 // message name) and tag number. If no extension is known, nil is returned. 132 func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor { 133 if r == nil { 134 return nil 135 } 136 r.mu.RLock() 137 defer r.mu.RUnlock() 138 fd := r.exts[messageName][tagNumber] 139 if fd == nil && r.includeDefault { 140 ext := getDefaultExtensions(messageName)[tagNumber] 141 if ext != nil { 142 fd, _ = desc.LoadFieldDescriptorForExtension(ext) 143 } 144 } 145 return fd 146 } 147 148 // FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified 149 // message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil 150 // is returned. 151 func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor { 152 if r == nil { 153 return nil 154 } 155 r.mu.RLock() 156 defer r.mu.RUnlock() 157 for _, fd := range r.exts[messageName] { 158 if fd.GetFullyQualifiedName() == fieldName { 159 return fd 160 } 161 } 162 if r.includeDefault { 163 for _, ext := range getDefaultExtensions(messageName) { 164 fd, _ := desc.LoadFieldDescriptorForExtension(ext) 165 if fd.GetFullyQualifiedName() == fieldName { 166 return fd 167 } 168 } 169 } 170 return nil 171 } 172 173 // FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified 174 // message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned. 175 // The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last 176 // component uses the field's JSON name (if present). 177 func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor { 178 if r == nil { 179 return nil 180 } 181 r.mu.RLock() 182 defer r.mu.RUnlock() 183 for _, fd := range r.exts[messageName] { 184 if fd.GetFullyQualifiedJSONName() == fieldName { 185 return fd 186 } 187 } 188 if r.includeDefault { 189 for _, ext := range getDefaultExtensions(messageName) { 190 fd, _ := desc.LoadFieldDescriptorForExtension(ext) 191 if fd.GetFullyQualifiedJSONName() == fieldName { 192 return fd 193 } 194 } 195 } 196 return nil 197 } 198 199 func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc { 200 t := proto.MessageType(messageName) 201 if t != nil { 202 msg := reflect.Zero(t).Interface().(proto.Message) 203 return proto.RegisteredExtensions(msg) 204 } 205 return nil 206 } 207 208 // AllExtensionsForType returns all known extension fields for the given extendee name (must be a 209 // fully-qualified message name). 210 func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor { 211 if r == nil { 212 return []*desc.FieldDescriptor(nil) 213 } 214 r.mu.RLock() 215 defer r.mu.RUnlock() 216 flds := r.exts[messageName] 217 var ret []*desc.FieldDescriptor 218 if r.includeDefault { 219 exts := getDefaultExtensions(messageName) 220 if len(exts) > 0 || len(flds) > 0 { 221 ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds)) 222 } 223 for tag, ext := range exts { 224 if _, ok := flds[tag]; ok { 225 // skip default extension and use the one explicitly registered instead 226 continue 227 } 228 fd, _ := desc.LoadFieldDescriptorForExtension(ext) 229 if fd != nil { 230 ret = append(ret, fd) 231 } 232 } 233 } else if len(flds) > 0 { 234 ret = make([]*desc.FieldDescriptor, 0, len(flds)) 235 } 236 237 for _, ext := range flds { 238 ret = append(ret, ext) 239 } 240 return ret 241 }