github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/generator/generator.go (about) 1 // Copyright 2021 CloudWeGo Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package generator . 16 package generator 17 18 import ( 19 "fmt" 20 "go/token" 21 "path/filepath" 22 "reflect" 23 "strconv" 24 "strings" 25 "time" 26 27 "github.com/cloudwego/kitex/tool/internal_pkg/log" 28 "github.com/cloudwego/kitex/tool/internal_pkg/tpl" 29 "github.com/cloudwego/kitex/tool/internal_pkg/util" 30 "github.com/cloudwego/kitex/transport" 31 ) 32 33 // Constants . 34 const ( 35 KitexGenPath = "kitex_gen" 36 DefaultCodec = "thrift" 37 38 BuildFileName = "build.sh" 39 BootstrapFileName = "bootstrap.sh" 40 ToolVersionFileName = "kitex_info.yaml" 41 HandlerFileName = "handler.go" 42 MainFileName = "main.go" 43 ClientFileName = "client.go" 44 ServerFileName = "server.go" 45 InvokerFileName = "invoker.go" 46 ServiceFileName = "*service.go" 47 ExtensionFilename = "extensions.yaml" 48 49 DefaultThriftPluginTimeLimit = time.Minute 50 ) 51 52 var ( 53 kitexImportPath = "github.com/cloudwego/kitex" 54 55 globalMiddlewares []Middleware 56 globalDependencies = map[string]string{ 57 "kitex": kitexImportPath, 58 "client": ImportPathTo("client"), 59 "server": ImportPathTo("server"), 60 "callopt": ImportPathTo("client/callopt"), 61 "frugal": "github.com/cloudwego/frugal", 62 } 63 ) 64 65 // SetKitexImportPath sets the import path of kitex. 66 // Must be called before generating code. 67 func SetKitexImportPath(path string) { 68 for k, v := range globalDependencies { 69 globalDependencies[k] = strings.ReplaceAll(v, kitexImportPath, path) 70 } 71 kitexImportPath = path 72 } 73 74 // ImportPathTo returns an import path to the specified package under kitex. 75 func ImportPathTo(pkg string) string { 76 return util.JoinPath(kitexImportPath, pkg) 77 } 78 79 // AddGlobalMiddleware adds middleware for all generators 80 func AddGlobalMiddleware(mw Middleware) { 81 globalMiddlewares = append(globalMiddlewares, mw) 82 } 83 84 // AddGlobalDependency adds dependency for all generators 85 func AddGlobalDependency(ref, path string) bool { 86 if _, ok := globalDependencies[ref]; !ok { 87 globalDependencies[ref] = path 88 return true 89 } 90 return false 91 } 92 93 // Generator generates the codes of main package and scripts for building a server based on kitex. 94 type Generator interface { 95 GenerateService(pkg *PackageInfo) ([]*File, error) 96 GenerateMainPackage(pkg *PackageInfo) ([]*File, error) 97 GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) 98 } 99 100 // Config . 101 type Config struct { 102 Verbose bool 103 GenerateMain bool // whether stuff in the main package should be generated 104 GenerateInvoker bool // generate main.go with invoker when main package generate 105 Version string 106 NoFastAPI bool 107 ModuleName string 108 ServiceName string 109 Use string 110 IDLType string 111 Includes util.StringSlice 112 ThriftOptions util.StringSlice 113 ProtobufOptions util.StringSlice 114 Hessian2Options util.StringSlice 115 IDL string // the IDL file passed on the command line 116 OutputPath string // the output path for main pkg and kitex_gen 117 PackagePrefix string 118 CombineService bool // combine services to one service 119 CopyIDL bool 120 ThriftPlugins util.StringSlice 121 ProtobufPlugins util.StringSlice 122 Features []feature 123 FrugalPretouch bool 124 ThriftPluginTimeLimit time.Duration 125 CompilerPath string // specify the path of thriftgo or protoc 126 127 ExtensionFile string 128 tmplExt *TemplateExtension 129 130 Record bool 131 RecordCmd []string 132 133 TemplateDir string 134 135 GenPath string 136 137 DeepCopyAPI bool 138 Protocol string 139 HandlerReturnKeepResp bool 140 } 141 142 // Pack packs the Config into a slice of "key=val" strings. 143 func (c *Config) Pack() (res []string) { 144 t := reflect.TypeOf(c).Elem() 145 v := reflect.ValueOf(c).Elem() 146 for i := 0; i < t.NumField(); i++ { 147 f := t.Field(i) 148 x := v.Field(i) 149 n := f.Name 150 151 // skip the plugin arguments to avoid the 'strings in strings' trouble 152 if f.Name == "ThriftPlugins" || !token.IsExported(f.Name) { 153 continue 154 } 155 156 if str, ok := x.Interface().(interface{ String() string }); ok { 157 res = append(res, n+"="+str.String()) 158 continue 159 } 160 161 switch x.Kind() { 162 case reflect.Bool: 163 res = append(res, n+"="+fmt.Sprint(x.Bool())) 164 case reflect.String: 165 res = append(res, n+"="+x.String()) 166 case reflect.Slice: 167 var ss []string 168 if x.Type().Elem().Kind() == reflect.Int { 169 for i := 0; i < x.Len(); i++ { 170 ss = append(ss, strconv.Itoa(int(x.Index(i).Int()))) 171 } 172 } else { 173 for i := 0; i < x.Len(); i++ { 174 ss = append(ss, x.Index(i).String()) 175 } 176 } 177 res = append(res, n+"="+strings.Join(ss, ";")) 178 default: 179 panic(fmt.Errorf("unsupported field type: %+v", f)) 180 } 181 } 182 return res 183 } 184 185 // Unpack restores the Config from a slice of "key=val" strings. 186 func (c *Config) Unpack(args []string) error { 187 t := reflect.TypeOf(c).Elem() 188 v := reflect.ValueOf(c).Elem() 189 for _, a := range args { 190 parts := strings.SplitN(a, "=", 2) 191 if len(parts) != 2 { 192 return fmt.Errorf("invalid argument: '%s'", a) 193 } 194 name, value := parts[0], parts[1] 195 f, ok := t.FieldByName(name) 196 if ok && value != "" { 197 x := v.FieldByName(name) 198 if _, ok := x.Interface().(time.Duration); ok { 199 if d, err := time.ParseDuration(value); err != nil { 200 return fmt.Errorf("invalid time duration '%s' for %s", value, name) 201 } else { 202 x.SetInt(int64(d)) 203 } 204 continue 205 } 206 switch x.Kind() { 207 case reflect.Bool: 208 x.SetBool(value == "true") 209 case reflect.String: 210 x.SetString(value) 211 case reflect.Slice: 212 ss := strings.Split(value, ";") 213 if x.Type().Elem().Kind() == reflect.Int { 214 n := reflect.MakeSlice(x.Type(), len(ss), len(ss)) 215 for i, s := range ss { 216 val, err := strconv.ParseInt(s, 10, 64) 217 if err != nil { 218 return err 219 } 220 n.Index(i).SetInt(val) 221 } 222 x.Set(n) 223 } else { 224 for _, s := range ss { 225 val := reflect.Append(x, reflect.ValueOf(s)) 226 x.Set(val) 227 } 228 } 229 default: 230 return fmt.Errorf("unsupported field type: %+v", f) 231 } 232 } 233 } 234 log.Verbose = c.Verbose 235 return c.ApplyExtension() 236 } 237 238 // AddFeature add registered feature to config 239 func (c *Config) AddFeature(key string) bool { 240 if f, ok := getFeature(key); ok { 241 c.Features = append(c.Features, f) 242 return true 243 } 244 return false 245 } 246 247 // ApplyExtension applies template extension. 248 func (c *Config) ApplyExtension() error { 249 templateExtExist := false 250 path := util.JoinPath(c.TemplateDir, ExtensionFilename) 251 if c.TemplateDir != "" && util.Exists(path) { 252 templateExtExist = true 253 } 254 255 if c.ExtensionFile == "" && !templateExtExist { 256 return nil 257 } 258 259 ext := new(TemplateExtension) 260 if c.ExtensionFile != "" { 261 if err := ext.FromYAMLFile(c.ExtensionFile); err != nil { 262 return fmt.Errorf("read template extension %q failed: %s", c.ExtensionFile, err.Error()) 263 } 264 } 265 266 if templateExtExist { 267 yamlExt := new(TemplateExtension) 268 if err := yamlExt.FromYAMLFile(path); err != nil { 269 return fmt.Errorf("read template extension %q failed: %s", path, err.Error()) 270 } 271 ext.Merge(yamlExt) 272 } 273 274 for _, fn := range ext.FeatureNames { 275 RegisterFeature(fn) 276 } 277 for _, fn := range ext.EnableFeatures { 278 c.AddFeature(fn) 279 } 280 for path, alias := range ext.Dependencies { 281 AddGlobalDependency(alias, path) 282 } 283 284 c.tmplExt = ext 285 return nil 286 } 287 288 // NewGenerator . 289 func NewGenerator(config *Config, middlewares []Middleware) Generator { 290 mws := append(globalMiddlewares, middlewares...) 291 g := &generator{Config: config, middlewares: mws} 292 if g.IDLType == "" { 293 g.IDLType = DefaultCodec 294 } 295 return g 296 } 297 298 // Middleware used generator 299 type Middleware func(HandleFunc) HandleFunc 300 301 // HandleFunc used generator 302 type HandleFunc func(*Task, *PackageInfo) (*File, error) 303 304 type generator struct { 305 *Config 306 middlewares []Middleware 307 } 308 309 func (g *generator) chainMWs(handle HandleFunc) HandleFunc { 310 for i := len(g.middlewares) - 1; i > -1; i-- { 311 handle = g.middlewares[i](handle) 312 } 313 return handle 314 } 315 316 func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error) { 317 g.updatePackageInfo(pkg) 318 319 tasks := []*Task{ 320 { 321 Name: BuildFileName, 322 Path: util.JoinPath(g.OutputPath, BuildFileName), 323 Text: tpl.BuildTpl, 324 }, 325 { 326 Name: BootstrapFileName, 327 Path: util.JoinPath(g.OutputPath, "script", BootstrapFileName), 328 Text: tpl.BootstrapTpl, 329 }, 330 { 331 Name: ToolVersionFileName, 332 Path: util.JoinPath(g.OutputPath, ToolVersionFileName), 333 Text: tpl.ToolVersionTpl, 334 }, 335 } 336 if !g.Config.GenerateInvoker { 337 tasks = append(tasks, &Task{ 338 Name: MainFileName, 339 Path: util.JoinPath(g.OutputPath, MainFileName), 340 Text: tpl.MainTpl, 341 }) 342 } 343 for _, t := range tasks { 344 if util.Exists(t.Path) { 345 log.Info(t.Path, "exists. Skipped.") 346 continue 347 } 348 g.setImports(t.Name, pkg) 349 handle := func(task *Task, pkg *PackageInfo) (*File, error) { 350 return task.Render(pkg) 351 } 352 f, err := g.chainMWs(handle)(t, pkg) 353 if err != nil { 354 return nil, err 355 } 356 fs = append(fs, f) 357 } 358 359 handlerFilePath := filepath.Join(g.OutputPath, HandlerFileName) 360 if util.Exists(handlerFilePath) { 361 comp := newCompleter( 362 pkg.ServiceInfo.AllMethods(), 363 handlerFilePath, 364 pkg.ServiceInfo.ServiceName) 365 f, err := comp.CompleteMethods() 366 if err != nil { 367 if err == errNoNewMethod { 368 return fs, nil 369 } 370 return nil, err 371 } 372 fs = append(fs, f) 373 } else { 374 task := Task{ 375 Name: HandlerFileName, 376 Path: handlerFilePath, 377 Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, 378 } 379 g.setImports(task.Name, pkg) 380 handle := func(task *Task, pkg *PackageInfo) (*File, error) { 381 return task.Render(pkg) 382 } 383 f, err := g.chainMWs(handle)(&task, pkg) 384 if err != nil { 385 return nil, err 386 } 387 fs = append(fs, f) 388 } 389 return 390 } 391 392 func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { 393 g.updatePackageInfo(pkg) 394 output := util.JoinPath(g.OutputPath, util.CombineOutputPath(g.GenPath, pkg.Namespace)) 395 svcPkg := strings.ToLower(pkg.ServiceName) 396 output = util.JoinPath(output, svcPkg) 397 ext := g.tmplExt 398 if ext == nil { 399 ext = new(TemplateExtension) 400 } 401 402 tasks := []*Task{ 403 { 404 Name: ClientFileName, 405 Path: util.JoinPath(output, ClientFileName), 406 Text: tpl.ClientTpl, 407 Ext: ext.ExtendClient, 408 }, 409 { 410 Name: ServerFileName, 411 Path: util.JoinPath(output, ServerFileName), 412 Text: tpl.ServerTpl, 413 Ext: ext.ExtendServer, 414 }, 415 { 416 Name: InvokerFileName, 417 Path: util.JoinPath(output, InvokerFileName), 418 Text: tpl.InvokerTpl, 419 Ext: ext.ExtendInvoker, 420 }, 421 { 422 Name: ServiceFileName, 423 Path: util.JoinPath(output, svcPkg+".go"), 424 Text: tpl.ServiceTpl, 425 }, 426 } 427 428 var fs []*File 429 for _, t := range tasks { 430 if err := t.Build(); err != nil { 431 err = fmt.Errorf("build %s failed: %w", t.Name, err) 432 return nil, err 433 } 434 g.setImports(t.Name, pkg) 435 if t.Ext != nil { 436 for _, path := range t.Ext.ImportPaths { 437 if alias, exist := ext.Dependencies[path]; exist { 438 pkg.AddImports(alias) 439 } 440 } 441 } 442 handle := func(task *Task, pkg *PackageInfo) (*File, error) { 443 return task.Render(pkg) 444 } 445 f, err := g.chainMWs(handle)(t, pkg) 446 if err != nil { 447 err = fmt.Errorf("render %s failed: %w", t.Name, err) 448 return nil, err 449 } 450 fs = append(fs, f) 451 } 452 return fs, nil 453 } 454 455 func (g *generator) updatePackageInfo(pkg *PackageInfo) { 456 pkg.NoFastAPI = g.NoFastAPI 457 pkg.Codec = g.IDLType 458 pkg.Version = g.Version 459 pkg.RealServiceName = g.ServiceName 460 pkg.Features = g.Features 461 pkg.ExternalKitexGen = g.Use 462 pkg.FrugalPretouch = g.FrugalPretouch 463 pkg.Module = g.ModuleName 464 if strings.EqualFold(g.Protocol, transport.HESSIAN2.String()) { 465 pkg.Protocol = transport.HESSIAN2 466 } 467 if pkg.Dependencies == nil { 468 pkg.Dependencies = make(map[string]string) 469 } 470 471 for ref, path := range globalDependencies { 472 if _, ok := pkg.Dependencies[ref]; !ok { 473 pkg.Dependencies[ref] = path 474 } 475 } 476 } 477 478 func (g *generator) setImports(name string, pkg *PackageInfo) { 479 pkg.Imports = make(map[string]map[string]bool) 480 switch name { 481 case ClientFileName: 482 pkg.AddImports("client") 483 if pkg.HasStreaming { 484 pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") 485 pkg.AddImport("transport", "github.com/cloudwego/kitex/transport") 486 } 487 if len(pkg.AllMethods()) > 0 { 488 if needCallOpt(pkg) { 489 pkg.AddImports("callopt") 490 } 491 pkg.AddImports("context") 492 } 493 fallthrough 494 case HandlerFileName: 495 for _, m := range pkg.ServiceInfo.AllMethods() { 496 if !m.ServerStreaming && !m.ClientStreaming { 497 pkg.AddImports("context") 498 } 499 for _, a := range m.Args { 500 for _, dep := range a.Deps { 501 pkg.AddImport(dep.PkgRefName, dep.ImportPath) 502 } 503 } 504 if !m.Void && m.Resp != nil { 505 for _, dep := range m.Resp.Deps { 506 pkg.AddImport(dep.PkgRefName, dep.ImportPath) 507 } 508 } 509 } 510 case ServerFileName, InvokerFileName: 511 if len(pkg.CombineServices) == 0 { 512 pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath) 513 } 514 pkg.AddImports("server") 515 case ServiceFileName: 516 pkg.AddImports("errors") 517 pkg.AddImports("client") 518 pkg.AddImport("kitex", "github.com/cloudwego/kitex/pkg/serviceinfo") 519 pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath) 520 if len(pkg.AllMethods()) > 0 { 521 pkg.AddImports("context") 522 } 523 for _, m := range pkg.ServiceInfo.AllMethods() { 524 if m.ClientStreaming || m.ServerStreaming { 525 pkg.AddImports("fmt") 526 } 527 if m.GenArgResultStruct { 528 pkg.AddImports("proto") 529 } else { 530 // for method Arg and Result 531 pkg.AddImport(m.PkgRefName, m.ImportPath) 532 } 533 for _, a := range m.Args { 534 for _, dep := range a.Deps { 535 pkg.AddImport(dep.PkgRefName, dep.ImportPath) 536 } 537 } 538 if m.Streaming.IsStreaming || pkg.Codec == "protobuf" { 539 // protobuf handler support both PingPong and Unary (streaming) requests 540 pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") 541 } 542 if !m.Void && m.Resp != nil { 543 for _, dep := range m.Resp.Deps { 544 pkg.AddImport(dep.PkgRefName, dep.ImportPath) 545 } 546 } 547 for _, e := range m.Exceptions { 548 for _, dep := range e.Deps { 549 pkg.AddImport(dep.PkgRefName, dep.ImportPath) 550 } 551 } 552 } 553 if pkg.FrugalPretouch { 554 pkg.AddImports("sync") 555 if len(pkg.AllMethods()) > 0 { 556 pkg.AddImports("frugal") 557 pkg.AddImports("reflect") 558 } 559 } 560 case MainFileName: 561 pkg.AddImport("log", "log") 562 pkg.AddImport(pkg.PkgRefName, util.JoinPath(pkg.ImportPath, strings.ToLower(pkg.ServiceName))) 563 } 564 } 565 566 func needCallOpt(pkg *PackageInfo) bool { 567 // callopt is referenced only by non-streaming methods 568 needCallOpt := false 569 switch pkg.Codec { 570 case "thrift": 571 for _, m := range pkg.ServiceInfo.AllMethods() { 572 if !m.Streaming.IsStreaming { 573 needCallOpt = true 574 break 575 } 576 } 577 case "protobuf": 578 needCallOpt = true 579 } 580 return needCallOpt 581 }