
     1  // Copyright 2021 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package protoc
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"fmt"
    21  	"path"
    22  	"path/filepath"
    23  	"strings"
    24  	"text/template"
    26  	genfastpb ""
    27  	""
    28  	gengo ""
    29  	""
    31  	""
    32  	""
    33  	""
    34  )
    36  type protocPlugin struct {
    37  	generator.Config
    38  	generator.PackageInfo
    39  	Services    []*generator.ServiceInfo
    40  	kg          generator.Generator
    41  	err         error
    42  	importPaths map[string]string // file -> import path
    43  }
    45  // Name implements the protobuf_generator.Plugin interface.
    46  func (pp *protocPlugin) Name() string {
    47  	return "kitex-internal"
    48  }
    50  // Init implements the protobuf_generator.Plugin interface.
    51  func (pp *protocPlugin) init() {
    52  	pp.Dependencies = map[string]string{
    53  		"proto": "",
    54  	}
    55  }
    57  // parse the 'M*' option
    58  // See for more information.
    59  func (pp *protocPlugin) parseM() {
    60  	pp.importPaths = make(map[string]string)
    61  	for _, po := range pp.Config.ProtobufOptions {
    62  		if po == "" || po[0] != 'M' {
    63  			continue
    64  		}
    65  		idx := strings.Index(po, "=")
    66  		if idx < 0 {
    67  			continue
    68  		}
    69  		key := po[1:idx]
    70  		val := po[idx+1:]
    71  		if val == "" {
    72  			continue
    73  		}
    74  		idx = strings.Index(val, ";")
    75  		if idx >= 0 {
    76  			val = val[:idx]
    77  		}
    78  		pp.importPaths[key] = val
    79  	}
    80  }
    82  var interfaceTemplate = `
    84  // Code generated by Kitex {{.Version}}. DO NOT EDIT.
    86  {{range .Interfaces}}
    87  {{$serviceName := .Name}}
    88  type {{.Name}} interface {
    89  {{- range .Methods}}
    90  {{- if or .ClientStreaming .ServerStreaming}}
    91  	{{.Name}}({{if and .ServerStreaming (not .ClientStreaming)}}req {{.ReqType}}, {{end}}stream {{$serviceName}}_{{.Name}}Server) (err error)
    92  {{- else}}
    93  	{{.Name}}(ctx context.Context, req {{.ReqType}}) (res {{.ResType}}, err error)
    94  {{- end}}
    95  {{- end}}
    96  }
    98  {{range .Methods}}
    99  {{- if or .ClientStreaming .ServerStreaming}}
   100  type {{$serviceName}}_{{.Name}}Server interface {
   101  	streaming.Stream
   102  	{{- if .ClientStreaming}}
   103  	Recv() ({{.ReqType}}, error)
   104  	{{- end}}
   105  	{{- if .ServerStreaming}}
   106  	Send({{.ResType}}) error
   107  	{{- end}}
   108  	{{- if and .ClientStreaming (not .ServerStreaming)}}
   109  	SendAndClose({{.ResType}}) error
   110  	{{- end}}
   111  }
   112  {{- end}}
   113  {{end}}
   115  {{end}}
   116  `
   118  // Generate implements the protobuf_generator.Plugin interface.
   119  func (pp *protocPlugin) GenerateFile(gen *protogen.Plugin, file *protogen.File) {
   120  	if pp.err != nil {
   121  		return
   122  	}
   123  	gopkg := file.Proto.GetOptions().GetGoPackage()
   124  	if !strings.HasPrefix(gopkg, pp.PackagePrefix) {
   125  		log.Warnf("[WARN] %q is skipped because its import path %q is not located in ./kitex_gen. Change the go_package option or use '--protobuf M%s=A-Import-Path-In-kitex_gen' to override it if you want this file to be generated under kitex_gen.\n",
   126  			file.Proto.GetName(), gopkg, file.Proto.GetName())
   127  		return
   128  	}
   129  	log.Infof("[INFO] Generate %q at %q\n", file.Proto.GetName(), gopkg)
   131  	if parts := strings.Split(gopkg, ";"); len(parts) > 1 {
   132  		gopkg = parts[0] // remove package alias from file path
   133  	}
   134  	pp.Namespace = strings.TrimPrefix(gopkg, pp.PackagePrefix)
   135  	pp.IDLName = util.IDLName(pp.Config.IDL)
   137  	ss := pp.convertTypes(file)
   138  	pp.Services = append(pp.Services, ss...)
   140  	if pp.Config.Use != "" {
   141  		return
   142  	}
   144  	hasStreaming := false
   145  	// generate service package
   146  	for _, si := range ss {
   147  		pp.ServiceInfo = si
   148  		fs, err :=
   149  		if err != nil {
   150  			pp.err = err
   151  			return
   152  		}
   153  		if !hasStreaming && si.HasStreaming {
   154  			hasStreaming = true
   155  		}
   156  		for _, f := range fs {
   157  			gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content)
   158  		}
   159  	}
   160  	// generate service interface
   161  	if pp.err == nil {
   162  		fixed := *file
   163  		fixed.GeneratedFilenamePrefix = strings.TrimPrefix(fixed.GeneratedFilenamePrefix, pp.PackagePrefix)
   164  		f := gengo.GenerateFile(gen, &fixed)
   165  		f.QualifiedGoIdent(protogen.GoIdent{GoImportPath: "context"})
   166  		if hasStreaming {
   167  			f.QualifiedGoIdent(protogen.GoIdent{
   168  				GoImportPath: "",
   169  			})
   170  		}
   171  		f.P("var _ context.Context")
   173  		if len(file.Services) != 0 {
   174  			tpl := template.New("interface")
   175  			tpl = template.Must(tpl.Parse(interfaceTemplate))
   176  			var buf bytes.Buffer
   177  			pp.err = tpl.ExecuteTemplate(&buf, tpl.Name(), pp.makeInterfaces(f, file))
   179  			f.P(buf.String())
   180  		}
   181  	}
   183  	// generate fast api
   184  	if !pp.Config.NoFastAPI && pp.err == nil {
   185  		fixed := *file
   186  		fixed.GeneratedFilenamePrefix = strings.TrimPrefix(fixed.GeneratedFilenamePrefix, pp.PackagePrefix)
   187  		genfastpb.GenerateFile(gen, &fixed)
   188  	}
   189  }
   191  func (pp *protocPlugin) process(gen *protogen.Plugin) {
   192  	defer func() {
   193  		if e := recover(); e != nil {
   194  			if err, ok := e.(error); ok {
   195  				gen.Error(err)
   196  			} else {
   197  				gen.Error(fmt.Errorf("%+v", e))
   198  			}
   199  		}
   200  	}()
   201  	if len(gen.Files) == 0 {
   202  		gen.Error(errors.New("no proto file"))
   203  		return
   204  	}
   205 = generator.NewGenerator(&pp.Config, nil)
   206  	// iterate over all proto files
   207  	idl := gen.Request.FileToGenerate[0]
   208  	for _, f := range gen.Files {
   209  		if pp.Config.Use != "" && f.Proto.GetName() != idl {
   210  			continue
   211  		}
   212  		pp.GenerateFile(gen, f)
   213  	}
   215  	if pp.Config.GenerateMain {
   216  		if len(pp.Services) == 0 {
   217  			gen.Error(errors.New("no service defined"))
   218  			return
   219  		}
   220  		pp.ServiceInfo = pp.Services[len(pp.Services)-1]
   221  		fs, err :=
   222  		if err != nil {
   223  			pp.err = err
   224  		}
   225  		for _, f := range fs {
   226  			gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content)
   227  		}
   228  	}
   230  	if pp.Config.TemplateDir != "" {
   231  		if len(pp.Services) == 0 {
   232  			gen.Error(errors.New("no service defined"))
   233  			return
   234  		}
   235  		pp.ServiceInfo = pp.Services[len(pp.Services)-1]
   236  		fs, err :=
   237  		if err != nil {
   238  			pp.err = err
   239  		}
   240  		for _, f := range fs {
   241  			gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content)
   242  		}
   243  	}
   245  	if pp.err != nil {
   246  		gen.Error(pp.err)
   247  	}
   248  	return
   249  }
   251  func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.ServiceInfo) {
   252  	pth := pp.fixImport(string(file.GoImportPath))
   253  	if pth == "" {
   254  		panic(fmt.Errorf("missing %q option in %q", "go_package", file.Desc.Name()))
   255  	}
   256  	pi := generator.PkgInfo{
   257  		PkgName:    file.Proto.GetPackage(),
   258  		PkgRefName: goSanitized(path.Base(pth)),
   259  		ImportPath: pth,
   260  	}
   261  	for _, service := range file.Services {
   262  		si := &generator.ServiceInfo{
   263  			PkgInfo:        pi,
   264  			ServiceName:    service.GoName,
   265  			RawServiceName: string(service.Desc.Name()),
   266  		}
   267  		si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName }
   268  		for _, m := range service.Methods {
   269  			req := pp.convertParameter(m.Input, "Req")
   270  			res := pp.convertParameter(m.Output, "Resp")
   272  			methodName := m.GoName
   273  			mi := &generator.MethodInfo{
   274  				PkgInfo:            pi,
   275  				ServiceName:        si.ServiceName,
   276  				RawName:            string(m.Desc.Name()),
   277  				Name:               methodName,
   278  				Args:               []*generator.Parameter{req},
   279  				Resp:               res,
   280  				ArgStructName:      methodName + "Args",
   281  				ResStructName:      methodName + "Result",
   282  				GenArgResultStruct: true,
   283  				ClientStreaming:    m.Desc.IsStreamingClient(),
   284  				ServerStreaming:    m.Desc.IsStreamingServer(),
   285  			}
   286  			si.Methods = append(si.Methods, mi)
   287  			if !si.HasStreaming && (mi.ClientStreaming || mi.ServerStreaming) {
   288  				si.HasStreaming = true
   289  			}
   290  		}
   291  		for _, m := range si.Methods {
   292  			BuildStreaming(m, si.HasStreaming)
   293  		}
   294  		ss = append(ss, si)
   295  	}
   296  	// combine service
   297  	if pp.Config.CombineService && len(file.Services) > 0 {
   298  		var svcs []*generator.ServiceInfo
   299  		var methods []*generator.MethodInfo
   300  		for _, s := range ss {
   301  			svcs = append(svcs, s)
   302  			methods = append(methods, s.AllMethods()...)
   303  		}
   304  		// check method name conflict
   305  		mm := make(map[string]*generator.MethodInfo)
   306  		for _, m := range methods {
   307  			if _, ok := mm[m.Name]; ok {
   308  				log.Warnf("[WARN] combine service method %s in %s conflicts with %s in %s\n",
   309  					m.Name, m.ServiceName, m.Name, mm[m.Name].ServiceName)
   310  				return
   311  			}
   312  			mm[m.Name] = m
   313  		}
   314  		var hasStreaming bool
   315  		for _, m := range methods {
   316  			if m.ClientStreaming || m.ServerStreaming {
   317  				hasStreaming = true
   318  			}
   319  		}
   320  		svcName := pp.getCombineServiceName("CombineService", ss)
   321  		si := &generator.ServiceInfo{
   322  			PkgInfo:         pi,
   323  			ServiceName:     svcName,
   324  			RawServiceName:  svcName,
   325  			CombineServices: svcs,
   326  			Methods:         methods,
   327  			HasStreaming:    hasStreaming,
   328  		}
   329  		si.ServiceTypeName = func() string { return si.ServiceName }
   330  		ss = append(ss, si)
   331  	}
   332  	return
   333  }
   335  // BuildStreaming builds protobuf MethodInfo.Streaming as for Thrift, to simplify codegen
   336  func BuildStreaming(mi *generator.MethodInfo, serviceHasStreaming bool) {
   337  	s := &streaming.Streaming{
   338  		// pb: if one method is streaming, then the service is streaming, making all methods streaming
   339  		IsStreaming: serviceHasStreaming,
   340  	}
   341  	if mi.ClientStreaming && mi.ServerStreaming {
   342  		s.Mode = streaming.StreamingBidirectional
   343  		s.BidirectionalStreaming = true
   344  		s.ClientStreaming = true
   345  		s.ServerStreaming = true
   346  	} else if mi.ClientStreaming && !mi.ServerStreaming {
   347  		s.Mode = streaming.StreamingClientSide
   348  		s.ClientStreaming = true
   349  	} else if !mi.ClientStreaming && mi.ServerStreaming {
   350  		s.Mode = streaming.StreamingServerSide
   351  		s.ServerStreaming = true
   352  	} else if serviceHasStreaming {
   353  		s.Mode = streaming.StreamingUnary // Unary APIs over HTTP2
   354  	}
   355  	mi.Streaming = s
   356  }
   358  func (pp *protocPlugin) getCombineServiceName(name string, svcs []*generator.ServiceInfo) string {
   359  	for _, svc := range svcs {
   360  		if svc.ServiceName == name {
   361  			return pp.getCombineServiceName(name+"_", svcs)
   362  		}
   363  	}
   364  	return name
   365  }
   367  func (pp *protocPlugin) convertParameter(msg *protogen.Message, paramName string) *generator.Parameter {
   368  	importPath := pp.fixImport(msg.GoIdent.GoImportPath.String())
   369  	pkgRefName := goSanitized(path.Base(importPath))
   370  	res := &generator.Parameter{
   371  		Deps: []generator.PkgInfo{
   372  			{
   373  				PkgRefName: pkgRefName,
   374  				ImportPath: importPath,
   375  			},
   376  		},
   377  		Name:    paramName,
   378  		RawName: paramName,
   379  		Type:    "*" + pkgRefName + "." + msg.GoIdent.GoName,
   380  	}
   381  	return res
   382  }
   384  func (pp *protocPlugin) makeInterfaces(gf *protogen.GeneratedFile, file *protogen.File) interface{} {
   385  	var is []interface{}
   386  	for _, service := range file.Services {
   387  		i := struct {
   388  			Name    string
   389  			Methods []interface{}
   390  		}{
   391  			Name: service.GoName,
   392  		}
   393  		for _, m := range service.Methods {
   394  			i.Methods = append(i.Methods, struct {
   395  				Name            string
   396  				ReqType         string
   397  				ResType         string
   398  				ClientStreaming bool
   399  				ServerStreaming bool
   400  			}{
   401  				m.GoName,
   402  				"*" + gf.QualifiedGoIdent(m.Input.GoIdent),
   403  				"*" + gf.QualifiedGoIdent(m.Output.GoIdent),
   404  				m.Desc.IsStreamingClient(),
   405  				m.Desc.IsStreamingServer(),
   406  			})
   407  		}
   408  		is = append(is, i)
   409  	}
   410  	return struct {
   411  		Version    string
   412  		Interfaces []interface{}
   413  	}{pp.Config.Version, is}
   414  }
   416  func (pp *protocPlugin) adjustPath(path string) (ret string) {
   417  	cur, _ := filepath.Abs(".")
   418  	if pp.Config.Use == "" {
   419  		cur = util.JoinPath(cur, generator.KitexGenPath)
   420  	}
   421  	if filepath.IsAbs(path) {
   422  		path, _ = filepath.Rel(cur, path)
   423  		return path
   424  	}
   425  	if pp.ModuleName == "" {
   426  		gopath := util.GetGOPATH()
   427  		path = util.JoinPath(gopath, "src", path)
   428  		path, _ = filepath.Rel(cur, path)
   429  	} else {
   430  		path, _ = filepath.Rel(pp.ModuleName, path)
   431  	}
   432  	return path
   433  }
   435  func (pp *protocPlugin) fixImport(path string) string {
   436  	path = strings.Trim(path, "\"")
   437  	return path
   438  }