github.com/jhump/protocompile@v0.0.0-20221021153901-4f6f732835e8/linker/files.go (about)

     1  package linker
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"google.golang.org/protobuf/reflect/protodesc"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  	"google.golang.org/protobuf/reflect/protoregistry"
    10  	"google.golang.org/protobuf/types/dynamicpb"
    11  
    12  	"github.com/jhump/protocompile/walk"
    13  )
    14  
    15  // File is like a super-powered protoreflect.FileDescriptor. It includes helpful
    16  // methods for looking up elements in the descriptor and can be used to create a
    17  // resolver for all in the file's transitive closure of dependencies. (See
    18  // ResolverFromFile.)
    19  type File interface {
    20  	protoreflect.FileDescriptor
    21  	// FindDescriptorByName returns the given named element that is defined in
    22  	// this file. If no such element exists, nil is returned.
    23  	FindDescriptorByName(name protoreflect.FullName) protoreflect.Descriptor
    24  	// FindImportByPath returns the File corresponding to the given import path.
    25  	// If this file does not import the given path, nil is returned.
    26  	FindImportByPath(path string) File
    27  	// FindExtensionByNumber returns the extension descriptor for the given tag
    28  	// that extends the given message name. If no such extension is defined in this
    29  	// file, nil is returned.
    30  	FindExtensionByNumber(message protoreflect.FullName, tag protoreflect.FieldNumber) protoreflect.ExtensionTypeDescriptor
    31  	// Imports returns this file's imports. These are only the files directly
    32  	// imported by the file. Indirect transitive dependencies will not be in
    33  	// the returned slice.
    34  	importsAsFiles() Files
    35  }
    36  
    37  // NewFile converts a protoreflect.FileDescriptor to a File. The given deps must
    38  // contain all dependencies/imports of f. Also see NewFileRecursive.
    39  func NewFile(f protoreflect.FileDescriptor, deps Files) (File, error) {
    40  	if asFile, ok := f.(File); ok {
    41  		return asFile, nil
    42  	}
    43  	checkedDeps := make(Files, f.Imports().Len())
    44  	for i := 0; i < f.Imports().Len(); i++ {
    45  		imprt := f.Imports().Get(i)
    46  		dep := deps.FindFileByPath(imprt.Path())
    47  		if dep == nil {
    48  			return nil, fmt.Errorf("cannot create File for %q: missing dependency for %q", f.Path(), imprt.Path())
    49  		}
    50  		checkedDeps[i] = dep
    51  	}
    52  	return newFile(f, checkedDeps)
    53  }
    54  
    55  func newFile(f protoreflect.FileDescriptor, deps Files) (File, error) {
    56  	descs := map[protoreflect.FullName]protoreflect.Descriptor{}
    57  	err := walk.Descriptors(f, func(d protoreflect.Descriptor) error {
    58  		if _, ok := descs[d.FullName()]; ok {
    59  			return fmt.Errorf("file %q contains multiple elements with the name %s", f.Path(), d.FullName())
    60  		}
    61  		descs[d.FullName()] = d
    62  		return nil
    63  	})
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	return file{
    68  		FileDescriptor: f,
    69  		descs:          descs,
    70  		deps:           deps,
    71  	}, nil
    72  }
    73  
    74  // NewFileRecursive recursively converts a protoreflect.FileDescriptor to a File.
    75  // If f has any dependencies/imports, they are converted, too, including any and
    76  // all transitive dependencies.
    77  func NewFileRecursive(f protoreflect.FileDescriptor) (File, error) {
    78  	if asFile, ok := f.(File); ok {
    79  		return asFile, nil
    80  	}
    81  	return newFileRecursive(f, map[protoreflect.FileDescriptor]File{})
    82  }
    83  
    84  func newFileRecursive(fd protoreflect.FileDescriptor, seen map[protoreflect.FileDescriptor]File) (File, error) {
    85  	if res, ok := seen[fd]; ok {
    86  		if res == nil {
    87  			return nil, fmt.Errorf("import cycle encountered: file %s transitively imports itself", fd.Path())
    88  		}
    89  		return res, nil
    90  	}
    91  
    92  	if f, ok := fd.(File); ok {
    93  		seen[fd] = f
    94  		return f, nil
    95  	}
    96  
    97  	seen[fd] = nil
    98  	deps := make([]File, fd.Imports().Len())
    99  	for i := 0; i < fd.Imports().Len(); i++ {
   100  		imprt := fd.Imports().Get(i)
   101  		dep, err := newFileRecursive(imprt, seen)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  		deps[i] = dep
   106  	}
   107  
   108  	f, err := newFile(fd, deps)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	seen[fd] = f
   113  	return f, nil
   114  }
   115  
   116  type file struct {
   117  	protoreflect.FileDescriptor
   118  	descs map[protoreflect.FullName]protoreflect.Descriptor
   119  	deps  Files
   120  }
   121  
   122  func (f file) FindDescriptorByName(name protoreflect.FullName) protoreflect.Descriptor {
   123  	return f.descs[name]
   124  }
   125  
   126  func (f file) FindImportByPath(path string) File {
   127  	return f.deps.FindFileByPath(path)
   128  }
   129  
   130  func (f file) FindExtensionByNumber(msg protoreflect.FullName, tag protoreflect.FieldNumber) protoreflect.ExtensionTypeDescriptor {
   131  	return findExtension(f, msg, tag)
   132  }
   133  
   134  func (f file) importsAsFiles() Files {
   135  	return f.deps
   136  }
   137  
   138  var _ File = file{}
   139  
   140  // Files represents a set of protobuf files. It is a slice of File values, but
   141  // also provides a method for easily looking up files by path and name.
   142  type Files []File
   143  
   144  // FindFileByPath finds a file in f that has the given path and name. If f
   145  // contains no such file, nil is returned.
   146  func (f Files) FindFileByPath(path string) File {
   147  	for _, file := range f {
   148  		if file.Path() == path {
   149  			return file
   150  		}
   151  	}
   152  	return nil
   153  }
   154  
   155  // AsResolver returns a Resolver that uses f as the source of descriptors. If
   156  // a given query cannot be answered with the files in f, the query will fail
   157  // with a protoregistry.NotFound error. The implementation just delegates calls
   158  // to each file until a result is found.
   159  //
   160  // Also see ResolverFromFile.
   161  func (f Files) AsResolver() Resolver {
   162  	return filesResolver(f)
   163  }
   164  
   165  // Resolver is an interface that can resolve various kinds of queries about
   166  // descriptors. It satisfies the resolver interfaces defined in protodesc
   167  // and protoregistry packages.
   168  type Resolver interface {
   169  	protodesc.Resolver
   170  	protoregistry.MessageTypeResolver
   171  	protoregistry.ExtensionTypeResolver
   172  }
   173  
   174  // ResolverFromFile returns a Resolver that uses the given file plus its full set of
   175  // transitive dependencies as the source of descriptors. If a given query
   176  // cannot be answered with these files, the query will fail with a
   177  // protoregistry.NotFound error.
   178  //
   179  // Note that this function does not compute any additional indexes, for
   180  // efficient search, so queries generally take linear time, O(n) where n is the
   181  // number of files in the transitive closure of the given file. Queries for an
   182  // extension by number are linear with the number of messages and extensions
   183  // defined across all the files.
   184  func ResolverFromFile(f File) Resolver {
   185  	return fileResolver{
   186  		f:    f,
   187  		deps: f.importsAsFiles().AsResolver(),
   188  	}
   189  }
   190  
   191  type fileResolver struct {
   192  	f    File
   193  	deps Resolver
   194  }
   195  
   196  func (r fileResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
   197  	if r.f.Path() == path {
   198  		return r.f, nil
   199  	}
   200  	return r.deps.FindFileByPath(path)
   201  }
   202  
   203  func (r fileResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
   204  	d := r.f.FindDescriptorByName(name)
   205  	if d != nil {
   206  		return d, nil
   207  	}
   208  	return r.deps.FindDescriptorByName(name)
   209  }
   210  
   211  func (r fileResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
   212  	d := r.f.FindDescriptorByName(message)
   213  	if d != nil {
   214  		if md, ok := d.(protoreflect.MessageDescriptor); ok {
   215  			return dynamicpb.NewMessageType(md), nil
   216  		}
   217  		return nil, protoregistry.NotFound
   218  	}
   219  	return r.deps.FindMessageByName(message)
   220  }
   221  
   222  func (r fileResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) {
   223  	fullName := messageNameFromUrl(url)
   224  	return r.FindMessageByName(protoreflect.FullName(fullName))
   225  }
   226  
   227  func messageNameFromUrl(url string) string {
   228  	lastSlash := strings.LastIndexByte(url, '/')
   229  	var fullName string
   230  	if lastSlash >= 0 {
   231  		fullName = url[lastSlash+1:]
   232  	} else {
   233  		fullName = url
   234  	}
   235  	return fullName
   236  }
   237  
   238  func (r fileResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
   239  	d := r.f.FindDescriptorByName(field)
   240  	if d != nil {
   241  		if extd, ok := d.(protoreflect.ExtensionTypeDescriptor); ok {
   242  			return extd.Type(), nil
   243  		}
   244  		if fld, ok := d.(protoreflect.FieldDescriptor); ok && fld.IsExtension() {
   245  			return dynamicpb.NewExtensionType(fld), nil
   246  		}
   247  		return nil, protoregistry.NotFound
   248  	}
   249  	return r.deps.FindExtensionByName(field)
   250  }
   251  
   252  func (r fileResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
   253  	ext := findExtension(r.f, message, field)
   254  	if ext != nil {
   255  		return ext.Type(), nil
   256  	}
   257  	return r.deps.FindExtensionByNumber(message, field)
   258  }
   259  
   260  type filesResolver []File
   261  
   262  func (r filesResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
   263  	for _, f := range r {
   264  		if f.Path() == path {
   265  			return f, nil
   266  		}
   267  	}
   268  	return nil, protoregistry.NotFound
   269  }
   270  
   271  func (r filesResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
   272  	for _, f := range r {
   273  		result := f.FindDescriptorByName(name)
   274  		if result != nil {
   275  			return result, nil
   276  		}
   277  	}
   278  	return nil, protoregistry.NotFound
   279  }
   280  
   281  func (r filesResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
   282  	for _, f := range r {
   283  		d := f.FindDescriptorByName(message)
   284  		if d != nil {
   285  			if md, ok := d.(protoreflect.MessageDescriptor); ok {
   286  				return dynamicpb.NewMessageType(md), nil
   287  			}
   288  			return nil, protoregistry.NotFound
   289  		}
   290  	}
   291  	return nil, protoregistry.NotFound
   292  }
   293  
   294  func (r filesResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) {
   295  	name := messageNameFromUrl(url)
   296  	return r.FindMessageByName(protoreflect.FullName(name))
   297  }
   298  
   299  func (r filesResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
   300  	for _, f := range r {
   301  		d := f.FindDescriptorByName(field)
   302  		if d != nil {
   303  			if extd, ok := d.(protoreflect.ExtensionTypeDescriptor); ok {
   304  				return extd.Type(), nil
   305  			}
   306  			if fld, ok := d.(protoreflect.FieldDescriptor); ok && fld.IsExtension() {
   307  				return dynamicpb.NewExtensionType(fld), nil
   308  			}
   309  			return nil, protoregistry.NotFound
   310  		}
   311  	}
   312  	return nil, protoregistry.NotFound
   313  }
   314  
   315  func (r filesResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
   316  	for _, f := range r {
   317  		ext := findExtension(f, message, field)
   318  		if ext != nil {
   319  			return ext.Type(), nil
   320  		}
   321  	}
   322  	return nil, protoregistry.NotFound
   323  }
   324  
   325  type hasExtensionsAndMessages interface {
   326  	Messages() protoreflect.MessageDescriptors
   327  	Extensions() protoreflect.ExtensionDescriptors
   328  }
   329  
   330  func findExtension(d hasExtensionsAndMessages, message protoreflect.FullName, field protoreflect.FieldNumber) protoreflect.ExtensionTypeDescriptor {
   331  	for i := 0; i < d.Extensions().Len(); i++ {
   332  		if extType := isExtensionMatch(d.Extensions().Get(i), message, field); extType != nil {
   333  			return extType
   334  		}
   335  	}
   336  
   337  	for i := 0; i < d.Messages().Len(); i++ {
   338  		if extType := findExtension(d.Messages().Get(i), message, field); extType != nil {
   339  			return extType
   340  		}
   341  	}
   342  
   343  	return nil // could not be found
   344  }
   345  
   346  func isExtensionMatch(ext protoreflect.ExtensionDescriptor, message protoreflect.FullName, field protoreflect.FieldNumber) protoreflect.ExtensionTypeDescriptor {
   347  	if ext.Number() != field || ext.ContainingMessage().FullName() != message {
   348  		return nil
   349  	}
   350  	if extType, ok := ext.(protoreflect.ExtensionTypeDescriptor); ok {
   351  		return extType
   352  	}
   353  	return dynamicpb.NewExtensionType(ext).TypeDescriptor()
   354  }