go.uber.org/yarpc@v1.72.1/internal/protoplugin-v2/registry.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package protopluginv2
    22  
    23  import (
    24  	"fmt"
    25  	"path"
    26  	"path/filepath"
    27  	"strings"
    28  
    29  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    30  	"github.com/golang/protobuf/protoc-gen-go/plugin"
    31  )
    32  
    33  type registry struct {
    34  	// msgs is a mapping from fully-qualified message name to descriptor
    35  	msgs map[string]*Message
    36  	// enums is a mapping from fully-qualified enum name to descriptor
    37  	enums map[string]*Enum
    38  	// files is a mapping from file path to descriptor
    39  	files map[string]*File
    40  	// prefix is a prefix to be inserted to golang package paths generated from proto package names.
    41  	prefix string
    42  	// pkgMap is a user-specified mapping from file path to proto package.
    43  	pkgMap map[string]string
    44  	// pkgAliases is a mapping from package aliases to package paths in go which are already taken.
    45  	pkgAliases map[string]string
    46  }
    47  
    48  func newRegistry() *registry {
    49  	return &registry{
    50  		msgs:       make(map[string]*Message),
    51  		enums:      make(map[string]*Enum),
    52  		files:      make(map[string]*File),
    53  		pkgMap:     make(map[string]string),
    54  		pkgAliases: make(map[string]string),
    55  	}
    56  }
    57  
    58  func (r *registry) Load(req *plugin_go.CodeGeneratorRequest) error {
    59  	for _, file := range req.GetProtoFile() {
    60  		r.loadFile(file)
    61  	}
    62  	var targetPkg string
    63  	for _, name := range req.FileToGenerate {
    64  		target := r.files[name]
    65  		if target == nil {
    66  			return fmt.Errorf("no such file: %s", name)
    67  		}
    68  		name := packageIdentityName(target.FileDescriptorProto)
    69  		if targetPkg == "" {
    70  			targetPkg = name
    71  		} else {
    72  			if targetPkg != name {
    73  				return fmt.Errorf("inconsistent package names: %s %s", targetPkg, name)
    74  			}
    75  		}
    76  		if err := r.loadServices(target); err != nil {
    77  			return err
    78  		}
    79  		if err := r.loadTransitiveFileDependencies(target); err != nil {
    80  			return err
    81  		}
    82  	}
    83  	return nil
    84  }
    85  
    86  func (r *registry) LookupMessage(location string, name string) (*Message, error) {
    87  	if strings.HasPrefix(name, ".") {
    88  		m, ok := r.msgs[name]
    89  		if !ok {
    90  			return nil, fmt.Errorf("no message found: %s", name)
    91  		}
    92  		return m, nil
    93  	}
    94  
    95  	if !strings.HasPrefix(location, ".") {
    96  		location = fmt.Sprintf(".%s", location)
    97  	}
    98  	components := strings.Split(location, ".")
    99  	for len(components) > 0 {
   100  		fqmn := strings.Join(append(components, name), ".")
   101  		if m, ok := r.msgs[fqmn]; ok {
   102  			return m, nil
   103  		}
   104  		components = components[:len(components)-1]
   105  	}
   106  	return nil, fmt.Errorf("no message found: %s", name)
   107  }
   108  
   109  func (r *registry) LookupFile(name string) (*File, error) {
   110  	f, ok := r.files[name]
   111  	if !ok {
   112  		return nil, fmt.Errorf("no such file given: %s", name)
   113  	}
   114  	return f, nil
   115  }
   116  
   117  func (r *registry) AddPackageMap(file, protoPackage string) {
   118  	r.pkgMap[file] = protoPackage
   119  }
   120  
   121  func (r *registry) SetPrefix(prefix string) {
   122  	r.prefix = prefix
   123  }
   124  
   125  func (r *registry) ReserveGoPackageAlias(alias, pkgpath string) error {
   126  	if taken, ok := r.pkgAliases[alias]; ok {
   127  		if taken == pkgpath {
   128  			return nil
   129  		}
   130  		return fmt.Errorf("package name %s is already taken. Use another alias", alias)
   131  	}
   132  	r.pkgAliases[alias] = pkgpath
   133  	return nil
   134  }
   135  
   136  // loadFile loads messages, enumerations and fields from "file".
   137  // It does not loads services and methods in "file".  You need to call
   138  // loadServices after loadFiles is called for all files to load services and methods.
   139  func (r *registry) loadFile(file *descriptor.FileDescriptorProto) {
   140  	pkg := &GoPackage{
   141  		Path: r.goPackagePath(file),
   142  		Name: defaultGoPackageName(file),
   143  	}
   144  	if err := r.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
   145  		for i := 0; ; i++ {
   146  			alias := fmt.Sprintf("%s_%d", pkg.Name, i)
   147  			if err := r.ReserveGoPackageAlias(alias, pkg.Path); err == nil {
   148  				pkg.Alias = alias
   149  				break
   150  			}
   151  		}
   152  	}
   153  	f := &File{
   154  		FileDescriptorProto: file,
   155  		GoPackage:           pkg,
   156  	}
   157  	r.files[file.GetName()] = f
   158  	r.registerMsg(f, nil, file.GetMessageType())
   159  	r.registerEnum(f, nil, file.GetEnumType())
   160  }
   161  
   162  func (r *registry) registerMsg(file *File, outerPath []string, msgs []*descriptor.DescriptorProto) {
   163  	for i, md := range msgs {
   164  		m := &Message{
   165  			DescriptorProto: md,
   166  			File:            file,
   167  			Outers:          outerPath,
   168  			Index:           i,
   169  		}
   170  		for _, fd := range md.GetField() {
   171  			m.Fields = append(m.Fields, &Field{
   172  				FieldDescriptorProto: fd,
   173  				Message:              m,
   174  			})
   175  		}
   176  		file.Messages = append(file.Messages, m)
   177  		r.msgs[m.FQMN()] = m
   178  
   179  		var outers []string
   180  		outers = append(outers, outerPath...)
   181  		outers = append(outers, m.GetName())
   182  		r.registerMsg(file, outers, m.GetNestedType())
   183  		r.registerEnum(file, outers, m.GetEnumType())
   184  	}
   185  }
   186  
   187  func (r *registry) registerEnum(file *File, outerPath []string, enums []*descriptor.EnumDescriptorProto) {
   188  	for i, ed := range enums {
   189  		e := &Enum{
   190  			EnumDescriptorProto: ed,
   191  			File:                file,
   192  			Outers:              outerPath,
   193  			Index:               i,
   194  		}
   195  		file.Enums = append(file.Enums, e)
   196  		r.enums[e.FQEN()] = e
   197  	}
   198  }
   199  
   200  // goPackagePath returns the go package path which go files generated from "f" should have.
   201  // It respects the mapping registered by AddPkgMap if exists. Or use go_package as import path
   202  // if it includes a slash,  Otherwide, it generates a path from the file name of "f".
   203  func (r *registry) goPackagePath(f *descriptor.FileDescriptorProto) string {
   204  	name := f.GetName()
   205  	if pkg, ok := r.pkgMap[name]; ok {
   206  		return path.Join(r.prefix, pkg)
   207  	}
   208  	gopkg := f.Options.GetGoPackage()
   209  	idx := strings.LastIndex(gopkg, "/")
   210  	if idx >= 0 {
   211  		return gopkg
   212  	}
   213  	return path.Join(r.prefix, path.Dir(name))
   214  }
   215  
   216  // loadServices registers services and their methods from "targetFile" to "r".
   217  // It must be called after loadFile is called for all files so that loadServices
   218  // can resolve names of message types and their fields.
   219  func (r *registry) loadServices(file *File) error {
   220  	var svcs []*Service
   221  	for _, sd := range file.GetService() {
   222  		svc := &Service{
   223  			ServiceDescriptorProto: sd,
   224  			File:                   file,
   225  		}
   226  		for _, md := range sd.GetMethod() {
   227  			meth, err := r.newMethod(svc, md)
   228  			if err != nil {
   229  				return err
   230  			}
   231  			svc.Methods = append(svc.Methods, meth)
   232  		}
   233  		if len(svc.Methods) == 0 {
   234  			continue
   235  		}
   236  		svcs = append(svcs, svc)
   237  	}
   238  	file.Services = svcs
   239  	return nil
   240  }
   241  
   242  func (r *registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto) (*Method, error) {
   243  	requestType, err := r.LookupMessage(svc.File.GetPackage(), md.GetInputType())
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	responseType, err := r.LookupMessage(svc.File.GetPackage(), md.GetOutputType())
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	return &Method{
   252  		MethodDescriptorProto: md,
   253  		Service:               svc,
   254  		RequestType:           requestType,
   255  		ResponseType:          responseType,
   256  	}, nil
   257  }
   258  
   259  // loadTransitiveFileDependencies registers services and their methods from "targetFile" to "r".
   260  // It must be called after loadFile is called for all files so that loadTransitiveFileDependencies
   261  // can resolve file descriptors as depdendencies.
   262  func (r *registry) loadTransitiveFileDependencies(file *File) error {
   263  	seen := make(map[string]struct{})
   264  	files, err := r.loadTransitiveFileDependenciesRecurse(file, seen)
   265  	if err != nil {
   266  		return err
   267  	}
   268  	file.TransitiveDependencies = files
   269  	return nil
   270  }
   271  
   272  func (r *registry) loadTransitiveFileDependenciesRecurse(file *File, seen map[string]struct{}) ([]*File, error) {
   273  	seen[file.GetName()] = struct{}{}
   274  	var deps []*File
   275  	for _, fname := range file.GetDependency() {
   276  		if _, ok := seen[fname]; ok {
   277  			continue
   278  		}
   279  		f, err := r.LookupFile(fname)
   280  		if err != nil {
   281  			return nil, err
   282  		}
   283  		deps = append(deps, f)
   284  
   285  		files, err := r.loadTransitiveFileDependenciesRecurse(f, seen)
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  		deps = append(deps, files...)
   290  	}
   291  	return deps, nil
   292  }
   293  
   294  // defaultGoPackageName returns the default go package name to be used for go files generated from "f".
   295  // You might need to use an unique alias for the package when you import it.  Use ReserveGoPackageAlias to get a unique alias.
   296  func defaultGoPackageName(f *descriptor.FileDescriptorProto) string {
   297  	name := packageIdentityName(f)
   298  	return strings.Replace(name, ".", "_", -1)
   299  }
   300  
   301  // packageIdentityName returns the identity of packages.
   302  // protoc-gen-grpc-gateway rejects CodeGenerationRequests which contains more than one packages
   303  // as protoc-gen-go does.
   304  func packageIdentityName(f *descriptor.FileDescriptorProto) string {
   305  	if f.Options != nil && f.Options.GoPackage != nil {
   306  		gopkg := f.Options.GetGoPackage()
   307  		// if go_package specifies an alias in the form of full/path/package;alias, use alias over package
   308  		idx := strings.Index(gopkg, ";")
   309  		if idx >= 0 {
   310  			return gopkg[idx+1:]
   311  		}
   312  		idx = strings.LastIndex(gopkg, "/")
   313  		if idx < 0 {
   314  			return gopkg
   315  		}
   316  
   317  		return gopkg[idx+1:]
   318  	}
   319  
   320  	if f.Package == nil {
   321  		base := filepath.Base(f.GetName())
   322  		ext := filepath.Ext(base)
   323  		return strings.TrimSuffix(base, ext)
   324  	}
   325  	return f.GetPackage()
   326  }