github.com/cloudwego/kitex@v0.9.0/tool/internal_pkg/pluginmode/protoc/plugin.go (about) 1 // Copyright 2021 CloudWeGo Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package protoc 16 17 import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "path" 22 "path/filepath" 23 "strings" 24 "text/template" 25 26 genfastpb "github.com/cloudwego/fastpb/protoc-gen-fastpb/generator" 27 "github.com/cloudwego/thriftgo/generator/golang/streaming" 28 gengo "google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo" 29 "google.golang.org/protobuf/compiler/protogen" 30 31 "github.com/cloudwego/kitex/tool/internal_pkg/generator" 32 "github.com/cloudwego/kitex/tool/internal_pkg/log" 33 "github.com/cloudwego/kitex/tool/internal_pkg/util" 34 ) 35 36 type protocPlugin struct { 37 generator.Config 38 generator.PackageInfo 39 Services []*generator.ServiceInfo 40 kg generator.Generator 41 err error 42 importPaths map[string]string // file -> import path 43 } 44 45 // Name implements the protobuf_generator.Plugin interface. 46 func (pp *protocPlugin) Name() string { 47 return "kitex-internal" 48 } 49 50 // Init implements the protobuf_generator.Plugin interface. 51 func (pp *protocPlugin) init() { 52 pp.Dependencies = map[string]string{ 53 "proto": "google.golang.org/protobuf/proto", 54 } 55 } 56 57 // parse the 'M*' option 58 // See https://developers.google.com/protocol-buffers/docs/reference/go-generated for more information. 59 func (pp *protocPlugin) parseM() { 60 pp.importPaths = make(map[string]string) 61 for _, po := range pp.Config.ProtobufOptions { 62 if po == "" || po[0] != 'M' { 63 continue 64 } 65 idx := strings.Index(po, "=") 66 if idx < 0 { 67 continue 68 } 69 key := po[1:idx] 70 val := po[idx+1:] 71 if val == "" { 72 continue 73 } 74 idx = strings.Index(val, ";") 75 if idx >= 0 { 76 val = val[:idx] 77 } 78 pp.importPaths[key] = val 79 } 80 } 81 82 var interfaceTemplate = ` 83 84 // Code generated by Kitex {{.Version}}. DO NOT EDIT. 85 86 {{range .Interfaces}} 87 {{$serviceName := .Name}} 88 type {{.Name}} interface { 89 {{- range .Methods}} 90 {{- if or .ClientStreaming .ServerStreaming}} 91 {{.Name}}({{if and .ServerStreaming (not .ClientStreaming)}}req {{.ReqType}}, {{end}}stream {{$serviceName}}_{{.Name}}Server) (err error) 92 {{- else}} 93 {{.Name}}(ctx context.Context, req {{.ReqType}}) (res {{.ResType}}, err error) 94 {{- end}} 95 {{- end}} 96 } 97 98 {{range .Methods}} 99 {{- if or .ClientStreaming .ServerStreaming}} 100 type {{$serviceName}}_{{.Name}}Server interface { 101 streaming.Stream 102 {{- if .ClientStreaming}} 103 Recv() ({{.ReqType}}, error) 104 {{- end}} 105 {{- if .ServerStreaming}} 106 Send({{.ResType}}) error 107 {{- end}} 108 {{- if and .ClientStreaming (not .ServerStreaming)}} 109 SendAndClose({{.ResType}}) error 110 {{- end}} 111 } 112 {{- end}} 113 {{end}} 114 115 {{end}} 116 ` 117 118 // Generate implements the protobuf_generator.Plugin interface. 119 func (pp *protocPlugin) GenerateFile(gen *protogen.Plugin, file *protogen.File) { 120 if pp.err != nil { 121 return 122 } 123 gopkg := file.Proto.GetOptions().GetGoPackage() 124 if !strings.HasPrefix(gopkg, pp.PackagePrefix) { 125 log.Warnf("[WARN] %q is skipped because its import path %q is not located in ./kitex_gen. Change the go_package option or use '--protobuf M%s=A-Import-Path-In-kitex_gen' to override it if you want this file to be generated under kitex_gen.\n", 126 file.Proto.GetName(), gopkg, file.Proto.GetName()) 127 return 128 } 129 log.Infof("[INFO] Generate %q at %q\n", file.Proto.GetName(), gopkg) 130 131 if parts := strings.Split(gopkg, ";"); len(parts) > 1 { 132 gopkg = parts[0] // remove package alias from file path 133 } 134 pp.Namespace = strings.TrimPrefix(gopkg, pp.PackagePrefix) 135 pp.IDLName = util.IDLName(pp.Config.IDL) 136 137 ss := pp.convertTypes(file) 138 pp.Services = append(pp.Services, ss...) 139 140 if pp.Config.Use != "" { 141 return 142 } 143 144 hasStreaming := false 145 // generate service package 146 for _, si := range ss { 147 pp.ServiceInfo = si 148 fs, err := pp.kg.GenerateService(&pp.PackageInfo) 149 if err != nil { 150 pp.err = err 151 return 152 } 153 if !hasStreaming && si.HasStreaming { 154 hasStreaming = true 155 } 156 for _, f := range fs { 157 gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content) 158 } 159 } 160 // generate service interface 161 if pp.err == nil { 162 fixed := *file 163 fixed.GeneratedFilenamePrefix = strings.TrimPrefix(fixed.GeneratedFilenamePrefix, pp.PackagePrefix) 164 f := gengo.GenerateFile(gen, &fixed) 165 f.QualifiedGoIdent(protogen.GoIdent{GoImportPath: "context"}) 166 if hasStreaming { 167 f.QualifiedGoIdent(protogen.GoIdent{ 168 GoImportPath: "github.com/cloudwego/kitex/pkg/streaming", 169 }) 170 } 171 f.P("var _ context.Context") 172 173 if len(file.Services) != 0 { 174 tpl := template.New("interface") 175 tpl = template.Must(tpl.Parse(interfaceTemplate)) 176 var buf bytes.Buffer 177 pp.err = tpl.ExecuteTemplate(&buf, tpl.Name(), pp.makeInterfaces(f, file)) 178 179 f.P(buf.String()) 180 } 181 } 182 183 // generate fast api 184 if !pp.Config.NoFastAPI && pp.err == nil { 185 fixed := *file 186 fixed.GeneratedFilenamePrefix = strings.TrimPrefix(fixed.GeneratedFilenamePrefix, pp.PackagePrefix) 187 genfastpb.GenerateFile(gen, &fixed) 188 } 189 } 190 191 func (pp *protocPlugin) process(gen *protogen.Plugin) { 192 defer func() { 193 if e := recover(); e != nil { 194 if err, ok := e.(error); ok { 195 gen.Error(err) 196 } else { 197 gen.Error(fmt.Errorf("%+v", e)) 198 } 199 } 200 }() 201 if len(gen.Files) == 0 { 202 gen.Error(errors.New("no proto file")) 203 return 204 } 205 pp.kg = generator.NewGenerator(&pp.Config, nil) 206 // iterate over all proto files 207 idl := gen.Request.FileToGenerate[0] 208 for _, f := range gen.Files { 209 if pp.Config.Use != "" && f.Proto.GetName() != idl { 210 continue 211 } 212 pp.GenerateFile(gen, f) 213 } 214 215 if pp.Config.GenerateMain { 216 if len(pp.Services) == 0 { 217 gen.Error(errors.New("no service defined")) 218 return 219 } 220 pp.ServiceInfo = pp.Services[len(pp.Services)-1] 221 fs, err := pp.kg.GenerateMainPackage(&pp.PackageInfo) 222 if err != nil { 223 pp.err = err 224 } 225 for _, f := range fs { 226 gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content) 227 } 228 } 229 230 if pp.Config.TemplateDir != "" { 231 if len(pp.Services) == 0 { 232 gen.Error(errors.New("no service defined")) 233 return 234 } 235 pp.ServiceInfo = pp.Services[len(pp.Services)-1] 236 fs, err := pp.kg.GenerateCustomPackage(&pp.PackageInfo) 237 if err != nil { 238 pp.err = err 239 } 240 for _, f := range fs { 241 gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content) 242 } 243 } 244 245 if pp.err != nil { 246 gen.Error(pp.err) 247 } 248 return 249 } 250 251 func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.ServiceInfo) { 252 pth := pp.fixImport(string(file.GoImportPath)) 253 if pth == "" { 254 panic(fmt.Errorf("missing %q option in %q", "go_package", file.Desc.Name())) 255 } 256 pi := generator.PkgInfo{ 257 PkgName: file.Proto.GetPackage(), 258 PkgRefName: goSanitized(path.Base(pth)), 259 ImportPath: pth, 260 } 261 for _, service := range file.Services { 262 si := &generator.ServiceInfo{ 263 PkgInfo: pi, 264 ServiceName: service.GoName, 265 RawServiceName: string(service.Desc.Name()), 266 } 267 si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName } 268 for _, m := range service.Methods { 269 req := pp.convertParameter(m.Input, "Req") 270 res := pp.convertParameter(m.Output, "Resp") 271 272 methodName := m.GoName 273 mi := &generator.MethodInfo{ 274 PkgInfo: pi, 275 ServiceName: si.ServiceName, 276 RawName: string(m.Desc.Name()), 277 Name: methodName, 278 Args: []*generator.Parameter{req}, 279 Resp: res, 280 ArgStructName: methodName + "Args", 281 ResStructName: methodName + "Result", 282 GenArgResultStruct: true, 283 ClientStreaming: m.Desc.IsStreamingClient(), 284 ServerStreaming: m.Desc.IsStreamingServer(), 285 } 286 si.Methods = append(si.Methods, mi) 287 if !si.HasStreaming && (mi.ClientStreaming || mi.ServerStreaming) { 288 si.HasStreaming = true 289 } 290 } 291 for _, m := range si.Methods { 292 BuildStreaming(m, si.HasStreaming) 293 } 294 ss = append(ss, si) 295 } 296 // combine service 297 if pp.Config.CombineService && len(file.Services) > 0 { 298 var svcs []*generator.ServiceInfo 299 var methods []*generator.MethodInfo 300 for _, s := range ss { 301 svcs = append(svcs, s) 302 methods = append(methods, s.AllMethods()...) 303 } 304 // check method name conflict 305 mm := make(map[string]*generator.MethodInfo) 306 for _, m := range methods { 307 if _, ok := mm[m.Name]; ok { 308 log.Warnf("[WARN] combine service method %s in %s conflicts with %s in %s\n", 309 m.Name, m.ServiceName, m.Name, mm[m.Name].ServiceName) 310 return 311 } 312 mm[m.Name] = m 313 } 314 var hasStreaming bool 315 for _, m := range methods { 316 if m.ClientStreaming || m.ServerStreaming { 317 hasStreaming = true 318 } 319 } 320 svcName := pp.getCombineServiceName("CombineService", ss) 321 si := &generator.ServiceInfo{ 322 PkgInfo: pi, 323 ServiceName: svcName, 324 RawServiceName: svcName, 325 CombineServices: svcs, 326 Methods: methods, 327 HasStreaming: hasStreaming, 328 } 329 si.ServiceTypeName = func() string { return si.ServiceName } 330 ss = append(ss, si) 331 } 332 return 333 } 334 335 // BuildStreaming builds protobuf MethodInfo.Streaming as for Thrift, to simplify codegen 336 func BuildStreaming(mi *generator.MethodInfo, serviceHasStreaming bool) { 337 s := &streaming.Streaming{ 338 // pb: if one method is streaming, then the service is streaming, making all methods streaming 339 IsStreaming: serviceHasStreaming, 340 } 341 if mi.ClientStreaming && mi.ServerStreaming { 342 s.Mode = streaming.StreamingBidirectional 343 s.BidirectionalStreaming = true 344 s.ClientStreaming = true 345 s.ServerStreaming = true 346 } else if mi.ClientStreaming && !mi.ServerStreaming { 347 s.Mode = streaming.StreamingClientSide 348 s.ClientStreaming = true 349 } else if !mi.ClientStreaming && mi.ServerStreaming { 350 s.Mode = streaming.StreamingServerSide 351 s.ServerStreaming = true 352 } else if serviceHasStreaming { 353 s.Mode = streaming.StreamingUnary // Unary APIs over HTTP2 354 } 355 mi.Streaming = s 356 } 357 358 func (pp *protocPlugin) getCombineServiceName(name string, svcs []*generator.ServiceInfo) string { 359 for _, svc := range svcs { 360 if svc.ServiceName == name { 361 return pp.getCombineServiceName(name+"_", svcs) 362 } 363 } 364 return name 365 } 366 367 func (pp *protocPlugin) convertParameter(msg *protogen.Message, paramName string) *generator.Parameter { 368 importPath := pp.fixImport(msg.GoIdent.GoImportPath.String()) 369 pkgRefName := goSanitized(path.Base(importPath)) 370 res := &generator.Parameter{ 371 Deps: []generator.PkgInfo{ 372 { 373 PkgRefName: pkgRefName, 374 ImportPath: importPath, 375 }, 376 }, 377 Name: paramName, 378 RawName: paramName, 379 Type: "*" + pkgRefName + "." + msg.GoIdent.GoName, 380 } 381 return res 382 } 383 384 func (pp *protocPlugin) makeInterfaces(gf *protogen.GeneratedFile, file *protogen.File) interface{} { 385 var is []interface{} 386 for _, service := range file.Services { 387 i := struct { 388 Name string 389 Methods []interface{} 390 }{ 391 Name: service.GoName, 392 } 393 for _, m := range service.Methods { 394 i.Methods = append(i.Methods, struct { 395 Name string 396 ReqType string 397 ResType string 398 ClientStreaming bool 399 ServerStreaming bool 400 }{ 401 m.GoName, 402 "*" + gf.QualifiedGoIdent(m.Input.GoIdent), 403 "*" + gf.QualifiedGoIdent(m.Output.GoIdent), 404 m.Desc.IsStreamingClient(), 405 m.Desc.IsStreamingServer(), 406 }) 407 } 408 is = append(is, i) 409 } 410 return struct { 411 Version string 412 Interfaces []interface{} 413 }{pp.Config.Version, is} 414 } 415 416 func (pp *protocPlugin) adjustPath(path string) (ret string) { 417 cur, _ := filepath.Abs(".") 418 if pp.Config.Use == "" { 419 cur = util.JoinPath(cur, generator.KitexGenPath) 420 } 421 if filepath.IsAbs(path) { 422 path, _ = filepath.Rel(cur, path) 423 return path 424 } 425 if pp.ModuleName == "" { 426 gopath := util.GetGOPATH() 427 path = util.JoinPath(gopath, "src", path) 428 path, _ = filepath.Rel(cur, path) 429 } else { 430 path, _ = filepath.Rel(pp.ModuleName, path) 431 } 432 return path 433 } 434 435 func (pp *protocPlugin) fixImport(path string) string { 436 path = strings.Trim(path, "\"") 437 return path 438 }