github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/generator/generator.go (about)

     1  // Copyright 2021 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package generator .
    16  package generator
    17  
    18  import (
    19  	"fmt"
    20  	"go/token"
    21  	"path/filepath"
    22  	"reflect"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/cloudwego/kitex/tool/internal_pkg/log"
    28  	"github.com/cloudwego/kitex/tool/internal_pkg/tpl"
    29  	"github.com/cloudwego/kitex/tool/internal_pkg/util"
    30  	"github.com/cloudwego/kitex/transport"
    31  )
    32  
    33  // Constants .
    34  const (
    35  	KitexGenPath = "kitex_gen"
    36  	DefaultCodec = "thrift"
    37  
    38  	BuildFileName       = "build.sh"
    39  	BootstrapFileName   = "bootstrap.sh"
    40  	ToolVersionFileName = "kitex_info.yaml"
    41  	HandlerFileName     = "handler.go"
    42  	MainFileName        = "main.go"
    43  	ClientFileName      = "client.go"
    44  	ServerFileName      = "server.go"
    45  	InvokerFileName     = "invoker.go"
    46  	ServiceFileName     = "*service.go"
    47  	ExtensionFilename   = "extensions.yaml"
    48  
    49  	DefaultThriftPluginTimeLimit = time.Minute
    50  )
    51  
    52  var (
    53  	kitexImportPath = "github.com/cloudwego/kitex"
    54  
    55  	globalMiddlewares  []Middleware
    56  	globalDependencies = map[string]string{
    57  		"kitex":   kitexImportPath,
    58  		"client":  ImportPathTo("client"),
    59  		"server":  ImportPathTo("server"),
    60  		"callopt": ImportPathTo("client/callopt"),
    61  		"frugal":  "github.com/cloudwego/frugal",
    62  	}
    63  )
    64  
    65  // SetKitexImportPath sets the import path of kitex.
    66  // Must be called before generating code.
    67  func SetKitexImportPath(path string) {
    68  	for k, v := range globalDependencies {
    69  		globalDependencies[k] = strings.ReplaceAll(v, kitexImportPath, path)
    70  	}
    71  	kitexImportPath = path
    72  }
    73  
    74  // ImportPathTo returns an import path to the specified package under kitex.
    75  func ImportPathTo(pkg string) string {
    76  	return util.JoinPath(kitexImportPath, pkg)
    77  }
    78  
    79  // AddGlobalMiddleware adds middleware for all generators
    80  func AddGlobalMiddleware(mw Middleware) {
    81  	globalMiddlewares = append(globalMiddlewares, mw)
    82  }
    83  
    84  // AddGlobalDependency adds dependency for all generators
    85  func AddGlobalDependency(ref, path string) bool {
    86  	if _, ok := globalDependencies[ref]; !ok {
    87  		globalDependencies[ref] = path
    88  		return true
    89  	}
    90  	return false
    91  }
    92  
    93  // Generator generates the codes of main package and scripts for building a server based on kitex.
    94  type Generator interface {
    95  	GenerateService(pkg *PackageInfo) ([]*File, error)
    96  	GenerateMainPackage(pkg *PackageInfo) ([]*File, error)
    97  	GenerateCustomPackage(pkg *PackageInfo) ([]*File, error)
    98  }
    99  
   100  // Config .
   101  type Config struct {
   102  	Verbose               bool
   103  	GenerateMain          bool // whether stuff in the main package should be generated
   104  	GenerateInvoker       bool // generate main.go with invoker when main package generate
   105  	Version               string
   106  	NoFastAPI             bool
   107  	ModuleName            string
   108  	ServiceName           string
   109  	Use                   string
   110  	IDLType               string
   111  	Includes              util.StringSlice
   112  	ThriftOptions         util.StringSlice
   113  	ProtobufOptions       util.StringSlice
   114  	Hessian2Options       util.StringSlice
   115  	IDL                   string // the IDL file passed on the command line
   116  	OutputPath            string // the output path for main pkg and kitex_gen
   117  	PackagePrefix         string
   118  	CombineService        bool // combine services to one service
   119  	CopyIDL               bool
   120  	ThriftPlugins         util.StringSlice
   121  	ProtobufPlugins       util.StringSlice
   122  	Features              []feature
   123  	FrugalPretouch        bool
   124  	ThriftPluginTimeLimit time.Duration
   125  	CompilerPath          string // specify the path of thriftgo or protoc
   126  
   127  	ExtensionFile string
   128  	tmplExt       *TemplateExtension
   129  
   130  	Record    bool
   131  	RecordCmd []string
   132  
   133  	TemplateDir string
   134  
   135  	GenPath string
   136  
   137  	DeepCopyAPI           bool
   138  	Protocol              string
   139  	HandlerReturnKeepResp bool
   140  }
   141  
   142  // Pack packs the Config into a slice of "key=val" strings.
   143  func (c *Config) Pack() (res []string) {
   144  	t := reflect.TypeOf(c).Elem()
   145  	v := reflect.ValueOf(c).Elem()
   146  	for i := 0; i < t.NumField(); i++ {
   147  		f := t.Field(i)
   148  		x := v.Field(i)
   149  		n := f.Name
   150  
   151  		// skip the plugin arguments to avoid the 'strings in strings' trouble
   152  		if f.Name == "ThriftPlugins" || !token.IsExported(f.Name) {
   153  			continue
   154  		}
   155  
   156  		if str, ok := x.Interface().(interface{ String() string }); ok {
   157  			res = append(res, n+"="+str.String())
   158  			continue
   159  		}
   160  
   161  		switch x.Kind() {
   162  		case reflect.Bool:
   163  			res = append(res, n+"="+fmt.Sprint(x.Bool()))
   164  		case reflect.String:
   165  			res = append(res, n+"="+x.String())
   166  		case reflect.Slice:
   167  			var ss []string
   168  			if x.Type().Elem().Kind() == reflect.Int {
   169  				for i := 0; i < x.Len(); i++ {
   170  					ss = append(ss, strconv.Itoa(int(x.Index(i).Int())))
   171  				}
   172  			} else {
   173  				for i := 0; i < x.Len(); i++ {
   174  					ss = append(ss, x.Index(i).String())
   175  				}
   176  			}
   177  			res = append(res, n+"="+strings.Join(ss, ";"))
   178  		default:
   179  			panic(fmt.Errorf("unsupported field type: %+v", f))
   180  		}
   181  	}
   182  	return res
   183  }
   184  
   185  // Unpack restores the Config from a slice of "key=val" strings.
   186  func (c *Config) Unpack(args []string) error {
   187  	t := reflect.TypeOf(c).Elem()
   188  	v := reflect.ValueOf(c).Elem()
   189  	for _, a := range args {
   190  		parts := strings.SplitN(a, "=", 2)
   191  		if len(parts) != 2 {
   192  			return fmt.Errorf("invalid argument: '%s'", a)
   193  		}
   194  		name, value := parts[0], parts[1]
   195  		f, ok := t.FieldByName(name)
   196  		if ok && value != "" {
   197  			x := v.FieldByName(name)
   198  			if _, ok := x.Interface().(time.Duration); ok {
   199  				if d, err := time.ParseDuration(value); err != nil {
   200  					return fmt.Errorf("invalid time duration '%s' for %s", value, name)
   201  				} else {
   202  					x.SetInt(int64(d))
   203  				}
   204  				continue
   205  			}
   206  			switch x.Kind() {
   207  			case reflect.Bool:
   208  				x.SetBool(value == "true")
   209  			case reflect.String:
   210  				x.SetString(value)
   211  			case reflect.Slice:
   212  				ss := strings.Split(value, ";")
   213  				if x.Type().Elem().Kind() == reflect.Int {
   214  					n := reflect.MakeSlice(x.Type(), len(ss), len(ss))
   215  					for i, s := range ss {
   216  						val, err := strconv.ParseInt(s, 10, 64)
   217  						if err != nil {
   218  							return err
   219  						}
   220  						n.Index(i).SetInt(val)
   221  					}
   222  					x.Set(n)
   223  				} else {
   224  					for _, s := range ss {
   225  						val := reflect.Append(x, reflect.ValueOf(s))
   226  						x.Set(val)
   227  					}
   228  				}
   229  			default:
   230  				return fmt.Errorf("unsupported field type: %+v", f)
   231  			}
   232  		}
   233  	}
   234  	log.Verbose = c.Verbose
   235  	return c.ApplyExtension()
   236  }
   237  
   238  // AddFeature add registered feature to config
   239  func (c *Config) AddFeature(key string) bool {
   240  	if f, ok := getFeature(key); ok {
   241  		c.Features = append(c.Features, f)
   242  		return true
   243  	}
   244  	return false
   245  }
   246  
   247  // ApplyExtension applies template extension.
   248  func (c *Config) ApplyExtension() error {
   249  	templateExtExist := false
   250  	path := util.JoinPath(c.TemplateDir, ExtensionFilename)
   251  	if c.TemplateDir != "" && util.Exists(path) {
   252  		templateExtExist = true
   253  	}
   254  
   255  	if c.ExtensionFile == "" && !templateExtExist {
   256  		return nil
   257  	}
   258  
   259  	ext := new(TemplateExtension)
   260  	if c.ExtensionFile != "" {
   261  		if err := ext.FromYAMLFile(c.ExtensionFile); err != nil {
   262  			return fmt.Errorf("read template extension %q failed: %s", c.ExtensionFile, err.Error())
   263  		}
   264  	}
   265  
   266  	if templateExtExist {
   267  		yamlExt := new(TemplateExtension)
   268  		if err := yamlExt.FromYAMLFile(path); err != nil {
   269  			return fmt.Errorf("read template extension %q failed: %s", path, err.Error())
   270  		}
   271  		ext.Merge(yamlExt)
   272  	}
   273  
   274  	for _, fn := range ext.FeatureNames {
   275  		RegisterFeature(fn)
   276  	}
   277  	for _, fn := range ext.EnableFeatures {
   278  		c.AddFeature(fn)
   279  	}
   280  	for path, alias := range ext.Dependencies {
   281  		AddGlobalDependency(alias, path)
   282  	}
   283  
   284  	c.tmplExt = ext
   285  	return nil
   286  }
   287  
   288  // NewGenerator .
   289  func NewGenerator(config *Config, middlewares []Middleware) Generator {
   290  	mws := append(globalMiddlewares, middlewares...)
   291  	g := &generator{Config: config, middlewares: mws}
   292  	if g.IDLType == "" {
   293  		g.IDLType = DefaultCodec
   294  	}
   295  	return g
   296  }
   297  
   298  // Middleware used generator
   299  type Middleware func(HandleFunc) HandleFunc
   300  
   301  // HandleFunc used generator
   302  type HandleFunc func(*Task, *PackageInfo) (*File, error)
   303  
   304  type generator struct {
   305  	*Config
   306  	middlewares []Middleware
   307  }
   308  
   309  func (g *generator) chainMWs(handle HandleFunc) HandleFunc {
   310  	for i := len(g.middlewares) - 1; i > -1; i-- {
   311  		handle = g.middlewares[i](handle)
   312  	}
   313  	return handle
   314  }
   315  
   316  func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error) {
   317  	g.updatePackageInfo(pkg)
   318  
   319  	tasks := []*Task{
   320  		{
   321  			Name: BuildFileName,
   322  			Path: util.JoinPath(g.OutputPath, BuildFileName),
   323  			Text: tpl.BuildTpl,
   324  		},
   325  		{
   326  			Name: BootstrapFileName,
   327  			Path: util.JoinPath(g.OutputPath, "script", BootstrapFileName),
   328  			Text: tpl.BootstrapTpl,
   329  		},
   330  		{
   331  			Name: ToolVersionFileName,
   332  			Path: util.JoinPath(g.OutputPath, ToolVersionFileName),
   333  			Text: tpl.ToolVersionTpl,
   334  		},
   335  	}
   336  	if !g.Config.GenerateInvoker {
   337  		tasks = append(tasks, &Task{
   338  			Name: MainFileName,
   339  			Path: util.JoinPath(g.OutputPath, MainFileName),
   340  			Text: tpl.MainTpl,
   341  		})
   342  	}
   343  	for _, t := range tasks {
   344  		if util.Exists(t.Path) {
   345  			log.Info(t.Path, "exists. Skipped.")
   346  			continue
   347  		}
   348  		g.setImports(t.Name, pkg)
   349  		handle := func(task *Task, pkg *PackageInfo) (*File, error) {
   350  			return task.Render(pkg)
   351  		}
   352  		f, err := g.chainMWs(handle)(t, pkg)
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  		fs = append(fs, f)
   357  	}
   358  
   359  	handlerFilePath := filepath.Join(g.OutputPath, HandlerFileName)
   360  	if util.Exists(handlerFilePath) {
   361  		comp := newCompleter(
   362  			pkg.ServiceInfo.AllMethods(),
   363  			handlerFilePath,
   364  			pkg.ServiceInfo.ServiceName)
   365  		f, err := comp.CompleteMethods()
   366  		if err != nil {
   367  			if err == errNoNewMethod {
   368  				return fs, nil
   369  			}
   370  			return nil, err
   371  		}
   372  		fs = append(fs, f)
   373  	} else {
   374  		task := Task{
   375  			Name: HandlerFileName,
   376  			Path: handlerFilePath,
   377  			Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl,
   378  		}
   379  		g.setImports(task.Name, pkg)
   380  		handle := func(task *Task, pkg *PackageInfo) (*File, error) {
   381  			return task.Render(pkg)
   382  		}
   383  		f, err := g.chainMWs(handle)(&task, pkg)
   384  		if err != nil {
   385  			return nil, err
   386  		}
   387  		fs = append(fs, f)
   388  	}
   389  	return
   390  }
   391  
   392  func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) {
   393  	g.updatePackageInfo(pkg)
   394  	output := util.JoinPath(g.OutputPath, util.CombineOutputPath(g.GenPath, pkg.Namespace))
   395  	svcPkg := strings.ToLower(pkg.ServiceName)
   396  	output = util.JoinPath(output, svcPkg)
   397  	ext := g.tmplExt
   398  	if ext == nil {
   399  		ext = new(TemplateExtension)
   400  	}
   401  
   402  	tasks := []*Task{
   403  		{
   404  			Name: ClientFileName,
   405  			Path: util.JoinPath(output, ClientFileName),
   406  			Text: tpl.ClientTpl,
   407  			Ext:  ext.ExtendClient,
   408  		},
   409  		{
   410  			Name: ServerFileName,
   411  			Path: util.JoinPath(output, ServerFileName),
   412  			Text: tpl.ServerTpl,
   413  			Ext:  ext.ExtendServer,
   414  		},
   415  		{
   416  			Name: InvokerFileName,
   417  			Path: util.JoinPath(output, InvokerFileName),
   418  			Text: tpl.InvokerTpl,
   419  			Ext:  ext.ExtendInvoker,
   420  		},
   421  		{
   422  			Name: ServiceFileName,
   423  			Path: util.JoinPath(output, svcPkg+".go"),
   424  			Text: tpl.ServiceTpl,
   425  		},
   426  	}
   427  
   428  	var fs []*File
   429  	for _, t := range tasks {
   430  		if err := t.Build(); err != nil {
   431  			err = fmt.Errorf("build %s failed: %w", t.Name, err)
   432  			return nil, err
   433  		}
   434  		g.setImports(t.Name, pkg)
   435  		if t.Ext != nil {
   436  			for _, path := range t.Ext.ImportPaths {
   437  				if alias, exist := ext.Dependencies[path]; exist {
   438  					pkg.AddImports(alias)
   439  				}
   440  			}
   441  		}
   442  		handle := func(task *Task, pkg *PackageInfo) (*File, error) {
   443  			return task.Render(pkg)
   444  		}
   445  		f, err := g.chainMWs(handle)(t, pkg)
   446  		if err != nil {
   447  			err = fmt.Errorf("render %s failed: %w", t.Name, err)
   448  			return nil, err
   449  		}
   450  		fs = append(fs, f)
   451  	}
   452  	return fs, nil
   453  }
   454  
   455  func (g *generator) updatePackageInfo(pkg *PackageInfo) {
   456  	pkg.NoFastAPI = g.NoFastAPI
   457  	pkg.Codec = g.IDLType
   458  	pkg.Version = g.Version
   459  	pkg.RealServiceName = g.ServiceName
   460  	pkg.Features = g.Features
   461  	pkg.ExternalKitexGen = g.Use
   462  	pkg.FrugalPretouch = g.FrugalPretouch
   463  	pkg.Module = g.ModuleName
   464  	if strings.EqualFold(g.Protocol, transport.HESSIAN2.String()) {
   465  		pkg.Protocol = transport.HESSIAN2
   466  	}
   467  	if pkg.Dependencies == nil {
   468  		pkg.Dependencies = make(map[string]string)
   469  	}
   470  
   471  	for ref, path := range globalDependencies {
   472  		if _, ok := pkg.Dependencies[ref]; !ok {
   473  			pkg.Dependencies[ref] = path
   474  		}
   475  	}
   476  }
   477  
   478  func (g *generator) setImports(name string, pkg *PackageInfo) {
   479  	pkg.Imports = make(map[string]map[string]bool)
   480  	switch name {
   481  	case ClientFileName:
   482  		pkg.AddImports("client")
   483  		if pkg.HasStreaming {
   484  			pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming")
   485  			pkg.AddImport("transport", "github.com/cloudwego/kitex/transport")
   486  		}
   487  		if len(pkg.AllMethods()) > 0 {
   488  			if needCallOpt(pkg) {
   489  				pkg.AddImports("callopt")
   490  			}
   491  			pkg.AddImports("context")
   492  		}
   493  		fallthrough
   494  	case HandlerFileName:
   495  		for _, m := range pkg.ServiceInfo.AllMethods() {
   496  			if !m.ServerStreaming && !m.ClientStreaming {
   497  				pkg.AddImports("context")
   498  			}
   499  			for _, a := range m.Args {
   500  				for _, dep := range a.Deps {
   501  					pkg.AddImport(dep.PkgRefName, dep.ImportPath)
   502  				}
   503  			}
   504  			if !m.Void && m.Resp != nil {
   505  				for _, dep := range m.Resp.Deps {
   506  					pkg.AddImport(dep.PkgRefName, dep.ImportPath)
   507  				}
   508  			}
   509  		}
   510  	case ServerFileName, InvokerFileName:
   511  		if len(pkg.CombineServices) == 0 {
   512  			pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath)
   513  		}
   514  		pkg.AddImports("server")
   515  	case ServiceFileName:
   516  		pkg.AddImports("errors")
   517  		pkg.AddImports("client")
   518  		pkg.AddImport("kitex", "github.com/cloudwego/kitex/pkg/serviceinfo")
   519  		pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath)
   520  		if len(pkg.AllMethods()) > 0 {
   521  			pkg.AddImports("context")
   522  		}
   523  		for _, m := range pkg.ServiceInfo.AllMethods() {
   524  			if m.ClientStreaming || m.ServerStreaming {
   525  				pkg.AddImports("fmt")
   526  			}
   527  			if m.GenArgResultStruct {
   528  				pkg.AddImports("proto")
   529  			} else {
   530  				// for method Arg and Result
   531  				pkg.AddImport(m.PkgRefName, m.ImportPath)
   532  			}
   533  			for _, a := range m.Args {
   534  				for _, dep := range a.Deps {
   535  					pkg.AddImport(dep.PkgRefName, dep.ImportPath)
   536  				}
   537  			}
   538  			if m.Streaming.IsStreaming || pkg.Codec == "protobuf" {
   539  				// protobuf handler support both PingPong and Unary (streaming) requests
   540  				pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming")
   541  			}
   542  			if !m.Void && m.Resp != nil {
   543  				for _, dep := range m.Resp.Deps {
   544  					pkg.AddImport(dep.PkgRefName, dep.ImportPath)
   545  				}
   546  			}
   547  			for _, e := range m.Exceptions {
   548  				for _, dep := range e.Deps {
   549  					pkg.AddImport(dep.PkgRefName, dep.ImportPath)
   550  				}
   551  			}
   552  		}
   553  		if pkg.FrugalPretouch {
   554  			pkg.AddImports("sync")
   555  			if len(pkg.AllMethods()) > 0 {
   556  				pkg.AddImports("frugal")
   557  				pkg.AddImports("reflect")
   558  			}
   559  		}
   560  	case MainFileName:
   561  		pkg.AddImport("log", "log")
   562  		pkg.AddImport(pkg.PkgRefName, util.JoinPath(pkg.ImportPath, strings.ToLower(pkg.ServiceName)))
   563  	}
   564  }
   565  
   566  func needCallOpt(pkg *PackageInfo) bool {
   567  	// callopt is referenced only by non-streaming methods
   568  	needCallOpt := false
   569  	switch pkg.Codec {
   570  	case "thrift":
   571  		for _, m := range pkg.ServiceInfo.AllMethods() {
   572  			if !m.Streaming.IsStreaming {
   573  				needCallOpt = true
   574  				break
   575  			}
   576  		}
   577  	case "protobuf":
   578  		needCallOpt = true
   579  	}
   580  	return needCallOpt
   581  }