github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/generator/completor.go (about) 1 // Copyright 2021 CloudWeGo Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package generator 16 17 import ( 18 "bytes" 19 "fmt" 20 "go/ast" 21 "go/parser" 22 "go/printer" 23 "go/token" 24 "io" 25 "os" 26 "path" 27 "path/filepath" 28 "strings" 29 "text/template" 30 31 "golang.org/x/tools/go/ast/astutil" 32 33 "github.com/cloudwego/kitex/tool/internal_pkg/log" 34 "github.com/cloudwego/kitex/tool/internal_pkg/tpl" 35 ) 36 37 var errNoNewMethod = fmt.Errorf("no new method") 38 39 type completer struct { 40 allMethods []*MethodInfo 41 handlerPath string 42 serviceName string 43 } 44 45 func newCompleter(allMethods []*MethodInfo, handlerPath, serviceName string) *completer { 46 return &completer{ 47 allMethods: allMethods, 48 handlerPath: handlerPath, 49 serviceName: serviceName, 50 } 51 } 52 53 func parseFuncDecl(fd *ast.FuncDecl) (recvName, funcName string) { 54 funcName = fd.Name.String() 55 if fd.Recv != nil && len(fd.Recv.List) > 0 { 56 v := fd.Recv.List[0] 57 switch xv := v.Type.(type) { 58 case *ast.StarExpr: 59 if si, ok := xv.X.(*ast.Ident); ok { 60 recvName = si.Name 61 } 62 case *ast.Ident: 63 recvName = xv.Name 64 } 65 } 66 return 67 } 68 69 func (c *completer) compare(pkg *ast.Package) []*MethodInfo { 70 var newMethods []*MethodInfo 71 for _, m := range c.allMethods { 72 var have bool 73 PKGFILES: 74 for _, file := range pkg.Files { 75 for _, d := range file.Decls { 76 if fd, ok := d.(*ast.FuncDecl); ok { 77 rn, fn := parseFuncDecl(fd) 78 if rn == c.serviceName+"Impl" && fn == m.Name { 79 have = true 80 break PKGFILES 81 } 82 } 83 } 84 } 85 if !have { 86 log.Infof("[complete handler] add '%s' to handler.go\n", m.Name) 87 newMethods = append(newMethods, m) 88 } 89 } 90 91 return newMethods 92 } 93 94 func (c *completer) addImplementations(w io.Writer, newMethods []*MethodInfo) error { 95 // generate implements of new methods 96 mt := template.New(HandlerFileName).Funcs(funcs) 97 mt = template.Must(mt.Parse(`{{template "HandlerMethod" .}}`)) 98 mt = template.Must(mt.Parse(tpl.HandlerMethodsTpl)) 99 data := struct { 100 AllMethods []*MethodInfo 101 ServiceName string 102 }{ 103 AllMethods: newMethods, 104 ServiceName: c.serviceName, 105 } 106 var buf bytes.Buffer 107 if err := mt.ExecuteTemplate(&buf, HandlerFileName, data); err != nil { 108 return err 109 } 110 _, err := w.Write(buf.Bytes()) 111 return err 112 } 113 114 // add imports for new methods 115 func (c *completer) addImport(w io.Writer, newMethods []*MethodInfo, fset *token.FileSet, handlerAST *ast.File) error { 116 newImports := make(map[string]bool) 117 for _, m := range newMethods { 118 for _, arg := range m.Args { 119 for _, dep := range arg.Deps { 120 newImports[dep.PkgRefName+" "+dep.ImportPath] = true 121 } 122 } 123 if m.Resp != nil { 124 for _, dep := range m.Resp.Deps { 125 newImports[dep.PkgRefName+" "+dep.ImportPath] = true 126 } 127 } 128 } 129 imports := handlerAST.Imports 130 for _, i := range imports { 131 path := strings.Trim(i.Path.Value, "\"") 132 var aliasPath string 133 // remove imports that already in handler.go 134 if i.Name != nil { 135 aliasPath = i.Name.String() + " " + path 136 } else { 137 aliasPath = filepath.Base(path) + " " + path 138 delete(newImports, path) 139 } 140 delete(newImports, aliasPath) 141 } 142 for path := range newImports { 143 s := strings.Split(path, " ") 144 switch len(s) { 145 case 1: 146 astutil.AddImport(fset, handlerAST, strings.Trim(s[0], "\"")) 147 case 2: 148 astutil.AddNamedImport(fset, handlerAST, s[0], strings.Trim(s[1], "\"")) 149 default: 150 log.Warn("cannot recognize import path", path) 151 } 152 } 153 printer.Fprint(w, fset, handlerAST) 154 return nil 155 } 156 157 func (c *completer) process(w io.Writer) error { 158 // get AST of main package 159 fset := token.NewFileSet() 160 pkgs, err := parser.ParseDir(fset, filepath.Dir(c.handlerPath), nil, parser.ParseComments) 161 if err != nil { 162 err = fmt.Errorf("go/parser failed to parse the main package: %w", err) 163 log.Warn("NOTICE: This is not a bug. We cannot add new methods to handler.go because your codes failed to compile. Fix the compile errors and try again.\n%s", err.Error()) 164 return err 165 } 166 main, ok := pkgs["main"] 167 if !ok { 168 return fmt.Errorf("main package not found") 169 } 170 171 newMethods := c.compare(main) 172 if len(newMethods) == 0 { 173 return errNoNewMethod 174 } 175 err = c.addImport(w, newMethods, fset, main.Files[c.handlerPath]) 176 if err != nil { 177 return fmt.Errorf("add imports failed error: %v", err) 178 } 179 err = c.addImplementations(w, newMethods) 180 if err != nil { 181 return fmt.Errorf("add implements failed error: %v", err) 182 } 183 return nil 184 } 185 186 func (c *completer) CompleteMethods() (*File, error) { 187 var buf bytes.Buffer 188 err := c.process(&buf) 189 if err != nil { 190 return nil, err 191 } 192 return &File{Name: c.handlerPath, Content: buf.String()}, nil 193 } 194 195 type commonCompleter struct { 196 path string 197 pkg *PackageInfo 198 update *Update 199 } 200 201 func (c *commonCompleter) Complete() (*File, error) { 202 var w bytes.Buffer 203 // get AST of main package 204 fset := token.NewFileSet() 205 f, err := parser.ParseFile(fset, c.path, nil, parser.ParseComments) 206 if err != nil { 207 err = fmt.Errorf("go/parser failed to parse the file: %s, err: %v", c.path, err) 208 log.Warnf("NOTICE: This is not a bug. We cannot update the file %s because your codes failed to compile. Fix the compile errors and try again.\n%s", c.path, err.Error()) 209 return nil, err 210 } 211 212 newMethods, err := c.compare() 213 if err != nil { 214 return nil, err 215 } 216 if len(newMethods) == 0 { 217 return nil, errNoNewMethod 218 } 219 err = c.addImport(&w, newMethods, fset, f) 220 if err != nil { 221 return nil, fmt.Errorf("add imports failed error: %v", err) 222 } 223 err = c.addImplementations(&w, newMethods) 224 if err != nil { 225 return nil, fmt.Errorf("add implements failed error: %v", err) 226 } 227 return &File{Name: c.path, Content: w.String()}, nil 228 } 229 230 func (c *commonCompleter) compare() ([]*MethodInfo, error) { 231 var newMethods []*MethodInfo 232 for _, m := range c.pkg.Methods { 233 c.pkg.Methods = []*MethodInfo{m} 234 keyTask := &Task{ 235 Text: c.update.Key, 236 } 237 key, err := keyTask.RenderString(c.pkg) 238 if err != nil { 239 return newMethods, err 240 } 241 have := false 242 243 dir := c.path 244 if strings.HasSuffix(c.path, ".go") { 245 dir = path.Dir(c.path) 246 } 247 filepath.Walk(dir, func(fullPath string, info os.FileInfo, err error) error { 248 if err != nil { 249 return err 250 } 251 252 if path.Base(dir) == info.Name() && info.IsDir() { 253 return nil 254 } 255 if info.IsDir() { 256 return filepath.SkipDir 257 } 258 if !strings.HasSuffix(fullPath, ".go") { 259 return nil 260 } 261 // get AST of main package 262 fset := token.NewFileSet() 263 f, err := parser.ParseFile(fset, fullPath, nil, parser.ParseComments) 264 if err != nil { 265 err = fmt.Errorf("go/parser failed to parse the file: %s, err: %v", c.path, err) 266 log.Warnf("NOTICE: This is not a bug. We cannot update the file %s because your codes failed to compile. Fix the compile errors and try again.\n%s", c.path, err.Error()) 267 return err 268 } 269 270 for _, d := range f.Decls { 271 if fd, ok := d.(*ast.FuncDecl); ok { 272 _, fn := parseFuncDecl(fd) 273 if fn == key { 274 have = true 275 break 276 } 277 } 278 } 279 return nil 280 }) 281 282 if !have { 283 newMethods = append(newMethods, m) 284 } 285 } 286 287 return newMethods, nil 288 } 289 290 // add imports for new methods 291 func (c *commonCompleter) addImport(w io.Writer, newMethods []*MethodInfo, fset *token.FileSet, handlerAST *ast.File) error { 292 existImports := make(map[string]bool) 293 for _, i := range handlerAST.Imports { 294 existImports[strings.Trim(i.Path.Value, "\"")] = true 295 } 296 tmp := c.pkg.Methods 297 defer func() { 298 c.pkg.Methods = tmp 299 }() 300 c.pkg.Methods = newMethods 301 for _, i := range c.update.ImportTpl { 302 importTask := &Task{ 303 Text: i, 304 } 305 content, err := importTask.RenderString(c.pkg) 306 if err != nil { 307 return err 308 } 309 imports := c.parseImports(content) 310 for idx := range imports { 311 if _, ok := existImports[strings.Trim(imports[idx][1], "\"")]; !ok { 312 astutil.AddImport(fset, handlerAST, strings.Trim(imports[idx][1], "\"")) 313 } 314 } 315 } 316 printer.Fprint(w, fset, handlerAST) 317 return nil 318 } 319 320 func (c *commonCompleter) addImplementations(w io.Writer, newMethods []*MethodInfo) error { 321 tmp := c.pkg.Methods 322 defer func() { 323 c.pkg.Methods = tmp 324 }() 325 c.pkg.Methods = newMethods 326 // generate implements of new methods 327 appendTask := &Task{ 328 Text: c.update.AppendTpl, 329 } 330 content, err := appendTask.RenderString(c.pkg) 331 if err != nil { 332 return err 333 } 334 _, err = w.Write([]byte(content)) 335 c.pkg.Methods = tmp 336 return err 337 } 338 339 // imports[2] is alias, import 340 func (c *commonCompleter) parseImports(content string) (imports [][2]string) { 341 if !strings.Contains(content, "\"") { 342 imports = append(imports, [2]string{"", content}) 343 return imports 344 } 345 for i := 0; i < len(content); i++ { 346 if content[i] == ' ' { 347 continue 348 } 349 isAlias := content[i] != '"' 350 351 start := i 352 for ; i < len(content); i++ { 353 if content[i] == ' ' { 354 break 355 } 356 } 357 sub := content[start:i] 358 switch { 359 case isAlias: 360 imports = append(imports, [2]string{sub, ""}) 361 case len(imports) > 0 && imports[len(imports)-1][1] == "": 362 imports[len(imports)-1][1] = sub 363 default: 364 imports = append(imports, [2]string{"", sub}) 365 } 366 } 367 return imports 368 }