github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/pluginmode/thriftgo/convertor.go (about)

     1  // Copyright 2021 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package thriftgo
    16  
    17  import (
    18  	"fmt"
    19  	"go/format"
    20  	"io"
    21  	"io/ioutil"
    22  	"os"
    23  	"path/filepath"
    24  	"regexp"
    25  	"strings"
    26  
    27  	"github.com/cloudwego/thriftgo/generator/backend"
    28  	"github.com/cloudwego/thriftgo/generator/golang"
    29  	"github.com/cloudwego/thriftgo/generator/golang/streaming"
    30  	"github.com/cloudwego/thriftgo/parser"
    31  	"github.com/cloudwego/thriftgo/plugin"
    32  	"github.com/cloudwego/thriftgo/semantic"
    33  
    34  	"github.com/cloudwego/kitex/tool/internal_pkg/generator"
    35  	internal_log "github.com/cloudwego/kitex/tool/internal_pkg/log"
    36  	"github.com/cloudwego/kitex/tool/internal_pkg/util"
    37  	"github.com/cloudwego/kitex/transport"
    38  )
    39  
    40  var (
    41  	prelude  = map[string]bool{"client": true, "server": true, "callopt": true, "context": true, "thrift": true, "kitex": true}
    42  	keyWords = []string{"client", "server", "callopt", "context", "thrift", "kitex"}
    43  )
    44  
    45  type converter struct {
    46  	Warnings []string
    47  	Utils    *golang.CodeUtils
    48  	Config   generator.Config
    49  	Package  generator.PackageInfo
    50  	Services []*generator.ServiceInfo
    51  	svc2ast  map[*generator.ServiceInfo]*parser.Thrift
    52  }
    53  
    54  func (c *converter) init(req *plugin.Request) error {
    55  	if req.Language != "go" {
    56  		return fmt.Errorf("expect language to be 'go'. Encountered '%s'", req.Language)
    57  	}
    58  
    59  	// resotre the arguments for kitex
    60  	if err := c.Config.Unpack(req.PluginParameters); err != nil {
    61  		return err
    62  	}
    63  
    64  	c.Utils = golang.NewCodeUtils(c.initLogs())
    65  	c.Utils.HandleOptions(req.GeneratorParameters)
    66  
    67  	return nil
    68  }
    69  
    70  func (c *converter) initLogs() backend.LogFunc {
    71  	lf := backend.LogFunc{
    72  		Info: func(v ...interface{}) {},
    73  		Warn: func(v ...interface{}) {
    74  			c.Warnings = append(c.Warnings, fmt.Sprint(v...))
    75  		},
    76  		MultiWarn: func(warns []string) {
    77  			c.Warnings = append(c.Warnings, warns...)
    78  		},
    79  	}
    80  	if c.Config.Verbose {
    81  		lf.Info = lf.Warn
    82  	}
    83  
    84  	internal_log.SetDefaultLogger(internal_log.Logger{
    85  		Println: func(w io.Writer, a ...interface{}) (n int, err error) {
    86  			if w != os.Stdout || c.Config.Verbose {
    87  				c.Warnings = append(c.Warnings, fmt.Sprint(a...))
    88  			}
    89  			return 0, nil
    90  		},
    91  		Printf: func(w io.Writer, format string, a ...interface{}) (n int, err error) {
    92  			if w != os.Stdout || c.Config.Verbose {
    93  				c.Warnings = append(c.Warnings, fmt.Sprintf(format, a...))
    94  			}
    95  			return 0, nil
    96  		},
    97  	})
    98  	return lf
    99  }
   100  
   101  func (c *converter) fail(err error) int {
   102  	res := plugin.BuildErrorResponse(err.Error(), c.Warnings...)
   103  	return exit(res)
   104  }
   105  
   106  func (c *converter) avoidIncludeConflict(ast *parser.Thrift, ref string) (*parser.Thrift, string) {
   107  	fn := filepath.Base(ast.Filename)
   108  	for _, inc := range ast.Includes {
   109  		if filepath.Base(inc.Path) == fn { // will cause include conflict
   110  			ref = "kitex_faked_idl"
   111  			faked := *ast
   112  			faked.Filename = util.JoinPath(filepath.Dir(faked.Filename), ref+".thrift")
   113  			_, hasNamespace := ast.GetNamespace("go")
   114  			if !hasNamespace {
   115  				faked.Namespaces = append(faked.Namespaces, &parser.Namespace{
   116  					Language: "go",
   117  					Name:     ast.GetNamespaceOrReferenceName("go"),
   118  				})
   119  			}
   120  			return &faked, ref
   121  		}
   122  	}
   123  	return ast, ref
   124  }
   125  
   126  // TODO: copy by marshal & unmarshal? to avoid missing fields.
   127  func (c *converter) copyTreeWithRef(ast *parser.Thrift, ref string) *parser.Thrift {
   128  	ast, ref = c.avoidIncludeConflict(ast, ref)
   129  
   130  	t := &parser.Thrift{
   131  		Filename: ast.Filename,
   132  		Namespaces: []*parser.Namespace{
   133  			{Language: "*", Name: "fake"},
   134  		},
   135  	}
   136  	t.Includes = append(t.Includes, &parser.Include{Path: ast.Filename, Reference: ast})
   137  	t.Includes = append(t.Includes, ast.Includes...)
   138  
   139  	for _, s := range ast.Services {
   140  		ss := &parser.Service{
   141  			Name:    s.Name,
   142  			Extends: s.Extends,
   143  		}
   144  		for _, f := range s.Functions {
   145  			ff := c.copyFunctionWithRef(f, ref)
   146  			ss.Functions = append(ss.Functions, ff)
   147  		}
   148  		t.Services = append(t.Services, ss)
   149  	}
   150  	return t
   151  }
   152  
   153  func (c *converter) copyFunctionWithRef(f *parser.Function, ref string) *parser.Function {
   154  	ff := &parser.Function{
   155  		Name:         f.Name,
   156  		Oneway:       f.Oneway,
   157  		Void:         f.Void,
   158  		FunctionType: c.copyTypeWithRef(f.FunctionType, ref),
   159  		Annotations:  c.copyAnnotations(f.Annotations),
   160  	}
   161  	for _, x := range f.Arguments {
   162  		y := *x
   163  		y.Type = c.copyTypeWithRef(x.Type, ref)
   164  		ff.Arguments = append(ff.Arguments, &y)
   165  	}
   166  	for _, x := range f.Throws {
   167  		y := *x
   168  		y.Type = c.copyTypeWithRef(x.Type, ref)
   169  		ff.Throws = append(ff.Throws, &y)
   170  	}
   171  	return ff
   172  }
   173  
   174  func (c *converter) copyTypeWithRef(t *parser.Type, ref string) (res *parser.Type) {
   175  	switch t.Name {
   176  	case "void":
   177  		return t
   178  	case "bool", "byte", "i8", "i16", "i32", "i64", "double", "string", "binary":
   179  		return t
   180  	case "map":
   181  		return &parser.Type{
   182  			Name:      t.Name,
   183  			KeyType:   c.copyTypeWithRef(t.KeyType, ref),
   184  			ValueType: c.copyTypeWithRef(t.ValueType, ref),
   185  		}
   186  	case "set", "list":
   187  		return &parser.Type{
   188  			Name:      t.Name,
   189  			ValueType: c.copyTypeWithRef(t.ValueType, ref),
   190  		}
   191  	default:
   192  		if strings.Contains(t.Name, ".") {
   193  			return &parser.Type{
   194  				Name:      t.Name,
   195  				KeyType:   t.KeyType,
   196  				ValueType: t.ValueType,
   197  			}
   198  		}
   199  		return &parser.Type{
   200  			Name: ref + "." + t.Name,
   201  		}
   202  	}
   203  }
   204  
   205  func (c *converter) getImports(t *parser.Type) (res []generator.PkgInfo) {
   206  	switch t.Name {
   207  	case "void":
   208  		return nil
   209  	case "bool", "byte", "i8", "i16", "i32", "i64", "double", "string", "binary":
   210  		return nil
   211  	case "map":
   212  		res = append(res, c.getImports(t.KeyType)...)
   213  		fallthrough
   214  	case "set", "list":
   215  		res = append(res, c.getImports(t.ValueType)...)
   216  		return res
   217  	default:
   218  		if ref := t.GetReference(); ref != nil {
   219  			inc := c.Utils.RootScope().Includes().ByIndex(int(ref.GetIndex()))
   220  			res = append(res, generator.PkgInfo{
   221  				PkgRefName: inc.PackageName,
   222  				ImportPath: inc.ImportPath,
   223  			})
   224  		}
   225  		return
   226  	}
   227  }
   228  
   229  func (c *converter) fixImportConflicts() {
   230  	for pkg, pth := range c.Package.Dependencies {
   231  		if prelude[pkg] {
   232  			delete(c.Package.Dependencies, pkg)
   233  			c.Package.Dependencies[pkg+"0"] = pth
   234  		}
   235  	}
   236  	var objs []interface{}
   237  	for _, s := range c.Services {
   238  		objs = append(objs, s)
   239  	}
   240  
   241  	fix := func(p *generator.PkgInfo) {
   242  		if prelude[p.PkgRefName] {
   243  			p.PkgRefName += "0"
   244  		}
   245  	}
   246  	kw := strings.Join(keyWords, "|")
   247  	re := regexp.MustCompile(`^(\*?)(` + kw + `)\.([^.]+)$`)
   248  	for len(objs) > 0 {
   249  		switch v := objs[0].(type) {
   250  		case *generator.ServiceInfo:
   251  			fix(&v.PkgInfo)
   252  			if v.Base != nil {
   253  				objs = append(objs, v.Base)
   254  			}
   255  			for _, m := range v.Methods {
   256  				objs = append(objs, m)
   257  			}
   258  		case *generator.MethodInfo:
   259  			fix(&v.PkgInfo)
   260  			for _, a := range v.Args {
   261  				objs = append(objs, a)
   262  			}
   263  			if !v.Void {
   264  				objs = append(objs, v.Resp)
   265  			}
   266  			for _, e := range v.Exceptions {
   267  				objs = append(objs, e)
   268  			}
   269  			v.ArgStructName = re.ReplaceAllString(v.ArgStructName, "${1}${2}0.${3}")
   270  			v.ResStructName = re.ReplaceAllString(v.ResStructName, "${1}${2}0.${3}")
   271  		case *generator.Parameter:
   272  			for i := 0; i < len(v.Deps); i++ {
   273  				fix(&v.Deps[i])
   274  			}
   275  			v.Type = re.ReplaceAllString(v.Type, "${1}${2}0.${3}")
   276  		}
   277  		objs = objs[1:]
   278  	}
   279  }
   280  
   281  type ast2svc map[string][]*generator.ServiceInfo
   282  
   283  func (t ast2svc) findService(ast *parser.Thrift, name string) *generator.ServiceInfo {
   284  	var list []*generator.ServiceInfo
   285  	for filename, l := range t {
   286  		if filename == ast.Filename {
   287  			list = l
   288  			break
   289  		}
   290  	}
   291  	for _, s := range list {
   292  		if s.RawServiceName == name {
   293  			return s
   294  		}
   295  	}
   296  	return nil
   297  }
   298  
   299  func (c *converter) convertTypes(req *plugin.Request) error {
   300  	var all ast2svc = make(map[string][]*generator.ServiceInfo)
   301  
   302  	c.svc2ast = make(map[*generator.ServiceInfo]*parser.Thrift)
   303  	for ast := range req.AST.DepthFirstSearch() {
   304  		ref, pkg, pth := c.Utils.ParseNamespace(ast)
   305  		// make the current ast as an include to produce correct type references.
   306  		fake := c.copyTreeWithRef(ast, ref)
   307  		fake.Name2Category = nil
   308  		if err := semantic.ResolveSymbols(fake); err != nil {
   309  			return fmt.Errorf("resolve fakse ast '%s': %w", ast.Filename, err)
   310  		}
   311  		used := true
   312  		fake.ForEachInclude(func(v *parser.Include) bool {
   313  			v.Used = &used // mark all includes used to force renaming for conflict IDLs in thriftgo
   314  			return true
   315  		})
   316  
   317  		scope, err := golang.BuildScope(c.Utils, fake)
   318  		if err != nil {
   319  			return fmt.Errorf("build scope for fake ast '%s': %w", ast.Filename, err)
   320  		}
   321  		c.Utils.SetRootScope(scope)
   322  		pi := generator.PkgInfo{
   323  			PkgName:    pkg,
   324  			PkgRefName: pkg,
   325  			ImportPath: util.CombineOutputPath(c.Config.PackagePrefix, pth),
   326  		}
   327  		for _, svc := range scope.Services() {
   328  			si, err := c.makeService(pi, svc)
   329  			if err != nil {
   330  				return fmt.Errorf("%s: makeService '%s': %w", ast.Filename, svc.Name, err)
   331  			}
   332  			si.ServiceFilePath = ast.Filename
   333  			all[ast.Filename] = append(all[ast.Filename], si)
   334  			c.svc2ast[si] = ast
   335  		}
   336  		// fill .Base
   337  		for i, svc := range ast.Services {
   338  			if len(svc.Extends) > 0 {
   339  				si := all[ast.Filename][i]
   340  				parts := semantic.SplitType(svc.Extends)
   341  				switch len(parts) {
   342  				case 1:
   343  					si.Base = all.findService(ast, parts[0])
   344  				case 2:
   345  					tmp, found := ast.GetReference(parts[0])
   346  					if !found {
   347  						break
   348  					}
   349  					si.Base = all.findService(tmp, parts[1])
   350  				}
   351  				if len(parts) > 0 && si.Base == nil {
   352  					return fmt.Errorf("base service '%s' %d not found for '%s'", svc.Extends, len(parts), svc.Name)
   353  				}
   354  			}
   355  		}
   356  
   357  		c.fixStreamingForExtendedServices(ast, all)
   358  
   359  		// combine service
   360  		if ast == req.AST && c.Config.CombineService && len(ast.Services) > 0 {
   361  			var (
   362  				svcs    []*generator.ServiceInfo
   363  				methods []*generator.MethodInfo
   364  			)
   365  			hasStreaming := false
   366  			for _, s := range all[ast.Filename] {
   367  				svcs = append(svcs, s)
   368  				hasStreaming = hasStreaming || s.HasStreaming
   369  				methods = append(methods, s.AllMethods()...)
   370  			}
   371  			// check method name conflict
   372  			mm := make(map[string]*generator.MethodInfo)
   373  			for _, m := range methods {
   374  				if _, ok := mm[m.Name]; ok {
   375  					return fmt.Errorf("combine service method %s in %s conflicts with %s in %s", m.Name, m.ServiceName, m.Name, mm[m.Name].ServiceName)
   376  				}
   377  				mm[m.Name] = m
   378  			}
   379  			svcName := c.getCombineServiceName("CombineService", all[ast.Filename])
   380  			si := &generator.ServiceInfo{
   381  				PkgInfo:         pi,
   382  				ServiceName:     svcName,
   383  				RawServiceName:  svcName,
   384  				CombineServices: svcs,
   385  				Methods:         methods,
   386  				ServiceFilePath: ast.Filename,
   387  				HasStreaming:    hasStreaming,
   388  			}
   389  
   390  			if c.IsHessian2() {
   391  				si.Protocol = transport.HESSIAN2.String()
   392  			}
   393  
   394  			si.HandlerReturnKeepResp = c.Config.HandlerReturnKeepResp
   395  			si.UseThriftReflection = c.Utils.Features().WithReflection
   396  			si.ServiceTypeName = func() string { return si.ServiceName }
   397  			all[ast.Filename] = append(all[ast.Filename], si)
   398  			c.svc2ast[si] = ast
   399  		}
   400  
   401  		c.Services = append(c.Services, all[ast.Filename]...)
   402  	}
   403  	return nil
   404  }
   405  
   406  func (c *converter) fixStreamingForExtendedServices(ast *parser.Thrift, all ast2svc) {
   407  	for i, svc := range ast.Services {
   408  		if svc.Extends == "" {
   409  			continue
   410  		}
   411  		si := all[ast.Filename][i]
   412  		if si.Base != nil {
   413  			si.FixHasStreamingForExtendedService()
   414  		}
   415  	}
   416  }
   417  
   418  func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*generator.ServiceInfo, error) {
   419  	si := &generator.ServiceInfo{
   420  		PkgInfo:        pkg,
   421  		ServiceName:    svc.GoName().String(),
   422  		RawServiceName: svc.Name,
   423  	}
   424  	si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName }
   425  
   426  	for _, f := range svc.Functions() {
   427  		if strings.HasPrefix(f.Name, "_") {
   428  			continue
   429  		}
   430  		mi, err := c.makeMethod(si, f)
   431  		if err != nil {
   432  			return nil, err
   433  		}
   434  		si.Methods = append(si.Methods, mi)
   435  	}
   436  
   437  	if c.IsHessian2() {
   438  		si.Protocol = transport.HESSIAN2.String()
   439  	}
   440  	si.HandlerReturnKeepResp = c.Config.HandlerReturnKeepResp
   441  	si.UseThriftReflection = c.Utils.Features().WithReflection
   442  	return si, nil
   443  }
   444  
   445  func (c *converter) makeMethod(si *generator.ServiceInfo, f *golang.Function) (*generator.MethodInfo, error) {
   446  	st, err := streaming.ParseStreaming(f.Function)
   447  	if err != nil {
   448  		return nil, err
   449  	}
   450  	mi := &generator.MethodInfo{
   451  		PkgInfo:            si.PkgInfo,
   452  		ServiceName:        si.ServiceName,
   453  		Name:               f.GoName().String(),
   454  		RawName:            f.Name,
   455  		Oneway:             f.Oneway,
   456  		Void:               f.Void,
   457  		ArgStructName:      f.ArgType().GoName().String(),
   458  		GenArgResultStruct: false,
   459  		Streaming:          st,
   460  		ClientStreaming:    st.ClientStreaming,
   461  		ServerStreaming:    st.ServerStreaming,
   462  		ArgsLength:         len(f.Arguments()),
   463  	}
   464  	if st.IsStreaming {
   465  		si.HasStreaming = true
   466  	}
   467  
   468  	if !f.Oneway {
   469  		mi.ResStructName = f.ResType().GoName().String()
   470  	}
   471  	if !f.Void {
   472  		typeName := f.ResponseGoTypeName().String()
   473  		mi.Resp = &generator.Parameter{
   474  			Deps: c.getImports(f.FunctionType),
   475  			Type: typeName,
   476  		}
   477  		mi.IsResponseNeedRedirect = "*"+typeName == f.ResType().Fields()[0].GoTypeName().String()
   478  	}
   479  
   480  	for _, a := range f.Arguments() {
   481  		arg := &generator.Parameter{
   482  			Deps:    c.getImports(a.Type),
   483  			Name:    f.ArgType().Field(a.Name).GoName().String(),
   484  			RawName: a.GoName().String(),
   485  			Type:    a.GoTypeName().String(),
   486  		}
   487  		mi.Args = append(mi.Args, arg)
   488  	}
   489  	for _, t := range f.Throws() {
   490  		ex := &generator.Parameter{
   491  			Deps:    c.getImports(t.Type),
   492  			Name:    f.ResType().Field(t.Name).GoName().String(),
   493  			RawName: t.GoName().String(),
   494  			Type:    t.GoTypeName().String(),
   495  		}
   496  		mi.Exceptions = append(mi.Exceptions, ex)
   497  	}
   498  	return mi, nil
   499  }
   500  
   501  func (c *converter) persist(res *plugin.Response) error {
   502  	for _, c := range res.Contents {
   503  		full := c.GetName()
   504  		content := []byte(c.Content)
   505  		if filepath.Ext(full) == ".go" {
   506  			if formatted, err := format.Source([]byte(c.Content)); err != nil {
   507  				internal_log.Warn(fmt.Sprintf("Failed to format %s: %s", full, err.Error()))
   508  			} else {
   509  				content = formatted
   510  			}
   511  		}
   512  
   513  		internal_log.Info("Write", full)
   514  		path := filepath.Dir(full)
   515  		if err := os.MkdirAll(path, 0o755); err != nil && !os.IsExist(err) {
   516  			return fmt.Errorf("failed to create path '%s': %w", path, err)
   517  		}
   518  		if err := ioutil.WriteFile(full, content, 0o644); err != nil {
   519  			return fmt.Errorf("failed to write file '%s': %w", full, err)
   520  		}
   521  	}
   522  	return nil
   523  }
   524  
   525  func (c *converter) getCombineServiceName(name string, svcs []*generator.ServiceInfo) string {
   526  	for _, svc := range svcs {
   527  		if svc.ServiceName == name {
   528  			return c.getCombineServiceName(name+"_", svcs)
   529  		}
   530  	}
   531  	return name
   532  }
   533  
   534  func (c *converter) IsHessian2() bool {
   535  	return strings.EqualFold(c.Config.Protocol, transport.HESSIAN2.String())
   536  }
   537  
   538  func (c *converter) copyAnnotations(annotations parser.Annotations) parser.Annotations {
   539  	copied := make(parser.Annotations, 0, len(annotations))
   540  	for _, annotation := range annotations {
   541  		values := make([]string, 0, len(annotation.Values))
   542  		values = append(values, annotation.Values...)
   543  		copied = append(copied, &parser.Annotation{
   544  			Key:    annotation.Key,
   545  			Values: values,
   546  		})
   547  	}
   548  	return copied
   549  }