github.com/bakjos/protoreflect@v1.9.2/dynamic/extension_registry.go (about)

     1  package dynamic
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"github.com/golang/protobuf/proto"
     9  
    10  	"github.com/bakjos/protoreflect/desc"
    11  )
    12  
    13  // ExtensionRegistry is a registry of known extension fields. This is used to parse
    14  // extension fields encountered when de-serializing a dynamic message.
    15  type ExtensionRegistry struct {
    16  	includeDefault bool
    17  	mu             sync.RWMutex
    18  	exts           map[string]map[int32]*desc.FieldDescriptor
    19  }
    20  
    21  // NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions,
    22  // which are those that are statically linked into the current program (e.g. registered by
    23  // protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the
    24  // registry will override any default extensions that are for the same extendee and have the
    25  // same tag number and/or name.
    26  func NewExtensionRegistryWithDefaults() *ExtensionRegistry {
    27  	return &ExtensionRegistry{includeDefault: true}
    28  }
    29  
    30  // AddExtensionDesc adds the given extensions to the registry.
    31  func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error {
    32  	flds := make([]*desc.FieldDescriptor, len(exts))
    33  	for i, ext := range exts {
    34  		fd, err := desc.LoadFieldDescriptorForExtension(ext)
    35  		if err != nil {
    36  			return err
    37  		}
    38  		flds[i] = fd
    39  	}
    40  	r.mu.Lock()
    41  	defer r.mu.Unlock()
    42  	if r.exts == nil {
    43  		r.exts = map[string]map[int32]*desc.FieldDescriptor{}
    44  	}
    45  	for _, fd := range flds {
    46  		r.putExtensionLocked(fd)
    47  	}
    48  	return nil
    49  }
    50  
    51  // AddExtension adds the given extensions to the registry. The given extensions
    52  // will overwrite any previously added extensions that are for the same extendee
    53  // message and same extension tag number.
    54  func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error {
    55  	for _, ext := range exts {
    56  		if !ext.IsExtension() {
    57  			return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName())
    58  		}
    59  	}
    60  	r.mu.Lock()
    61  	defer r.mu.Unlock()
    62  	if r.exts == nil {
    63  		r.exts = map[string]map[int32]*desc.FieldDescriptor{}
    64  	}
    65  	for _, ext := range exts {
    66  		r.putExtensionLocked(ext)
    67  	}
    68  	return nil
    69  }
    70  
    71  // AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor.
    72  func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) {
    73  	r.mu.Lock()
    74  	defer r.mu.Unlock()
    75  	r.addExtensionsFromFileLocked(fd, false, nil)
    76  }
    77  
    78  // AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file
    79  // descriptor and also recursively adds all extensions defined in that file's dependencies. This adds
    80  // extensions from the entire transitive closure for the given file.
    81  func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) {
    82  	r.mu.Lock()
    83  	defer r.mu.Unlock()
    84  	already := map[*desc.FileDescriptor]struct{}{}
    85  	r.addExtensionsFromFileLocked(fd, true, already)
    86  }
    87  
    88  func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) {
    89  	if _, ok := alreadySeen[fd]; ok {
    90  		return
    91  	}
    92  
    93  	if r.exts == nil {
    94  		r.exts = map[string]map[int32]*desc.FieldDescriptor{}
    95  	}
    96  	for _, ext := range fd.GetExtensions() {
    97  		r.putExtensionLocked(ext)
    98  	}
    99  	for _, msg := range fd.GetMessageTypes() {
   100  		r.addExtensionsFromMessageLocked(msg)
   101  	}
   102  
   103  	if recursive {
   104  		alreadySeen[fd] = struct{}{}
   105  		for _, dep := range fd.GetDependencies() {
   106  			r.addExtensionsFromFileLocked(dep, recursive, alreadySeen)
   107  		}
   108  	}
   109  }
   110  
   111  func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) {
   112  	for _, ext := range md.GetNestedExtensions() {
   113  		r.putExtensionLocked(ext)
   114  	}
   115  	for _, msg := range md.GetNestedMessageTypes() {
   116  		r.addExtensionsFromMessageLocked(msg)
   117  	}
   118  }
   119  
   120  func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) {
   121  	msgName := fd.GetOwner().GetFullyQualifiedName()
   122  	m := r.exts[msgName]
   123  	if m == nil {
   124  		m = map[int32]*desc.FieldDescriptor{}
   125  		r.exts[msgName] = m
   126  	}
   127  	m[fd.GetNumber()] = fd
   128  }
   129  
   130  // FindExtension queries for the extension field with the given extendee name (must be a fully-qualified
   131  // message name) and tag number. If no extension is known, nil is returned.
   132  func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor {
   133  	if r == nil {
   134  		return nil
   135  	}
   136  	r.mu.RLock()
   137  	defer r.mu.RUnlock()
   138  	fd := r.exts[messageName][tagNumber]
   139  	if fd == nil && r.includeDefault {
   140  		ext := getDefaultExtensions(messageName)[tagNumber]
   141  		if ext != nil {
   142  			fd, _ = desc.LoadFieldDescriptorForExtension(ext)
   143  		}
   144  	}
   145  	return fd
   146  }
   147  
   148  // FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified
   149  // message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil
   150  // is returned.
   151  func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor {
   152  	if r == nil {
   153  		return nil
   154  	}
   155  	r.mu.RLock()
   156  	defer r.mu.RUnlock()
   157  	for _, fd := range r.exts[messageName] {
   158  		if fd.GetFullyQualifiedName() == fieldName {
   159  			return fd
   160  		}
   161  	}
   162  	if r.includeDefault {
   163  		for _, ext := range getDefaultExtensions(messageName) {
   164  			fd, _ := desc.LoadFieldDescriptorForExtension(ext)
   165  			if fd.GetFullyQualifiedName() == fieldName {
   166  				return fd
   167  			}
   168  		}
   169  	}
   170  	return nil
   171  }
   172  
   173  // FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified
   174  // message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned.
   175  // The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last
   176  // component uses the field's JSON name (if present).
   177  func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor {
   178  	if r == nil {
   179  		return nil
   180  	}
   181  	r.mu.RLock()
   182  	defer r.mu.RUnlock()
   183  	for _, fd := range r.exts[messageName] {
   184  		if fd.GetFullyQualifiedJSONName() == fieldName {
   185  			return fd
   186  		}
   187  	}
   188  	if r.includeDefault {
   189  		for _, ext := range getDefaultExtensions(messageName) {
   190  			fd, _ := desc.LoadFieldDescriptorForExtension(ext)
   191  			if fd.GetFullyQualifiedJSONName() == fieldName {
   192  				return fd
   193  			}
   194  		}
   195  	}
   196  	return nil
   197  }
   198  
   199  func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc {
   200  	t := proto.MessageType(messageName)
   201  	if t != nil {
   202  		msg := reflect.Zero(t).Interface().(proto.Message)
   203  		return proto.RegisteredExtensions(msg)
   204  	}
   205  	return nil
   206  }
   207  
   208  // AllExtensionsForType returns all known extension fields for the given extendee name (must be a
   209  // fully-qualified message name).
   210  func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor {
   211  	if r == nil {
   212  		return []*desc.FieldDescriptor(nil)
   213  	}
   214  	r.mu.RLock()
   215  	defer r.mu.RUnlock()
   216  	flds := r.exts[messageName]
   217  	var ret []*desc.FieldDescriptor
   218  	if r.includeDefault {
   219  		exts := getDefaultExtensions(messageName)
   220  		if len(exts) > 0 || len(flds) > 0 {
   221  			ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds))
   222  		}
   223  		for tag, ext := range exts {
   224  			if _, ok := flds[tag]; ok {
   225  				// skip default extension and use the one explicitly registered instead
   226  				continue
   227  			}
   228  			fd, _ := desc.LoadFieldDescriptorForExtension(ext)
   229  			if fd != nil {
   230  				ret = append(ret, fd)
   231  			}
   232  		}
   233  	} else if len(flds) > 0 {
   234  		ret = make([]*desc.FieldDescriptor, 0, len(flds))
   235  	}
   236  
   237  	for _, ext := range flds {
   238  		ret = append(ret, ext)
   239  	}
   240  	return ret
   241  }