github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/pluginmode/thriftgo/patcher.go (about) 1 // Copyright 2021 CloudWeGo Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package thriftgo 16 17 import ( 18 "fmt" 19 "io/ioutil" 20 "path/filepath" 21 "reflect" 22 "runtime" 23 "sort" 24 "strconv" 25 "strings" 26 "text/template" 27 28 "github.com/cloudwego/thriftgo/generator/golang" 29 "github.com/cloudwego/thriftgo/generator/golang/templates" 30 "github.com/cloudwego/thriftgo/generator/golang/templates/slim" 31 "github.com/cloudwego/thriftgo/parser" 32 "github.com/cloudwego/thriftgo/plugin" 33 34 "github.com/cloudwego/kitex/tool/internal_pkg/generator" 35 "github.com/cloudwego/kitex/tool/internal_pkg/util" 36 ) 37 38 var extraTemplates []string 39 40 // AppendToTemplate string 41 func AppendToTemplate(text string) { 42 extraTemplates = append(extraTemplates, text) 43 } 44 45 const kitexUnusedProtection = ` 46 // KitexUnusedProtection is used to prevent 'imported and not used' error. 47 var KitexUnusedProtection = struct{}{} 48 ` 49 50 //lint:ignore U1000 until protectionInsertionPoint is used 51 var protectionInsertionPoint = "KitexUnusedProtection" 52 53 type patcher struct { 54 noFastAPI bool 55 utils *golang.CodeUtils 56 module string 57 copyIDL bool 58 version string 59 record bool 60 recordCmd []string 61 deepCopyAPI bool 62 protocol string 63 handlerReturnKeepResp bool 64 65 fileTpl *template.Template 66 } 67 68 func (p *patcher) buildTemplates() (err error) { 69 m := p.utils.BuildFuncMap() 70 m["ReorderStructFields"] = p.reorderStructFields 71 m["TypeIDToGoType"] = func(t string) string { return typeIDToGoType[t] } 72 m["IsBinaryOrStringType"] = p.isBinaryOrStringType 73 m["Version"] = func() string { return p.version } 74 m["GenerateFastAPIs"] = func() bool { return !p.noFastAPI && p.utils.Template() != "slim" } 75 m["GenerateDeepCopyAPIs"] = func() bool { return p.deepCopyAPI } 76 m["GenerateArgsResultTypes"] = func() bool { return p.utils.Template() == "slim" } 77 m["ImportPathTo"] = generator.ImportPathTo 78 m["ToPackageNames"] = func(imports map[string]string) (res []string) { 79 for pth, alias := range imports { 80 if alias != "" { 81 res = append(res, alias) 82 } else { 83 res = append(res, strings.ToLower(filepath.Base(pth))) 84 } 85 } 86 sort.Strings(res) 87 return 88 } 89 m["Str"] = func(id int32) string { 90 if id < 0 { 91 return "_" + strconv.Itoa(-int(id)) 92 } 93 return strconv.Itoa(int(id)) 94 } 95 m["IsNil"] = func(i interface{}) bool { 96 return i == nil || reflect.ValueOf(i).IsNil() 97 } 98 m["SourceTarget"] = func(s string) string { 99 // p.XXX 100 if strings.HasPrefix(s, "p.") { 101 return "src." + s[2:] 102 } 103 // _key, _val 104 return s[1:] 105 } 106 m["FieldName"] = func(s string) string { 107 // p.XXX 108 return strings.ToLower(s[2:3]) + s[3:] 109 } 110 m["IsHessian"] = func() bool { 111 return p.IsHessian2() 112 } 113 m["IsGoStringType"] = func(typeName golang.TypeName) bool { 114 return typeName == "string" || typeName == "*string" 115 } 116 117 tpl := template.New("kitex").Funcs(m) 118 allTemplates := basicTemplates 119 if p.utils.Template() == "slim" { 120 allTemplates = append(allTemplates, slim.StructLike, 121 templates.StructLikeDefault, 122 templates.FieldGetOrSet, 123 templates.FieldIsSet, 124 structLikeDeepCopy, 125 fieldDeepCopy, 126 fieldDeepCopyStructLike, 127 fieldDeepCopyContainer, 128 fieldDeepCopyMap, 129 fieldDeepCopyList, 130 fieldDeepCopySet, 131 fieldDeepCopyBaseType, 132 structLikeCodec, 133 structLikeProtocol, 134 javaClassName, 135 processor, 136 ) 137 } else { 138 allTemplates = append(allTemplates, structLikeCodec, 139 structLikeFastRead, 140 structLikeFastReadField, 141 structLikeDeepCopy, 142 structLikeFastWrite, 143 structLikeFastWriteNocopy, 144 structLikeLength, 145 structLikeFastWriteField, 146 structLikeFieldLength, 147 structLikeProtocol, 148 javaClassName, 149 fieldFastRead, 150 fieldFastReadStructLike, 151 fieldFastReadBaseType, 152 fieldFastReadContainer, 153 fieldFastReadMap, 154 fieldFastReadSet, 155 fieldFastReadList, 156 fieldDeepCopy, 157 fieldDeepCopyStructLike, 158 fieldDeepCopyContainer, 159 fieldDeepCopyMap, 160 fieldDeepCopyList, 161 fieldDeepCopySet, 162 fieldDeepCopyBaseType, 163 fieldFastWrite, 164 fieldLength, 165 fieldFastWriteStructLike, 166 fieldStructLikeLength, 167 fieldFastWriteBaseType, 168 fieldBaseTypeLength, 169 fieldFixedLengthTypeLength, 170 fieldFastWriteContainer, 171 fieldContainerLength, 172 fieldFastWriteMap, 173 fieldMapLength, 174 fieldFastWriteSet, 175 fieldSetLength, 176 fieldFastWriteList, 177 fieldListLength, 178 templates.FieldDeepEqual, 179 templates.FieldDeepEqualBase, 180 templates.FieldDeepEqualStructLike, 181 templates.FieldDeepEqualContainer, 182 validateSet, 183 processor, 184 ) 185 } 186 for _, txt := range allTemplates { 187 tpl = template.Must(tpl.Parse(txt)) 188 } 189 190 ext := `{{define "ExtraTemplates"}}{{end}}` 191 if len(extraTemplates) > 0 { 192 ext = fmt.Sprintf("{{define \"ExtraTemplates\"}}\n%s\n{{end}}", 193 strings.Join(extraTemplates, "\n")) 194 } 195 tpl, err = tpl.Parse(ext) 196 if err != nil { 197 return fmt.Errorf("failed to parse extra templates: %w: %q", err, ext) 198 } 199 200 if p.IsHessian2() { 201 tpl, err = tpl.Parse(registerHessian) 202 if err != nil { 203 return fmt.Errorf("failed to parse hessian2 templates: %w: %q", err, registerHessian) 204 } 205 } 206 207 p.fileTpl = tpl 208 return nil 209 } 210 211 func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err error) { 212 p.buildTemplates() 213 var buf strings.Builder 214 215 protection := make(map[string]*plugin.Generated) 216 217 for ast := range req.AST.DepthFirstSearch() { 218 // scope, err := golang.BuildScope(p.utils, ast) 219 scope, _, err := golang.BuildRefScope(p.utils, ast) 220 if err != nil { 221 return nil, fmt.Errorf("build scope for ast %q: %w", ast.Filename, err) 222 } 223 p.utils.SetRootScope(scope) 224 225 pkgName := golang.GetImportPackage(golang.GetImportPath(p.utils, ast)) 226 227 path := p.utils.CombineOutputPath(req.OutputPath, ast) 228 base := p.utils.GetFilename(ast) 229 target := util.JoinPath(path, "k-"+base) 230 231 // Define KitexUnusedProtection in k-consts.go . 232 // Add k-consts.go before target to force the k-consts.go generated by consts.thrift to be renamed. 233 consts := util.JoinPath(path, "k-consts.go") 234 if protection[consts] == nil { 235 patch := &plugin.Generated{ 236 Content: "package " + pkgName + "\n" + kitexUnusedProtection, 237 Name: &consts, 238 } 239 patches = append(patches, patch) 240 protection[consts] = patch 241 } 242 243 buf.Reset() 244 245 // if all scopes are ref, don't generate k-xxx 246 if scope == nil { 247 continue 248 } 249 250 if p.IsHessian2() { 251 register := util.JoinPath(path, fmt.Sprintf("hessian2-register-%s", base)) 252 patch, err := p.patchHessian(path, scope, pkgName, base) 253 if err != nil { 254 return nil, fmt.Errorf("patch hessian fail for %q: %w", ast.Filename, err) 255 } 256 257 patches = append(patches, patch) 258 protection[register] = patch 259 } 260 261 data := &struct { 262 Scope *golang.Scope 263 PkgName string 264 Imports map[string]string 265 }{Scope: scope, PkgName: pkgName} 266 data.Imports, err = scope.ResolveImports() 267 if err != nil { 268 return nil, fmt.Errorf("resolve imports failed for %q: %w", ast.Filename, err) 269 } 270 p.filterStdLib(data.Imports) 271 if err = p.fileTpl.ExecuteTemplate(&buf, "file", data); err != nil { 272 return nil, fmt.Errorf("%q: %w", ast.Filename, err) 273 } 274 content := buf.String() 275 // if kutils is not used, remove the dependency. 276 if !strings.Contains(content, "kutils.StringDeepCopy") { 277 kutilsImp := `kutils "github.com/cloudwego/kitex/pkg/utils"` 278 idx := strings.Index(content, kutilsImp) 279 if idx > 0 { 280 content = content[:idx-1] + content[idx+len(kutilsImp):] 281 } 282 } 283 patches = append(patches, &plugin.Generated{ 284 Content: content, 285 Name: &target, 286 }) 287 288 if p.copyIDL { 289 content, err := ioutil.ReadFile(ast.Filename) 290 if err != nil { 291 return nil, fmt.Errorf("read %q: %w", ast.Filename, err) 292 } 293 path := util.JoinPath(path, filepath.Base(ast.Filename)) 294 patches = append(patches, &plugin.Generated{ 295 Content: string(content), 296 Name: &path, 297 }) 298 } 299 300 if p.record { 301 content := doRecord(p.recordCmd) 302 bashPath := util.JoinPath(getBashPath()) 303 patches = append(patches, &plugin.Generated{ 304 Content: content, 305 Name: &bashPath, 306 }) 307 } 308 309 } 310 return 311 } 312 313 func (p *patcher) patchHessian(path string, scope *golang.Scope, pkgName, base string) (patch *plugin.Generated, err error) { 314 buf := strings.Builder{} 315 resigterIDLName := fmt.Sprintf("hessian2-register-%s", base) 316 register := util.JoinPath(path, resigterIDLName) 317 data := &struct { 318 Scope *golang.Scope 319 PkgName string 320 Imports map[string]string 321 GoName string 322 IDLName string 323 }{Scope: scope, PkgName: pkgName, IDLName: util.UpperFirst(strings.Replace(base, ".go", "", -1))} 324 data.Imports, err = scope.ResolveImports() 325 if err != nil { 326 return nil, err 327 } 328 329 if err = p.fileTpl.ExecuteTemplate(&buf, "register", data); err != nil { 330 return nil, err 331 } 332 patch = &plugin.Generated{ 333 Content: buf.String(), 334 Name: ®ister, 335 } 336 return patch, nil 337 } 338 339 func getBashPath() string { 340 if runtime.GOOS == "windows" { 341 return "kitex-all.bat" 342 } 343 return "kitex-all.sh" 344 } 345 346 // DoRecord records current cmd into kitex-all.sh 347 func doRecord(recordCmd []string) string { 348 bytes, err := ioutil.ReadFile(getBashPath()) 349 content := string(bytes) 350 if err != nil { 351 content = "#! /usr/bin/env bash\n" 352 } 353 var input, currentIdl string 354 for _, s := range recordCmd { 355 if s != "-record" { 356 input += s + " " 357 } 358 if strings.HasSuffix(s, ".thrift") || strings.HasSuffix(s, ".proto") { 359 currentIdl = s 360 } 361 } 362 if input != "" && currentIdl != "" { 363 find := false 364 lines := strings.Split(content, "\n") 365 for i, line := range lines { 366 if strings.Contains(input, "-service") && strings.Contains(line, "-service") { 367 lines[i] = input 368 find = true 369 break 370 } 371 if strings.Contains(line, currentIdl) && !strings.Contains(line, "-service") { 372 lines[i] = input 373 find = true 374 break 375 } 376 } 377 if !find { 378 content += "\n" + input 379 } else { 380 content = strings.Join(lines, "\n") 381 } 382 } 383 return content 384 } 385 386 func (p *patcher) reorderStructFields(fields []*golang.Field) ([]*golang.Field, error) { 387 fixedLengthFields := make(map[*golang.Field]bool, len(fields)) 388 for _, field := range fields { 389 fixedLengthFields[field] = golang.IsFixedLengthType(field.Type) 390 } 391 392 sortedFields := make([]*golang.Field, 0, len(fields)) 393 for _, v := range fields { 394 if fixedLengthFields[v] { 395 sortedFields = append(sortedFields, v) 396 } 397 } 398 for _, v := range fields { 399 if !fixedLengthFields[v] { 400 sortedFields = append(sortedFields, v) 401 } 402 } 403 404 return sortedFields, nil 405 } 406 407 func (p *patcher) filterStdLib(imports map[string]string) { 408 // remove std libs and thrift to prevent duplicate import. 409 prefix := p.module + "/" 410 for pth := range imports { 411 if strings.HasPrefix(pth, prefix) { // local module 412 continue 413 } 414 if pth == "github.com/apache/thrift/lib/go/thrift" { 415 delete(imports, pth) 416 } 417 if strings.HasPrefix(pth, "github.com/cloudwego/thriftgo") { 418 delete(imports, pth) 419 } 420 if !strings.Contains(pth, ".") { // std lib 421 delete(imports, pth) 422 } 423 } 424 } 425 426 func (p *patcher) isBinaryOrStringType(t *parser.Type) bool { 427 return t.Category.IsBinary() || t.Category.IsString() 428 } 429 430 func (p *patcher) IsHessian2() bool { 431 return strings.EqualFold(p.protocol, "hessian2") 432 } 433 434 var typeIDToGoType = map[string]string{ 435 "Bool": "bool", 436 "Byte": "int8", 437 "I16": "int16", 438 "I32": "int32", 439 "I64": "int64", 440 "Double": "float64", 441 "String": "string", 442 "Binary": "[]byte", 443 }