github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/pluginmode/thriftgo/convertor.go (about) 1 // Copyright 2021 CloudWeGo Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package thriftgo 16 17 import ( 18 "fmt" 19 "go/format" 20 "io" 21 "io/ioutil" 22 "os" 23 "path/filepath" 24 "regexp" 25 "strings" 26 27 "github.com/cloudwego/thriftgo/generator/backend" 28 "github.com/cloudwego/thriftgo/generator/golang" 29 "github.com/cloudwego/thriftgo/generator/golang/streaming" 30 "github.com/cloudwego/thriftgo/parser" 31 "github.com/cloudwego/thriftgo/plugin" 32 "github.com/cloudwego/thriftgo/semantic" 33 34 "github.com/cloudwego/kitex/tool/internal_pkg/generator" 35 internal_log "github.com/cloudwego/kitex/tool/internal_pkg/log" 36 "github.com/cloudwego/kitex/tool/internal_pkg/util" 37 "github.com/cloudwego/kitex/transport" 38 ) 39 40 var ( 41 prelude = map[string]bool{"client": true, "server": true, "callopt": true, "context": true, "thrift": true, "kitex": true} 42 keyWords = []string{"client", "server", "callopt", "context", "thrift", "kitex"} 43 ) 44 45 type converter struct { 46 Warnings []string 47 Utils *golang.CodeUtils 48 Config generator.Config 49 Package generator.PackageInfo 50 Services []*generator.ServiceInfo 51 svc2ast map[*generator.ServiceInfo]*parser.Thrift 52 } 53 54 func (c *converter) init(req *plugin.Request) error { 55 if req.Language != "go" { 56 return fmt.Errorf("expect language to be 'go'. Encountered '%s'", req.Language) 57 } 58 59 // resotre the arguments for kitex 60 if err := c.Config.Unpack(req.PluginParameters); err != nil { 61 return err 62 } 63 64 c.Utils = golang.NewCodeUtils(c.initLogs()) 65 c.Utils.HandleOptions(req.GeneratorParameters) 66 67 return nil 68 } 69 70 func (c *converter) initLogs() backend.LogFunc { 71 lf := backend.LogFunc{ 72 Info: func(v ...interface{}) {}, 73 Warn: func(v ...interface{}) { 74 c.Warnings = append(c.Warnings, fmt.Sprint(v...)) 75 }, 76 MultiWarn: func(warns []string) { 77 c.Warnings = append(c.Warnings, warns...) 78 }, 79 } 80 if c.Config.Verbose { 81 lf.Info = lf.Warn 82 } 83 84 internal_log.SetDefaultLogger(internal_log.Logger{ 85 Println: func(w io.Writer, a ...interface{}) (n int, err error) { 86 if w != os.Stdout || c.Config.Verbose { 87 c.Warnings = append(c.Warnings, fmt.Sprint(a...)) 88 } 89 return 0, nil 90 }, 91 Printf: func(w io.Writer, format string, a ...interface{}) (n int, err error) { 92 if w != os.Stdout || c.Config.Verbose { 93 c.Warnings = append(c.Warnings, fmt.Sprintf(format, a...)) 94 } 95 return 0, nil 96 }, 97 }) 98 return lf 99 } 100 101 func (c *converter) fail(err error) int { 102 res := plugin.BuildErrorResponse(err.Error(), c.Warnings...) 103 return exit(res) 104 } 105 106 func (c *converter) avoidIncludeConflict(ast *parser.Thrift, ref string) (*parser.Thrift, string) { 107 fn := filepath.Base(ast.Filename) 108 for _, inc := range ast.Includes { 109 if filepath.Base(inc.Path) == fn { // will cause include conflict 110 ref = "kitex_faked_idl" 111 faked := *ast 112 faked.Filename = util.JoinPath(filepath.Dir(faked.Filename), ref+".thrift") 113 _, hasNamespace := ast.GetNamespace("go") 114 if !hasNamespace { 115 faked.Namespaces = append(faked.Namespaces, &parser.Namespace{ 116 Language: "go", 117 Name: ast.GetNamespaceOrReferenceName("go"), 118 }) 119 } 120 return &faked, ref 121 } 122 } 123 return ast, ref 124 } 125 126 // TODO: copy by marshal & unmarshal? to avoid missing fields. 127 func (c *converter) copyTreeWithRef(ast *parser.Thrift, ref string) *parser.Thrift { 128 ast, ref = c.avoidIncludeConflict(ast, ref) 129 130 t := &parser.Thrift{ 131 Filename: ast.Filename, 132 Namespaces: []*parser.Namespace{ 133 {Language: "*", Name: "fake"}, 134 }, 135 } 136 t.Includes = append(t.Includes, &parser.Include{Path: ast.Filename, Reference: ast}) 137 t.Includes = append(t.Includes, ast.Includes...) 138 139 for _, s := range ast.Services { 140 ss := &parser.Service{ 141 Name: s.Name, 142 Extends: s.Extends, 143 } 144 for _, f := range s.Functions { 145 ff := c.copyFunctionWithRef(f, ref) 146 ss.Functions = append(ss.Functions, ff) 147 } 148 t.Services = append(t.Services, ss) 149 } 150 return t 151 } 152 153 func (c *converter) copyFunctionWithRef(f *parser.Function, ref string) *parser.Function { 154 ff := &parser.Function{ 155 Name: f.Name, 156 Oneway: f.Oneway, 157 Void: f.Void, 158 FunctionType: c.copyTypeWithRef(f.FunctionType, ref), 159 Annotations: c.copyAnnotations(f.Annotations), 160 } 161 for _, x := range f.Arguments { 162 y := *x 163 y.Type = c.copyTypeWithRef(x.Type, ref) 164 ff.Arguments = append(ff.Arguments, &y) 165 } 166 for _, x := range f.Throws { 167 y := *x 168 y.Type = c.copyTypeWithRef(x.Type, ref) 169 ff.Throws = append(ff.Throws, &y) 170 } 171 return ff 172 } 173 174 func (c *converter) copyTypeWithRef(t *parser.Type, ref string) (res *parser.Type) { 175 switch t.Name { 176 case "void": 177 return t 178 case "bool", "byte", "i8", "i16", "i32", "i64", "double", "string", "binary": 179 return t 180 case "map": 181 return &parser.Type{ 182 Name: t.Name, 183 KeyType: c.copyTypeWithRef(t.KeyType, ref), 184 ValueType: c.copyTypeWithRef(t.ValueType, ref), 185 } 186 case "set", "list": 187 return &parser.Type{ 188 Name: t.Name, 189 ValueType: c.copyTypeWithRef(t.ValueType, ref), 190 } 191 default: 192 if strings.Contains(t.Name, ".") { 193 return &parser.Type{ 194 Name: t.Name, 195 KeyType: t.KeyType, 196 ValueType: t.ValueType, 197 } 198 } 199 return &parser.Type{ 200 Name: ref + "." + t.Name, 201 } 202 } 203 } 204 205 func (c *converter) getImports(t *parser.Type) (res []generator.PkgInfo) { 206 switch t.Name { 207 case "void": 208 return nil 209 case "bool", "byte", "i8", "i16", "i32", "i64", "double", "string", "binary": 210 return nil 211 case "map": 212 res = append(res, c.getImports(t.KeyType)...) 213 fallthrough 214 case "set", "list": 215 res = append(res, c.getImports(t.ValueType)...) 216 return res 217 default: 218 if ref := t.GetReference(); ref != nil { 219 inc := c.Utils.RootScope().Includes().ByIndex(int(ref.GetIndex())) 220 res = append(res, generator.PkgInfo{ 221 PkgRefName: inc.PackageName, 222 ImportPath: inc.ImportPath, 223 }) 224 } 225 return 226 } 227 } 228 229 func (c *converter) fixImportConflicts() { 230 for pkg, pth := range c.Package.Dependencies { 231 if prelude[pkg] { 232 delete(c.Package.Dependencies, pkg) 233 c.Package.Dependencies[pkg+"0"] = pth 234 } 235 } 236 var objs []interface{} 237 for _, s := range c.Services { 238 objs = append(objs, s) 239 } 240 241 fix := func(p *generator.PkgInfo) { 242 if prelude[p.PkgRefName] { 243 p.PkgRefName += "0" 244 } 245 } 246 kw := strings.Join(keyWords, "|") 247 re := regexp.MustCompile(`^(\*?)(` + kw + `)\.([^.]+)$`) 248 for len(objs) > 0 { 249 switch v := objs[0].(type) { 250 case *generator.ServiceInfo: 251 fix(&v.PkgInfo) 252 if v.Base != nil { 253 objs = append(objs, v.Base) 254 } 255 for _, m := range v.Methods { 256 objs = append(objs, m) 257 } 258 case *generator.MethodInfo: 259 fix(&v.PkgInfo) 260 for _, a := range v.Args { 261 objs = append(objs, a) 262 } 263 if !v.Void { 264 objs = append(objs, v.Resp) 265 } 266 for _, e := range v.Exceptions { 267 objs = append(objs, e) 268 } 269 v.ArgStructName = re.ReplaceAllString(v.ArgStructName, "${1}${2}0.${3}") 270 v.ResStructName = re.ReplaceAllString(v.ResStructName, "${1}${2}0.${3}") 271 case *generator.Parameter: 272 for i := 0; i < len(v.Deps); i++ { 273 fix(&v.Deps[i]) 274 } 275 v.Type = re.ReplaceAllString(v.Type, "${1}${2}0.${3}") 276 } 277 objs = objs[1:] 278 } 279 } 280 281 type ast2svc map[string][]*generator.ServiceInfo 282 283 func (t ast2svc) findService(ast *parser.Thrift, name string) *generator.ServiceInfo { 284 var list []*generator.ServiceInfo 285 for filename, l := range t { 286 if filename == ast.Filename { 287 list = l 288 break 289 } 290 } 291 for _, s := range list { 292 if s.RawServiceName == name { 293 return s 294 } 295 } 296 return nil 297 } 298 299 func (c *converter) convertTypes(req *plugin.Request) error { 300 var all ast2svc = make(map[string][]*generator.ServiceInfo) 301 302 c.svc2ast = make(map[*generator.ServiceInfo]*parser.Thrift) 303 for ast := range req.AST.DepthFirstSearch() { 304 ref, pkg, pth := c.Utils.ParseNamespace(ast) 305 // make the current ast as an include to produce correct type references. 306 fake := c.copyTreeWithRef(ast, ref) 307 fake.Name2Category = nil 308 if err := semantic.ResolveSymbols(fake); err != nil { 309 return fmt.Errorf("resolve fakse ast '%s': %w", ast.Filename, err) 310 } 311 used := true 312 fake.ForEachInclude(func(v *parser.Include) bool { 313 v.Used = &used // mark all includes used to force renaming for conflict IDLs in thriftgo 314 return true 315 }) 316 317 scope, err := golang.BuildScope(c.Utils, fake) 318 if err != nil { 319 return fmt.Errorf("build scope for fake ast '%s': %w", ast.Filename, err) 320 } 321 c.Utils.SetRootScope(scope) 322 pi := generator.PkgInfo{ 323 PkgName: pkg, 324 PkgRefName: pkg, 325 ImportPath: util.CombineOutputPath(c.Config.PackagePrefix, pth), 326 } 327 for _, svc := range scope.Services() { 328 si, err := c.makeService(pi, svc) 329 if err != nil { 330 return fmt.Errorf("%s: makeService '%s': %w", ast.Filename, svc.Name, err) 331 } 332 si.ServiceFilePath = ast.Filename 333 all[ast.Filename] = append(all[ast.Filename], si) 334 c.svc2ast[si] = ast 335 } 336 // fill .Base 337 for i, svc := range ast.Services { 338 if len(svc.Extends) > 0 { 339 si := all[ast.Filename][i] 340 parts := semantic.SplitType(svc.Extends) 341 switch len(parts) { 342 case 1: 343 si.Base = all.findService(ast, parts[0]) 344 case 2: 345 tmp, found := ast.GetReference(parts[0]) 346 if !found { 347 break 348 } 349 si.Base = all.findService(tmp, parts[1]) 350 } 351 if len(parts) > 0 && si.Base == nil { 352 return fmt.Errorf("base service '%s' %d not found for '%s'", svc.Extends, len(parts), svc.Name) 353 } 354 } 355 } 356 357 c.fixStreamingForExtendedServices(ast, all) 358 359 // combine service 360 if ast == req.AST && c.Config.CombineService && len(ast.Services) > 0 { 361 var ( 362 svcs []*generator.ServiceInfo 363 methods []*generator.MethodInfo 364 ) 365 hasStreaming := false 366 for _, s := range all[ast.Filename] { 367 svcs = append(svcs, s) 368 hasStreaming = hasStreaming || s.HasStreaming 369 methods = append(methods, s.AllMethods()...) 370 } 371 // check method name conflict 372 mm := make(map[string]*generator.MethodInfo) 373 for _, m := range methods { 374 if _, ok := mm[m.Name]; ok { 375 return fmt.Errorf("combine service method %s in %s conflicts with %s in %s", m.Name, m.ServiceName, m.Name, mm[m.Name].ServiceName) 376 } 377 mm[m.Name] = m 378 } 379 svcName := c.getCombineServiceName("CombineService", all[ast.Filename]) 380 si := &generator.ServiceInfo{ 381 PkgInfo: pi, 382 ServiceName: svcName, 383 RawServiceName: svcName, 384 CombineServices: svcs, 385 Methods: methods, 386 ServiceFilePath: ast.Filename, 387 HasStreaming: hasStreaming, 388 } 389 390 if c.IsHessian2() { 391 si.Protocol = transport.HESSIAN2.String() 392 } 393 394 si.HandlerReturnKeepResp = c.Config.HandlerReturnKeepResp 395 si.UseThriftReflection = c.Utils.Features().WithReflection 396 si.ServiceTypeName = func() string { return si.ServiceName } 397 all[ast.Filename] = append(all[ast.Filename], si) 398 c.svc2ast[si] = ast 399 } 400 401 c.Services = append(c.Services, all[ast.Filename]...) 402 } 403 return nil 404 } 405 406 func (c *converter) fixStreamingForExtendedServices(ast *parser.Thrift, all ast2svc) { 407 for i, svc := range ast.Services { 408 if svc.Extends == "" { 409 continue 410 } 411 si := all[ast.Filename][i] 412 if si.Base != nil { 413 si.FixHasStreamingForExtendedService() 414 } 415 } 416 } 417 418 func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*generator.ServiceInfo, error) { 419 si := &generator.ServiceInfo{ 420 PkgInfo: pkg, 421 ServiceName: svc.GoName().String(), 422 RawServiceName: svc.Name, 423 } 424 si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName } 425 426 for _, f := range svc.Functions() { 427 if strings.HasPrefix(f.Name, "_") { 428 continue 429 } 430 mi, err := c.makeMethod(si, f) 431 if err != nil { 432 return nil, err 433 } 434 si.Methods = append(si.Methods, mi) 435 } 436 437 if c.IsHessian2() { 438 si.Protocol = transport.HESSIAN2.String() 439 } 440 si.HandlerReturnKeepResp = c.Config.HandlerReturnKeepResp 441 si.UseThriftReflection = c.Utils.Features().WithReflection 442 return si, nil 443 } 444 445 func (c *converter) makeMethod(si *generator.ServiceInfo, f *golang.Function) (*generator.MethodInfo, error) { 446 st, err := streaming.ParseStreaming(f.Function) 447 if err != nil { 448 return nil, err 449 } 450 mi := &generator.MethodInfo{ 451 PkgInfo: si.PkgInfo, 452 ServiceName: si.ServiceName, 453 Name: f.GoName().String(), 454 RawName: f.Name, 455 Oneway: f.Oneway, 456 Void: f.Void, 457 ArgStructName: f.ArgType().GoName().String(), 458 GenArgResultStruct: false, 459 Streaming: st, 460 ClientStreaming: st.ClientStreaming, 461 ServerStreaming: st.ServerStreaming, 462 ArgsLength: len(f.Arguments()), 463 } 464 if st.IsStreaming { 465 si.HasStreaming = true 466 } 467 468 if !f.Oneway { 469 mi.ResStructName = f.ResType().GoName().String() 470 } 471 if !f.Void { 472 typeName := f.ResponseGoTypeName().String() 473 mi.Resp = &generator.Parameter{ 474 Deps: c.getImports(f.FunctionType), 475 Type: typeName, 476 } 477 mi.IsResponseNeedRedirect = "*"+typeName == f.ResType().Fields()[0].GoTypeName().String() 478 } 479 480 for _, a := range f.Arguments() { 481 arg := &generator.Parameter{ 482 Deps: c.getImports(a.Type), 483 Name: f.ArgType().Field(a.Name).GoName().String(), 484 RawName: a.GoName().String(), 485 Type: a.GoTypeName().String(), 486 } 487 mi.Args = append(mi.Args, arg) 488 } 489 for _, t := range f.Throws() { 490 ex := &generator.Parameter{ 491 Deps: c.getImports(t.Type), 492 Name: f.ResType().Field(t.Name).GoName().String(), 493 RawName: t.GoName().String(), 494 Type: t.GoTypeName().String(), 495 } 496 mi.Exceptions = append(mi.Exceptions, ex) 497 } 498 return mi, nil 499 } 500 501 func (c *converter) persist(res *plugin.Response) error { 502 for _, c := range res.Contents { 503 full := c.GetName() 504 content := []byte(c.Content) 505 if filepath.Ext(full) == ".go" { 506 if formatted, err := format.Source([]byte(c.Content)); err != nil { 507 internal_log.Warn(fmt.Sprintf("Failed to format %s: %s", full, err.Error())) 508 } else { 509 content = formatted 510 } 511 } 512 513 internal_log.Info("Write", full) 514 path := filepath.Dir(full) 515 if err := os.MkdirAll(path, 0o755); err != nil && !os.IsExist(err) { 516 return fmt.Errorf("failed to create path '%s': %w", path, err) 517 } 518 if err := ioutil.WriteFile(full, content, 0o644); err != nil { 519 return fmt.Errorf("failed to write file '%s': %w", full, err) 520 } 521 } 522 return nil 523 } 524 525 func (c *converter) getCombineServiceName(name string, svcs []*generator.ServiceInfo) string { 526 for _, svc := range svcs { 527 if svc.ServiceName == name { 528 return c.getCombineServiceName(name+"_", svcs) 529 } 530 } 531 return name 532 } 533 534 func (c *converter) IsHessian2() bool { 535 return strings.EqualFold(c.Config.Protocol, transport.HESSIAN2.String()) 536 } 537 538 func (c *converter) copyAnnotations(annotations parser.Annotations) parser.Annotations { 539 copied := make(parser.Annotations, 0, len(annotations)) 540 for _, annotation := range annotations { 541 values := make([]string, 0, len(annotation.Values)) 542 values = append(values, annotation.Values...) 543 copied = append(copied, &parser.Annotation{ 544 Key: annotation.Key, 545 Values: values, 546 }) 547 } 548 return copied 549 }