go.uber.org/yarpc@v1.72.1/internal/protoplugin/generator.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package protoplugin
    22  
    23  import (
    24  	"bytes"
    25  	"errors"
    26  	"fmt"
    27  	"go/format"
    28  	"path"
    29  	"text/template"
    30  
    31  	"github.com/gogo/protobuf/proto"
    32  	"github.com/gogo/protobuf/protoc-gen-gogo/plugin"
    33  )
    34  
    35  var (
    36  	errNoTargetService = errors.New("no target service defined in the file")
    37  )
    38  
    39  type generator struct {
    40  	registry             *registry
    41  	tmpl                 *template.Template
    42  	templateInfoChecker  func(*TemplateInfo) error
    43  	baseImports          []*GoPackage
    44  	fileToOutputFilename func(*File) (string, error)
    45  }
    46  
    47  func newGenerator(
    48  	registry *registry,
    49  	tmpl *template.Template,
    50  	templateInfoChecker func(*TemplateInfo) error,
    51  	baseImportStrings []string,
    52  	fileToOutputFilename func(*File) (string, error),
    53  ) *generator {
    54  	var baseImports []*GoPackage
    55  	for _, pkgpath := range baseImportStrings {
    56  		pkg := &GoPackage{
    57  			Path: pkgpath,
    58  			Name: path.Base(pkgpath),
    59  		}
    60  		if err := registry.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
    61  			for i := 0; ; i++ {
    62  				alias := fmt.Sprintf("%s_%d", pkg.Name, i)
    63  				if err := registry.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
    64  					continue
    65  				}
    66  				pkg.Alias = alias
    67  				break
    68  			}
    69  		}
    70  		baseImports = append(baseImports, pkg)
    71  	}
    72  	return &generator{
    73  		registry,
    74  		tmpl,
    75  		templateInfoChecker,
    76  		baseImports,
    77  		fileToOutputFilename,
    78  	}
    79  }
    80  
    81  func (g *generator) Generate(targets []*File) ([]*plugin_go.CodeGeneratorResponse_File, error) {
    82  	var files []*plugin_go.CodeGeneratorResponse_File
    83  	for _, file := range targets {
    84  		code, err := g.generate(file)
    85  		if err == errNoTargetService {
    86  			continue
    87  		}
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  		formatted, err := format.Source([]byte(code))
    92  		if err != nil {
    93  			return nil, fmt.Errorf("could not format go code: %v\n%s", err, code)
    94  		}
    95  		output, err := g.fileToOutputFilename(file)
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  		files = append(files, &plugin_go.CodeGeneratorResponse_File{
   100  			Name:    proto.String(output),
   101  			Content: proto.String(string(formatted)),
   102  		})
   103  	}
   104  	return files, nil
   105  }
   106  
   107  func (g *generator) generate(file *File) (string, error) {
   108  	pkgSeen := make(map[string]bool)
   109  	var imports []*GoPackage
   110  	for _, pkg := range g.baseImports {
   111  		pkgSeen[pkg.Path] = true
   112  		imports = append(imports, pkg)
   113  	}
   114  	for _, svc := range file.Services {
   115  		for _, m := range svc.Methods {
   116  			for _, pkg := range []*GoPackage{m.RequestType.File.GoPackage, m.ResponseType.File.GoPackage} {
   117  				if pkg.Path == file.GoPackage.Path {
   118  					continue
   119  				}
   120  				if pkgSeen[pkg.Path] {
   121  					continue
   122  				}
   123  				pkgSeen[pkg.Path] = true
   124  				imports = append(imports, pkg)
   125  			}
   126  		}
   127  	}
   128  	templateInfo := &TemplateInfo{file, imports}
   129  	if err := g.templateInfoChecker(templateInfo); err != nil {
   130  		return "", err
   131  	}
   132  	buffer := bytes.NewBuffer(nil)
   133  	if err := g.tmpl.Execute(buffer, templateInfo); err != nil {
   134  		return "", err
   135  	}
   136  	return buffer.String(), nil
   137  }