github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/generator/completor.go (about)

     1  // Copyright 2021 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package generator
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"go/ast"
    21  	"go/parser"
    22  	"go/printer"
    23  	"go/token"
    24  	"io"
    25  	"os"
    26  	"path"
    27  	"path/filepath"
    28  	"strings"
    29  	"text/template"
    30  
    31  	"golang.org/x/tools/go/ast/astutil"
    32  
    33  	"github.com/cloudwego/kitex/tool/internal_pkg/log"
    34  	"github.com/cloudwego/kitex/tool/internal_pkg/tpl"
    35  )
    36  
    37  var errNoNewMethod = fmt.Errorf("no new method")
    38  
    39  type completer struct {
    40  	allMethods  []*MethodInfo
    41  	handlerPath string
    42  	serviceName string
    43  }
    44  
    45  func newCompleter(allMethods []*MethodInfo, handlerPath, serviceName string) *completer {
    46  	return &completer{
    47  		allMethods:  allMethods,
    48  		handlerPath: handlerPath,
    49  		serviceName: serviceName,
    50  	}
    51  }
    52  
    53  func parseFuncDecl(fd *ast.FuncDecl) (recvName, funcName string) {
    54  	funcName = fd.Name.String()
    55  	if fd.Recv != nil && len(fd.Recv.List) > 0 {
    56  		v := fd.Recv.List[0]
    57  		switch xv := v.Type.(type) {
    58  		case *ast.StarExpr:
    59  			if si, ok := xv.X.(*ast.Ident); ok {
    60  				recvName = si.Name
    61  			}
    62  		case *ast.Ident:
    63  			recvName = xv.Name
    64  		}
    65  	}
    66  	return
    67  }
    68  
    69  func (c *completer) compare(pkg *ast.Package) []*MethodInfo {
    70  	var newMethods []*MethodInfo
    71  	for _, m := range c.allMethods {
    72  		var have bool
    73  	PKGFILES:
    74  		for _, file := range pkg.Files {
    75  			for _, d := range file.Decls {
    76  				if fd, ok := d.(*ast.FuncDecl); ok {
    77  					rn, fn := parseFuncDecl(fd)
    78  					if rn == c.serviceName+"Impl" && fn == m.Name {
    79  						have = true
    80  						break PKGFILES
    81  					}
    82  				}
    83  			}
    84  		}
    85  		if !have {
    86  			log.Infof("[complete handler] add '%s' to handler.go\n", m.Name)
    87  			newMethods = append(newMethods, m)
    88  		}
    89  	}
    90  
    91  	return newMethods
    92  }
    93  
    94  func (c *completer) addImplementations(w io.Writer, newMethods []*MethodInfo) error {
    95  	// generate implements of new methods
    96  	mt := template.New(HandlerFileName).Funcs(funcs)
    97  	mt = template.Must(mt.Parse(`{{template "HandlerMethod" .}}`))
    98  	mt = template.Must(mt.Parse(tpl.HandlerMethodsTpl))
    99  	data := struct {
   100  		AllMethods  []*MethodInfo
   101  		ServiceName string
   102  	}{
   103  		AllMethods:  newMethods,
   104  		ServiceName: c.serviceName,
   105  	}
   106  	var buf bytes.Buffer
   107  	if err := mt.ExecuteTemplate(&buf, HandlerFileName, data); err != nil {
   108  		return err
   109  	}
   110  	_, err := w.Write(buf.Bytes())
   111  	return err
   112  }
   113  
   114  // add imports for new methods
   115  func (c *completer) addImport(w io.Writer, newMethods []*MethodInfo, fset *token.FileSet, handlerAST *ast.File) error {
   116  	newImports := make(map[string]bool)
   117  	for _, m := range newMethods {
   118  		for _, arg := range m.Args {
   119  			for _, dep := range arg.Deps {
   120  				newImports[dep.PkgRefName+" "+dep.ImportPath] = true
   121  			}
   122  		}
   123  		if m.Resp != nil {
   124  			for _, dep := range m.Resp.Deps {
   125  				newImports[dep.PkgRefName+" "+dep.ImportPath] = true
   126  			}
   127  		}
   128  	}
   129  	imports := handlerAST.Imports
   130  	for _, i := range imports {
   131  		path := strings.Trim(i.Path.Value, "\"")
   132  		var aliasPath string
   133  		// remove imports that already in handler.go
   134  		if i.Name != nil {
   135  			aliasPath = i.Name.String() + " " + path
   136  		} else {
   137  			aliasPath = filepath.Base(path) + " " + path
   138  			delete(newImports, path)
   139  		}
   140  		delete(newImports, aliasPath)
   141  	}
   142  	for path := range newImports {
   143  		s := strings.Split(path, " ")
   144  		switch len(s) {
   145  		case 1:
   146  			astutil.AddImport(fset, handlerAST, strings.Trim(s[0], "\""))
   147  		case 2:
   148  			astutil.AddNamedImport(fset, handlerAST, s[0], strings.Trim(s[1], "\""))
   149  		default:
   150  			log.Warn("cannot recognize import path", path)
   151  		}
   152  	}
   153  	printer.Fprint(w, fset, handlerAST)
   154  	return nil
   155  }
   156  
   157  func (c *completer) process(w io.Writer) error {
   158  	// get AST of main package
   159  	fset := token.NewFileSet()
   160  	pkgs, err := parser.ParseDir(fset, filepath.Dir(c.handlerPath), nil, parser.ParseComments)
   161  	if err != nil {
   162  		err = fmt.Errorf("go/parser failed to parse the main package: %w", err)
   163  		log.Warn("NOTICE: This is not a bug. We cannot add new methods to handler.go because your codes failed to compile. Fix the compile errors and try again.\n%s", err.Error())
   164  		return err
   165  	}
   166  	main, ok := pkgs["main"]
   167  	if !ok {
   168  		return fmt.Errorf("main package not found")
   169  	}
   170  
   171  	newMethods := c.compare(main)
   172  	if len(newMethods) == 0 {
   173  		return errNoNewMethod
   174  	}
   175  	err = c.addImport(w, newMethods, fset, main.Files[c.handlerPath])
   176  	if err != nil {
   177  		return fmt.Errorf("add imports failed error: %v", err)
   178  	}
   179  	err = c.addImplementations(w, newMethods)
   180  	if err != nil {
   181  		return fmt.Errorf("add implements failed error: %v", err)
   182  	}
   183  	return nil
   184  }
   185  
   186  func (c *completer) CompleteMethods() (*File, error) {
   187  	var buf bytes.Buffer
   188  	err := c.process(&buf)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  	return &File{Name: c.handlerPath, Content: buf.String()}, nil
   193  }
   194  
   195  type commonCompleter struct {
   196  	path   string
   197  	pkg    *PackageInfo
   198  	update *Update
   199  }
   200  
   201  func (c *commonCompleter) Complete() (*File, error) {
   202  	var w bytes.Buffer
   203  	// get AST of main package
   204  	fset := token.NewFileSet()
   205  	f, err := parser.ParseFile(fset, c.path, nil, parser.ParseComments)
   206  	if err != nil {
   207  		err = fmt.Errorf("go/parser failed to parse the file: %s, err: %v", c.path, err)
   208  		log.Warnf("NOTICE: This is not a bug. We cannot update the file %s because your codes failed to compile. Fix the compile errors and try again.\n%s", c.path, err.Error())
   209  		return nil, err
   210  	}
   211  
   212  	newMethods, err := c.compare()
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	if len(newMethods) == 0 {
   217  		return nil, errNoNewMethod
   218  	}
   219  	err = c.addImport(&w, newMethods, fset, f)
   220  	if err != nil {
   221  		return nil, fmt.Errorf("add imports failed error: %v", err)
   222  	}
   223  	err = c.addImplementations(&w, newMethods)
   224  	if err != nil {
   225  		return nil, fmt.Errorf("add implements failed error: %v", err)
   226  	}
   227  	return &File{Name: c.path, Content: w.String()}, nil
   228  }
   229  
   230  func (c *commonCompleter) compare() ([]*MethodInfo, error) {
   231  	var newMethods []*MethodInfo
   232  	for _, m := range c.pkg.Methods {
   233  		c.pkg.Methods = []*MethodInfo{m}
   234  		keyTask := &Task{
   235  			Text: c.update.Key,
   236  		}
   237  		key, err := keyTask.RenderString(c.pkg)
   238  		if err != nil {
   239  			return newMethods, err
   240  		}
   241  		have := false
   242  
   243  		dir := c.path
   244  		if strings.HasSuffix(c.path, ".go") {
   245  			dir = path.Dir(c.path)
   246  		}
   247  		filepath.Walk(dir, func(fullPath string, info os.FileInfo, err error) error {
   248  			if err != nil {
   249  				return err
   250  			}
   251  
   252  			if path.Base(dir) == info.Name() && info.IsDir() {
   253  				return nil
   254  			}
   255  			if info.IsDir() {
   256  				return filepath.SkipDir
   257  			}
   258  			if !strings.HasSuffix(fullPath, ".go") {
   259  				return nil
   260  			}
   261  			// get AST of main package
   262  			fset := token.NewFileSet()
   263  			f, err := parser.ParseFile(fset, fullPath, nil, parser.ParseComments)
   264  			if err != nil {
   265  				err = fmt.Errorf("go/parser failed to parse the file: %s, err: %v", c.path, err)
   266  				log.Warnf("NOTICE: This is not a bug. We cannot update the file %s because your codes failed to compile. Fix the compile errors and try again.\n%s", c.path, err.Error())
   267  				return err
   268  			}
   269  
   270  			for _, d := range f.Decls {
   271  				if fd, ok := d.(*ast.FuncDecl); ok {
   272  					_, fn := parseFuncDecl(fd)
   273  					if fn == key {
   274  						have = true
   275  						break
   276  					}
   277  				}
   278  			}
   279  			return nil
   280  		})
   281  
   282  		if !have {
   283  			newMethods = append(newMethods, m)
   284  		}
   285  	}
   286  
   287  	return newMethods, nil
   288  }
   289  
   290  // add imports for new methods
   291  func (c *commonCompleter) addImport(w io.Writer, newMethods []*MethodInfo, fset *token.FileSet, handlerAST *ast.File) error {
   292  	existImports := make(map[string]bool)
   293  	for _, i := range handlerAST.Imports {
   294  		existImports[strings.Trim(i.Path.Value, "\"")] = true
   295  	}
   296  	tmp := c.pkg.Methods
   297  	defer func() {
   298  		c.pkg.Methods = tmp
   299  	}()
   300  	c.pkg.Methods = newMethods
   301  	for _, i := range c.update.ImportTpl {
   302  		importTask := &Task{
   303  			Text: i,
   304  		}
   305  		content, err := importTask.RenderString(c.pkg)
   306  		if err != nil {
   307  			return err
   308  		}
   309  		imports := c.parseImports(content)
   310  		for idx := range imports {
   311  			if _, ok := existImports[strings.Trim(imports[idx][1], "\"")]; !ok {
   312  				astutil.AddImport(fset, handlerAST, strings.Trim(imports[idx][1], "\""))
   313  			}
   314  		}
   315  	}
   316  	printer.Fprint(w, fset, handlerAST)
   317  	return nil
   318  }
   319  
   320  func (c *commonCompleter) addImplementations(w io.Writer, newMethods []*MethodInfo) error {
   321  	tmp := c.pkg.Methods
   322  	defer func() {
   323  		c.pkg.Methods = tmp
   324  	}()
   325  	c.pkg.Methods = newMethods
   326  	// generate implements of new methods
   327  	appendTask := &Task{
   328  		Text: c.update.AppendTpl,
   329  	}
   330  	content, err := appendTask.RenderString(c.pkg)
   331  	if err != nil {
   332  		return err
   333  	}
   334  	_, err = w.Write([]byte(content))
   335  	c.pkg.Methods = tmp
   336  	return err
   337  }
   338  
   339  // imports[2] is alias, import
   340  func (c *commonCompleter) parseImports(content string) (imports [][2]string) {
   341  	if !strings.Contains(content, "\"") {
   342  		imports = append(imports, [2]string{"", content})
   343  		return imports
   344  	}
   345  	for i := 0; i < len(content); i++ {
   346  		if content[i] == ' ' {
   347  			continue
   348  		}
   349  		isAlias := content[i] != '"'
   350  
   351  		start := i
   352  		for ; i < len(content); i++ {
   353  			if content[i] == ' ' {
   354  				break
   355  			}
   356  		}
   357  		sub := content[start:i]
   358  		switch {
   359  		case isAlias:
   360  			imports = append(imports, [2]string{sub, ""})
   361  		case len(imports) > 0 && imports[len(imports)-1][1] == "":
   362  			imports[len(imports)-1][1] = sub
   363  		default:
   364  			imports = append(imports, [2]string{"", sub})
   365  		}
   366  	}
   367  	return imports
   368  }