github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/pluginmode/thriftgo/patcher.go (about)

     1  // Copyright 2021 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package thriftgo
    16  
    17  import (
    18  	"fmt"
    19  	"io/ioutil"
    20  	"path/filepath"
    21  	"reflect"
    22  	"runtime"
    23  	"sort"
    24  	"strconv"
    25  	"strings"
    26  	"text/template"
    27  
    28  	"github.com/cloudwego/thriftgo/generator/golang"
    29  	"github.com/cloudwego/thriftgo/generator/golang/templates"
    30  	"github.com/cloudwego/thriftgo/generator/golang/templates/slim"
    31  	"github.com/cloudwego/thriftgo/parser"
    32  	"github.com/cloudwego/thriftgo/plugin"
    33  
    34  	"github.com/cloudwego/kitex/tool/internal_pkg/generator"
    35  	"github.com/cloudwego/kitex/tool/internal_pkg/util"
    36  )
    37  
    38  var extraTemplates []string
    39  
    40  // AppendToTemplate string
    41  func AppendToTemplate(text string) {
    42  	extraTemplates = append(extraTemplates, text)
    43  }
    44  
    45  const kitexUnusedProtection = `
    46  // KitexUnusedProtection is used to prevent 'imported and not used' error.
    47  var KitexUnusedProtection = struct{}{}
    48  `
    49  
    50  //lint:ignore U1000 until protectionInsertionPoint is used
    51  var protectionInsertionPoint = "KitexUnusedProtection"
    52  
    53  type patcher struct {
    54  	noFastAPI             bool
    55  	utils                 *golang.CodeUtils
    56  	module                string
    57  	copyIDL               bool
    58  	version               string
    59  	record                bool
    60  	recordCmd             []string
    61  	deepCopyAPI           bool
    62  	protocol              string
    63  	handlerReturnKeepResp bool
    64  
    65  	fileTpl *template.Template
    66  }
    67  
    68  func (p *patcher) buildTemplates() (err error) {
    69  	m := p.utils.BuildFuncMap()
    70  	m["ReorderStructFields"] = p.reorderStructFields
    71  	m["TypeIDToGoType"] = func(t string) string { return typeIDToGoType[t] }
    72  	m["IsBinaryOrStringType"] = p.isBinaryOrStringType
    73  	m["Version"] = func() string { return p.version }
    74  	m["GenerateFastAPIs"] = func() bool { return !p.noFastAPI && p.utils.Template() != "slim" }
    75  	m["GenerateDeepCopyAPIs"] = func() bool { return p.deepCopyAPI }
    76  	m["GenerateArgsResultTypes"] = func() bool { return p.utils.Template() == "slim" }
    77  	m["ImportPathTo"] = generator.ImportPathTo
    78  	m["ToPackageNames"] = func(imports map[string]string) (res []string) {
    79  		for pth, alias := range imports {
    80  			if alias != "" {
    81  				res = append(res, alias)
    82  			} else {
    83  				res = append(res, strings.ToLower(filepath.Base(pth)))
    84  			}
    85  		}
    86  		sort.Strings(res)
    87  		return
    88  	}
    89  	m["Str"] = func(id int32) string {
    90  		if id < 0 {
    91  			return "_" + strconv.Itoa(-int(id))
    92  		}
    93  		return strconv.Itoa(int(id))
    94  	}
    95  	m["IsNil"] = func(i interface{}) bool {
    96  		return i == nil || reflect.ValueOf(i).IsNil()
    97  	}
    98  	m["SourceTarget"] = func(s string) string {
    99  		// p.XXX
   100  		if strings.HasPrefix(s, "p.") {
   101  			return "src." + s[2:]
   102  		}
   103  		// _key, _val
   104  		return s[1:]
   105  	}
   106  	m["FieldName"] = func(s string) string {
   107  		// p.XXX
   108  		return strings.ToLower(s[2:3]) + s[3:]
   109  	}
   110  	m["IsHessian"] = func() bool {
   111  		return p.IsHessian2()
   112  	}
   113  	m["IsGoStringType"] = func(typeName golang.TypeName) bool {
   114  		return typeName == "string" || typeName == "*string"
   115  	}
   116  
   117  	tpl := template.New("kitex").Funcs(m)
   118  	allTemplates := basicTemplates
   119  	if p.utils.Template() == "slim" {
   120  		allTemplates = append(allTemplates, slim.StructLike,
   121  			templates.StructLikeDefault,
   122  			templates.FieldGetOrSet,
   123  			templates.FieldIsSet,
   124  			structLikeDeepCopy,
   125  			fieldDeepCopy,
   126  			fieldDeepCopyStructLike,
   127  			fieldDeepCopyContainer,
   128  			fieldDeepCopyMap,
   129  			fieldDeepCopyList,
   130  			fieldDeepCopySet,
   131  			fieldDeepCopyBaseType,
   132  			structLikeCodec,
   133  			structLikeProtocol,
   134  			javaClassName,
   135  			processor,
   136  		)
   137  	} else {
   138  		allTemplates = append(allTemplates, structLikeCodec,
   139  			structLikeFastRead,
   140  			structLikeFastReadField,
   141  			structLikeDeepCopy,
   142  			structLikeFastWrite,
   143  			structLikeFastWriteNocopy,
   144  			structLikeLength,
   145  			structLikeFastWriteField,
   146  			structLikeFieldLength,
   147  			structLikeProtocol,
   148  			javaClassName,
   149  			fieldFastRead,
   150  			fieldFastReadStructLike,
   151  			fieldFastReadBaseType,
   152  			fieldFastReadContainer,
   153  			fieldFastReadMap,
   154  			fieldFastReadSet,
   155  			fieldFastReadList,
   156  			fieldDeepCopy,
   157  			fieldDeepCopyStructLike,
   158  			fieldDeepCopyContainer,
   159  			fieldDeepCopyMap,
   160  			fieldDeepCopyList,
   161  			fieldDeepCopySet,
   162  			fieldDeepCopyBaseType,
   163  			fieldFastWrite,
   164  			fieldLength,
   165  			fieldFastWriteStructLike,
   166  			fieldStructLikeLength,
   167  			fieldFastWriteBaseType,
   168  			fieldBaseTypeLength,
   169  			fieldFixedLengthTypeLength,
   170  			fieldFastWriteContainer,
   171  			fieldContainerLength,
   172  			fieldFastWriteMap,
   173  			fieldMapLength,
   174  			fieldFastWriteSet,
   175  			fieldSetLength,
   176  			fieldFastWriteList,
   177  			fieldListLength,
   178  			templates.FieldDeepEqual,
   179  			templates.FieldDeepEqualBase,
   180  			templates.FieldDeepEqualStructLike,
   181  			templates.FieldDeepEqualContainer,
   182  			validateSet,
   183  			processor,
   184  		)
   185  	}
   186  	for _, txt := range allTemplates {
   187  		tpl = template.Must(tpl.Parse(txt))
   188  	}
   189  
   190  	ext := `{{define "ExtraTemplates"}}{{end}}`
   191  	if len(extraTemplates) > 0 {
   192  		ext = fmt.Sprintf("{{define \"ExtraTemplates\"}}\n%s\n{{end}}",
   193  			strings.Join(extraTemplates, "\n"))
   194  	}
   195  	tpl, err = tpl.Parse(ext)
   196  	if err != nil {
   197  		return fmt.Errorf("failed to parse extra templates: %w: %q", err, ext)
   198  	}
   199  
   200  	if p.IsHessian2() {
   201  		tpl, err = tpl.Parse(registerHessian)
   202  		if err != nil {
   203  			return fmt.Errorf("failed to parse hessian2 templates: %w: %q", err, registerHessian)
   204  		}
   205  	}
   206  
   207  	p.fileTpl = tpl
   208  	return nil
   209  }
   210  
   211  func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err error) {
   212  	p.buildTemplates()
   213  	var buf strings.Builder
   214  
   215  	protection := make(map[string]*plugin.Generated)
   216  
   217  	for ast := range req.AST.DepthFirstSearch() {
   218  		// scope, err := golang.BuildScope(p.utils, ast)
   219  		scope, _, err := golang.BuildRefScope(p.utils, ast)
   220  		if err != nil {
   221  			return nil, fmt.Errorf("build scope for ast %q: %w", ast.Filename, err)
   222  		}
   223  		p.utils.SetRootScope(scope)
   224  
   225  		pkgName := golang.GetImportPackage(golang.GetImportPath(p.utils, ast))
   226  
   227  		path := p.utils.CombineOutputPath(req.OutputPath, ast)
   228  		base := p.utils.GetFilename(ast)
   229  		target := util.JoinPath(path, "k-"+base)
   230  
   231  		// Define KitexUnusedProtection in k-consts.go .
   232  		// Add k-consts.go before target to force the k-consts.go generated by consts.thrift to be renamed.
   233  		consts := util.JoinPath(path, "k-consts.go")
   234  		if protection[consts] == nil {
   235  			patch := &plugin.Generated{
   236  				Content: "package " + pkgName + "\n" + kitexUnusedProtection,
   237  				Name:    &consts,
   238  			}
   239  			patches = append(patches, patch)
   240  			protection[consts] = patch
   241  		}
   242  
   243  		buf.Reset()
   244  
   245  		// if all scopes are ref, don't generate k-xxx
   246  		if scope == nil {
   247  			continue
   248  		}
   249  
   250  		if p.IsHessian2() {
   251  			register := util.JoinPath(path, fmt.Sprintf("hessian2-register-%s", base))
   252  			patch, err := p.patchHessian(path, scope, pkgName, base)
   253  			if err != nil {
   254  				return nil, fmt.Errorf("patch hessian fail for %q: %w", ast.Filename, err)
   255  			}
   256  
   257  			patches = append(patches, patch)
   258  			protection[register] = patch
   259  		}
   260  
   261  		data := &struct {
   262  			Scope   *golang.Scope
   263  			PkgName string
   264  			Imports map[string]string
   265  		}{Scope: scope, PkgName: pkgName}
   266  		data.Imports, err = scope.ResolveImports()
   267  		if err != nil {
   268  			return nil, fmt.Errorf("resolve imports failed for %q: %w", ast.Filename, err)
   269  		}
   270  		p.filterStdLib(data.Imports)
   271  		if err = p.fileTpl.ExecuteTemplate(&buf, "file", data); err != nil {
   272  			return nil, fmt.Errorf("%q: %w", ast.Filename, err)
   273  		}
   274  		content := buf.String()
   275  		// if kutils is not used, remove the dependency.
   276  		if !strings.Contains(content, "kutils.StringDeepCopy") {
   277  			kutilsImp := `kutils "github.com/cloudwego/kitex/pkg/utils"`
   278  			idx := strings.Index(content, kutilsImp)
   279  			if idx > 0 {
   280  				content = content[:idx-1] + content[idx+len(kutilsImp):]
   281  			}
   282  		}
   283  		patches = append(patches, &plugin.Generated{
   284  			Content: content,
   285  			Name:    &target,
   286  		})
   287  
   288  		if p.copyIDL {
   289  			content, err := ioutil.ReadFile(ast.Filename)
   290  			if err != nil {
   291  				return nil, fmt.Errorf("read %q: %w", ast.Filename, err)
   292  			}
   293  			path := util.JoinPath(path, filepath.Base(ast.Filename))
   294  			patches = append(patches, &plugin.Generated{
   295  				Content: string(content),
   296  				Name:    &path,
   297  			})
   298  		}
   299  
   300  		if p.record {
   301  			content := doRecord(p.recordCmd)
   302  			bashPath := util.JoinPath(getBashPath())
   303  			patches = append(patches, &plugin.Generated{
   304  				Content: content,
   305  				Name:    &bashPath,
   306  			})
   307  		}
   308  
   309  	}
   310  	return
   311  }
   312  
   313  func (p *patcher) patchHessian(path string, scope *golang.Scope, pkgName, base string) (patch *plugin.Generated, err error) {
   314  	buf := strings.Builder{}
   315  	resigterIDLName := fmt.Sprintf("hessian2-register-%s", base)
   316  	register := util.JoinPath(path, resigterIDLName)
   317  	data := &struct {
   318  		Scope   *golang.Scope
   319  		PkgName string
   320  		Imports map[string]string
   321  		GoName  string
   322  		IDLName string
   323  	}{Scope: scope, PkgName: pkgName, IDLName: util.UpperFirst(strings.Replace(base, ".go", "", -1))}
   324  	data.Imports, err = scope.ResolveImports()
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	if err = p.fileTpl.ExecuteTemplate(&buf, "register", data); err != nil {
   330  		return nil, err
   331  	}
   332  	patch = &plugin.Generated{
   333  		Content: buf.String(),
   334  		Name:    &register,
   335  	}
   336  	return patch, nil
   337  }
   338  
   339  func getBashPath() string {
   340  	if runtime.GOOS == "windows" {
   341  		return "kitex-all.bat"
   342  	}
   343  	return "kitex-all.sh"
   344  }
   345  
   346  // DoRecord records current cmd into kitex-all.sh
   347  func doRecord(recordCmd []string) string {
   348  	bytes, err := ioutil.ReadFile(getBashPath())
   349  	content := string(bytes)
   350  	if err != nil {
   351  		content = "#! /usr/bin/env bash\n"
   352  	}
   353  	var input, currentIdl string
   354  	for _, s := range recordCmd {
   355  		if s != "-record" {
   356  			input += s + " "
   357  		}
   358  		if strings.HasSuffix(s, ".thrift") || strings.HasSuffix(s, ".proto") {
   359  			currentIdl = s
   360  		}
   361  	}
   362  	if input != "" && currentIdl != "" {
   363  		find := false
   364  		lines := strings.Split(content, "\n")
   365  		for i, line := range lines {
   366  			if strings.Contains(input, "-service") && strings.Contains(line, "-service") {
   367  				lines[i] = input
   368  				find = true
   369  				break
   370  			}
   371  			if strings.Contains(line, currentIdl) && !strings.Contains(line, "-service") {
   372  				lines[i] = input
   373  				find = true
   374  				break
   375  			}
   376  		}
   377  		if !find {
   378  			content += "\n" + input
   379  		} else {
   380  			content = strings.Join(lines, "\n")
   381  		}
   382  	}
   383  	return content
   384  }
   385  
   386  func (p *patcher) reorderStructFields(fields []*golang.Field) ([]*golang.Field, error) {
   387  	fixedLengthFields := make(map[*golang.Field]bool, len(fields))
   388  	for _, field := range fields {
   389  		fixedLengthFields[field] = golang.IsFixedLengthType(field.Type)
   390  	}
   391  
   392  	sortedFields := make([]*golang.Field, 0, len(fields))
   393  	for _, v := range fields {
   394  		if fixedLengthFields[v] {
   395  			sortedFields = append(sortedFields, v)
   396  		}
   397  	}
   398  	for _, v := range fields {
   399  		if !fixedLengthFields[v] {
   400  			sortedFields = append(sortedFields, v)
   401  		}
   402  	}
   403  
   404  	return sortedFields, nil
   405  }
   406  
   407  func (p *patcher) filterStdLib(imports map[string]string) {
   408  	// remove std libs and thrift to prevent duplicate import.
   409  	prefix := p.module + "/"
   410  	for pth := range imports {
   411  		if strings.HasPrefix(pth, prefix) { // local module
   412  			continue
   413  		}
   414  		if pth == "github.com/apache/thrift/lib/go/thrift" {
   415  			delete(imports, pth)
   416  		}
   417  		if strings.HasPrefix(pth, "github.com/cloudwego/thriftgo") {
   418  			delete(imports, pth)
   419  		}
   420  		if !strings.Contains(pth, ".") { // std lib
   421  			delete(imports, pth)
   422  		}
   423  	}
   424  }
   425  
   426  func (p *patcher) isBinaryOrStringType(t *parser.Type) bool {
   427  	return t.Category.IsBinary() || t.Category.IsString()
   428  }
   429  
   430  func (p *patcher) IsHessian2() bool {
   431  	return strings.EqualFold(p.protocol, "hessian2")
   432  }
   433  
   434  var typeIDToGoType = map[string]string{
   435  	"Bool":   "bool",
   436  	"Byte":   "int8",
   437  	"I16":    "int16",
   438  	"I32":    "int32",
   439  	"I64":    "int64",
   440  	"Double": "float64",
   441  	"String": "string",
   442  	"Binary": "[]byte",
   443  }