vitess.io/vitess@v0.16.2/go/vt/vtctl/vtctldclient/codegen/main.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package main 18 19 import ( 20 "errors" 21 "fmt" 22 "go/types" 23 "io" 24 "os" 25 "path/filepath" 26 "regexp" 27 "sort" 28 29 "github.com/spf13/pflag" 30 "golang.org/x/tools/go/packages" 31 ) 32 33 func main() { // nolint:funlen 34 source := pflag.String("source", "../../proto/vtctlservice", "source package") 35 typeName := pflag.String("type", "VtctldClient", "interface type to implement") 36 implType := pflag.String("impl", "gRPCVtctldClient", "type implementing the interface") 37 pkgName := pflag.String("targetpkg", "grpcvtctldclient", "package name to generate code for") 38 local := pflag.Bool("local", false, "generate a local, in-process client rather than a grpcclient") 39 out := pflag.String("out", "", "output destination. leave empty to use stdout") 40 41 pflag.Parse() 42 43 if *source == "" { 44 panic("--source cannot be empty") 45 } 46 47 if *typeName == "" { 48 panic("--type cannot be empty") 49 } 50 51 if *implType == "" { 52 panic("--impl cannot be empty") 53 } 54 55 if *pkgName == "" { 56 panic("--targetpkg cannot be empty") 57 } 58 59 var output io.Writer = os.Stdout 60 61 if *out != "" { 62 f, err := os.Create(*out) 63 if err != nil { 64 panic(err) 65 } 66 67 defer f.Close() 68 output = f 69 } 70 71 pkg, err := loadPackage(*source) 72 if err != nil { 73 panic(err) 74 } 75 76 iface, err := extractSourceInterface(pkg, *typeName) 77 if err != nil { 78 panic(fmt.Errorf("error getting %s in %s: %w", *typeName, *source, err)) 79 } 80 81 imports := map[string]string{ 82 "context": "context", 83 } 84 importNames := []string{} 85 funcs := make(map[string]*Func, iface.NumExplicitMethods()) 86 funcNames := make([]string, iface.NumExplicitMethods()) 87 88 for i := 0; i < iface.NumExplicitMethods(); i++ { 89 m := iface.ExplicitMethod(i) 90 funcNames[i] = m.Name() 91 92 sig, ok := m.Type().(*types.Signature) 93 if !ok { 94 panic(fmt.Sprintf("could not derive signature from method %s, have %T", m.FullName(), m.Type())) 95 } 96 97 if sig.Params().Len() != 3 { 98 panic(fmt.Sprintf("all methods in a grpc client interface should have exactly 3 params; found\n=> %s", sig)) 99 } 100 101 if sig.Results().Len() != 2 { 102 panic(fmt.Sprintf("all methods in a grpc client interface should have exactly 2 results; found\n=> %s", sig)) 103 } 104 105 f := &Func{ 106 Name: m.Name(), 107 } 108 funcs[f.Name] = f 109 110 // The first parameter is always context.Context. The third parameter is 111 // always a ...grpc.CallOption. 112 param := sig.Params().At(1) 113 114 localType, localImport, pkgPath, err := extractLocalPointerType(param) 115 if err != nil { 116 panic(err) 117 } 118 119 f.Param.Name = param.Name() 120 f.Param.Type = "*" + localImport + "." + localType 121 122 if _, ok := imports[localImport]; !ok { 123 importNames = append(importNames, localImport) 124 } 125 126 imports[localImport] = pkgPath 127 128 // (TODO|@amason): check which grpc lib CallOption is imported from in 129 // this interface; it could be either google.golang.org/grpc or 130 // github.com/golang/protobuf/grpc, although in vitess we currently 131 // always use the former. 132 133 // In the case of unary RPCs, the first result is a Pointer. In the case 134 // of streaming RPCs, it is a Named type whose underlying type is an 135 // Interface. 136 // 137 // The second result is always error. 138 result := sig.Results().At(0) 139 switch result.Type().(type) { 140 case *types.Pointer: 141 localType, localImport, pkgPath, err = extractLocalPointerType(result) 142 case *types.Named: 143 switch result.Type().Underlying().(type) { 144 case *types.Interface: 145 f.IsStreaming = true 146 localType, localImport, pkgPath, err = extractLocalNamedType(result) 147 if err == nil && *local { 148 // We need to get the pointer type returned by `stream.Recv()` 149 // in the local case for the stream adapter. 150 var recvType, recvImport, recvPkgPath string 151 recvType, recvImport, recvPkgPath, err = extractRecvType(result) 152 if err == nil { 153 f.StreamMessage = buildParam("stream", recvImport, recvType, true) 154 importNames = addImport(recvImport, recvPkgPath, importNames, imports) 155 } 156 } 157 default: 158 err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type().Underlying()) 159 } 160 default: 161 err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type()) 162 } 163 164 if err != nil { 165 panic(err) 166 } 167 168 f.Result = buildParam(result.Name(), localImport, localType, !f.IsStreaming) 169 importNames = addImport(localImport, pkgPath, importNames, imports) 170 } 171 172 sort.Strings(importNames) 173 sort.Strings(funcNames) 174 175 def := &ClientInterfaceDef{ 176 PackageName: *pkgName, 177 Type: *implType, 178 ClientName: "grpcvtctldclient", 179 } 180 181 if *local { 182 def.ClientName = "localvtctldclient" 183 def.Local = true 184 } 185 186 for _, name := range importNames { 187 imp := &Import{ 188 Path: imports[name], 189 } 190 191 if filepath.Base(imp.Path) != name { 192 imp.Alias = name 193 } 194 195 def.Imports = append(def.Imports, imp) 196 } 197 198 for _, name := range funcNames { 199 def.Methods = append(def.Methods, funcs[name]) 200 } 201 202 if err := tmpl.Execute(output, def); err != nil { 203 panic(err) 204 } 205 } 206 207 // ClientInterfaceDef is a struct providing enough information to generate an 208 // implementation of a gRPC Client interface. 209 type ClientInterfaceDef struct { 210 PackageName string 211 Type string 212 Imports []*Import 213 Methods []*Func 214 Local bool 215 ClientName string 216 } 217 218 // NeedsGRPCShim returns true if the generated client code needs the internal 219 // grpcshim imported. Currently this is true if the client is Local and has any 220 // methods that are streaming RPCs. 221 func (def *ClientInterfaceDef) NeedsGRPCShim() bool { 222 if !def.Local { 223 return false 224 } 225 226 for _, m := range def.Methods { 227 if m.IsStreaming { 228 return true 229 } 230 } 231 232 return false 233 } 234 235 // Import contains the meta information about a Go import. 236 type Import struct { 237 Alias string 238 Path string 239 } 240 241 // Func is the variable part of a gRPC client interface method (i.e. not the 242 // context or dialopts arguments, or the error part of the result tuple). 243 type Func struct { 244 Name string 245 Param Param 246 Result Param 247 IsStreaming bool 248 StreamMessage Param 249 } 250 251 // Param represents an element of either a parameter list or result list. It 252 // contains an optional name, and a package-local type. This struct exists 253 // purely to power template execution, which is why the Type field is simply a 254 // bare string. 255 type Param struct { 256 Name string 257 // locally-qualified type, e.g. "grpc.CallOption", and not "google.golang.org/grpc.CallOption". 258 Type string 259 } 260 261 func buildParam(name string, localImport string, localType string, isPointer bool) Param { 262 p := Param{ 263 Name: name, 264 Type: fmt.Sprintf("%s.%s", localImport, localType), 265 } 266 267 if isPointer { 268 p.Type = "*" + p.Type 269 } 270 271 return p 272 } 273 274 func addImport(localImport string, pkgPath string, importNames []string, imports map[string]string) []string { 275 if _, ok := imports[localImport]; !ok { 276 importNames = append(importNames, localImport) 277 } 278 279 imports[localImport] = pkgPath 280 return importNames 281 } 282 283 func loadPackage(source string) (*packages.Package, error) { 284 pkgs, err := packages.Load(&packages.Config{ 285 Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, 286 }, source) 287 if err != nil { 288 return nil, err 289 } 290 291 if len(pkgs) != 1 { 292 return nil, errors.New("must specify exactly one package") 293 } 294 295 pkg := pkgs[0] 296 if len(pkg.Errors) > 0 { 297 var err error 298 299 for _, e := range pkg.Errors { 300 switch err { 301 case nil: 302 err = fmt.Errorf("errors loading package %s: %s", source, e.Error()) 303 default: 304 err = fmt.Errorf("%w; %s", err, e.Error()) 305 } 306 } 307 308 return nil, err 309 } 310 311 return pkg, nil 312 } 313 314 func extractSourceInterface(pkg *packages.Package, name string) (*types.Interface, error) { 315 obj := pkg.Types.Scope().Lookup(name) 316 if obj == nil { 317 return nil, fmt.Errorf("no symbol found with name %s", name) 318 } 319 320 switch t := obj.Type().(type) { 321 case *types.Named: 322 iface, ok := t.Underlying().(*types.Interface) 323 if !ok { 324 return nil, fmt.Errorf("symbol %s was not an interface but %T", name, t.Underlying()) 325 } 326 327 return iface, nil 328 case *types.Interface: 329 return t, nil 330 } 331 332 return nil, fmt.Errorf("symbol %s was not an interface but %T", name, obj.Type()) 333 } 334 335 var vitessProtoRegexp = regexp.MustCompile(`^vitess.io.*/proto/.*`) 336 337 func rewriteProtoImports(pkg *types.Package) string { 338 if vitessProtoRegexp.MatchString(pkg.Path()) { 339 return pkg.Name() + "pb" 340 } 341 342 return pkg.Name() 343 } 344 345 func extractLocalNamedType(v *types.Var) (name string, localImport string, pkgPath string, err error) { 346 named, ok := v.Type().(*types.Named) 347 if !ok { 348 return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type()) 349 } 350 351 name = named.Obj().Name() 352 localImport = rewriteProtoImports(named.Obj().Pkg()) 353 pkgPath = named.Obj().Pkg().Path() 354 355 return name, localImport, pkgPath, nil 356 } 357 358 func extractLocalPointerType(v *types.Var) (name string, localImport string, pkgPath string, err error) { 359 ptr, ok := v.Type().(*types.Pointer) 360 if !ok { 361 return "", "", "", fmt.Errorf("expected a pointer type for %s, got %v", v.Name(), v.Type()) 362 } 363 364 typ, ok := ptr.Elem().(*types.Named) 365 if !ok { 366 return "", "", "", fmt.Errorf("expected an underlying named type for %s, got %v", v.Name(), ptr.Elem()) 367 } 368 369 name = typ.Obj().Name() 370 localImport = rewriteProtoImports(typ.Obj().Pkg()) 371 pkgPath = typ.Obj().Pkg().Path() 372 373 return name, localImport, pkgPath, nil 374 } 375 376 func extractRecvType(v *types.Var) (name string, localImport string, pkgPath string, err error) { 377 named, ok := v.Type().(*types.Named) 378 if !ok { 379 return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type()) 380 } 381 382 iface, ok := named.Underlying().(*types.Interface) 383 if !ok { 384 return "", "", "", fmt.Errorf("expected %s to name an interface type, got %v", v.Name(), named.Underlying()) 385 } 386 387 for i := 0; i < iface.NumExplicitMethods(); i++ { 388 m := iface.ExplicitMethod(i) 389 if m.Name() != "Recv" { 390 continue 391 } 392 393 sig, ok := m.Type().(*types.Signature) 394 if !ok { 395 return "", "", "", fmt.Errorf("%s.Recv should have type Signature; got %v", v.Name(), m.Type()) 396 } 397 398 if sig.Results().Len() != 2 { 399 return "", "", "", fmt.Errorf("%s.Recv should return two values, not %d", v.Name(), sig.Results().Len()) 400 } 401 402 return extractLocalPointerType(sig.Results().At(0)) 403 } 404 405 return "", "", "", fmt.Errorf("interface %s has no explicit method named Recv", named.Obj().Name()) 406 }