github.com/s7techlab/cckit@v0.10.5/gateway/protoc-gen-cc-gateway/generator/generator.go (about)

     1  package generator
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/format"
     7  	"path"
     8  	"path/filepath"
     9  	"strings"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
    13  	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
    14  )
    15  
    16  var (
    17  	pkg = make(map[string]string)
    18  )
    19  
    20  type Generator struct {
    21  	reg     *descriptor.Registry
    22  	imports []descriptor.GoPackage // common imports
    23  	Opts    Opts
    24  }
    25  
    26  // New returns a new generator which generates handler wrappers.
    27  func New(reg *descriptor.Registry) *Generator {
    28  	return &Generator{
    29  		reg:  reg,
    30  		Opts: Opts{},
    31  	}
    32  }
    33  
    34  func (g *Generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
    35  	var files []*plugin.CodeGeneratorResponse_File
    36  	for _, file := range targets {
    37  		if len(file.Services) == 0 {
    38  			continue
    39  		}
    40  
    41  		if code, err := g.generateCC(file); err == nil {
    42  			files = append(files, code)
    43  		} else {
    44  			return nil, err
    45  		}
    46  	}
    47  
    48  	return files, nil
    49  }
    50  
    51  func (g *Generator) generateCC(file *descriptor.File) (*plugin.CodeGeneratorResponse_File, error) {
    52  	code, err := g.getCCTemplate(file)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	formatted, err := format.Source([]byte(code))
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	name := filepath.Base(file.GetName())
    63  	ext := filepath.Ext(name)
    64  	base := strings.TrimSuffix(name, ext)
    65  
    66  	basePath := path.Dir(*file.FileDescriptorProto.Name)
    67  	if !g.Opts.PathsSourceRelative {
    68  		basePath = file.GoPkg.Path
    69  	}
    70  
    71  	output := fmt.Sprintf(filepath.Join(basePath, "%s.pb.cc.go"), base)
    72  	output = filepath.Clean(output)
    73  
    74  	return &plugin.CodeGeneratorResponse_File{
    75  		Name:    proto.String(output),
    76  		Content: proto.String(string(formatted)),
    77  	}, nil
    78  }
    79  
    80  func (g *Generator) getCCTemplate(f *descriptor.File) (string, error) {
    81  	pkgSeen := make(map[string]bool)
    82  	var imports []descriptor.GoPackage
    83  	for _, pkg := range g.imports {
    84  		pkgSeen[pkg.Path] = true
    85  		imports = append(imports, pkg)
    86  	}
    87  
    88  	pkgs := [][]string{
    89  		{"context", "context"},
    90  		{"github.com/s7techlab/cckit/gateway", "cckit_gateway"},
    91  		{"github.com/s7techlab/cckit/router", "cckit_router"},
    92  		{"github.com/s7techlab/cckit/router/param/defparam", "cckit_defparam"},
    93  		{"github.com/s7techlab/cckit/sdk", "cckit_sdk"},
    94  	}
    95  
    96  	if g.Opts.EmbedSwagger {
    97  		pkgs = append(pkgs, []string{"embed", "_"})
    98  	}
    99  
   100  	if g.Opts.ServiceChaincodeResolver {
   101  		pkgs = append(pkgs, []string{"errors", "errors"})
   102  	}
   103  
   104  	for _, pkg := range pkgs {
   105  		pkgSeen[pkg[0]] = true
   106  		imports = append(imports, g.newGoPackage(pkg[0], pkg[1]))
   107  	}
   108  
   109  	for _, svc := range f.Services {
   110  		for _, m := range svc.Methods {
   111  			checkedAppend := func(pkg descriptor.GoPackage) {
   112  				// Add request type package to imports if needed
   113  				if m.Options == nil || pkg == f.GoPkg || pkgSeen[pkg.Path] {
   114  					return
   115  				}
   116  				pkgSeen[pkg.Path] = true
   117  
   118  				// always generate alias for external packages, when types used in req/resp object
   119  				//if pkg.Alias == "" {
   120  				//	pkg.Alias = pkg.Name
   121  				//	pkgSeen[pkg.Path] = false
   122  				//}
   123  
   124  				imports = append(imports, pkg)
   125  			}
   126  
   127  			checkedAppend(m.RequestType.File.GoPkg)
   128  			checkedAppend(m.ResponseType.File.GoPkg)
   129  		}
   130  	}
   131  
   132  	p := TemplateParams{
   133  		File:    f,
   134  		Imports: imports,
   135  		Opts:    g.Opts,
   136  	}
   137  	return applyTemplate(p)
   138  }
   139  
   140  func (g *Generator) newGoPackage(pkgPath string, aalias ...string) descriptor.GoPackage {
   141  	gopkg := descriptor.GoPackage{
   142  		Path: pkgPath,
   143  		Name: path.Base(pkgPath),
   144  	}
   145  	alias := gopkg.Name
   146  	if len(aalias) > 0 {
   147  		alias = aalias[0]
   148  		gopkg.Alias = alias
   149  	}
   150  
   151  	reference := alias
   152  	if reference == "" {
   153  		reference = gopkg.Name
   154  	}
   155  
   156  	for i := 0; ; i++ {
   157  		if err := g.reg.ReserveGoPackageAlias(alias, gopkg.Path); err == nil {
   158  			break
   159  		}
   160  		alias = fmt.Sprintf("%s_%d", gopkg.Name, i)
   161  		gopkg.Alias = alias
   162  	}
   163  
   164  	pkg[reference] = alias
   165  
   166  	return gopkg
   167  }
   168  
   169  func applyTemplate(p TemplateParams) (string, error) {
   170  	w := bytes.NewBuffer(nil)
   171  	if err := headerTemplate.Execute(w, p); err != nil {
   172  		return "", err
   173  	}
   174  
   175  	if err := ccTemplate.Execute(w, p); err != nil {
   176  		return "", err
   177  	}
   178  
   179  	if err := gatewayTemplate.Execute(w, p); err != nil {
   180  		return "", err
   181  	}
   182  
   183  	if p.Opts.ServiceChaincodeResolver {
   184  		if err := resolverTemplate.Execute(w, p); err != nil {
   185  			return "", err
   186  		}
   187  	}
   188  
   189  	return w.String(), nil
   190  }