vitess.io/vitess@v0.16.2/go/vt/vtctl/vtctldclient/codegen/main.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"go/types"
    23  	"io"
    24  	"os"
    25  	"path/filepath"
    26  	"regexp"
    27  	"sort"
    28  
    29  	"github.com/spf13/pflag"
    30  	"golang.org/x/tools/go/packages"
    31  )
    32  
    33  func main() { // nolint:funlen
    34  	source := pflag.String("source", "../../proto/vtctlservice", "source package")
    35  	typeName := pflag.String("type", "VtctldClient", "interface type to implement")
    36  	implType := pflag.String("impl", "gRPCVtctldClient", "type implementing the interface")
    37  	pkgName := pflag.String("targetpkg", "grpcvtctldclient", "package name to generate code for")
    38  	local := pflag.Bool("local", false, "generate a local, in-process client rather than a grpcclient")
    39  	out := pflag.String("out", "", "output destination. leave empty to use stdout")
    40  
    41  	pflag.Parse()
    42  
    43  	if *source == "" {
    44  		panic("--source cannot be empty")
    45  	}
    46  
    47  	if *typeName == "" {
    48  		panic("--type cannot be empty")
    49  	}
    50  
    51  	if *implType == "" {
    52  		panic("--impl cannot be empty")
    53  	}
    54  
    55  	if *pkgName == "" {
    56  		panic("--targetpkg cannot be empty")
    57  	}
    58  
    59  	var output io.Writer = os.Stdout
    60  
    61  	if *out != "" {
    62  		f, err := os.Create(*out)
    63  		if err != nil {
    64  			panic(err)
    65  		}
    66  
    67  		defer f.Close()
    68  		output = f
    69  	}
    70  
    71  	pkg, err := loadPackage(*source)
    72  	if err != nil {
    73  		panic(err)
    74  	}
    75  
    76  	iface, err := extractSourceInterface(pkg, *typeName)
    77  	if err != nil {
    78  		panic(fmt.Errorf("error getting %s in %s: %w", *typeName, *source, err))
    79  	}
    80  
    81  	imports := map[string]string{
    82  		"context": "context",
    83  	}
    84  	importNames := []string{}
    85  	funcs := make(map[string]*Func, iface.NumExplicitMethods())
    86  	funcNames := make([]string, iface.NumExplicitMethods())
    87  
    88  	for i := 0; i < iface.NumExplicitMethods(); i++ {
    89  		m := iface.ExplicitMethod(i)
    90  		funcNames[i] = m.Name()
    91  
    92  		sig, ok := m.Type().(*types.Signature)
    93  		if !ok {
    94  			panic(fmt.Sprintf("could not derive signature from method %s, have %T", m.FullName(), m.Type()))
    95  		}
    96  
    97  		if sig.Params().Len() != 3 {
    98  			panic(fmt.Sprintf("all methods in a grpc client interface should have exactly 3 params; found\n=> %s", sig))
    99  		}
   100  
   101  		if sig.Results().Len() != 2 {
   102  			panic(fmt.Sprintf("all methods in a grpc client interface should have exactly 2 results; found\n=> %s", sig))
   103  		}
   104  
   105  		f := &Func{
   106  			Name: m.Name(),
   107  		}
   108  		funcs[f.Name] = f
   109  
   110  		// The first parameter is always context.Context. The third parameter is
   111  		// always a ...grpc.CallOption.
   112  		param := sig.Params().At(1)
   113  
   114  		localType, localImport, pkgPath, err := extractLocalPointerType(param)
   115  		if err != nil {
   116  			panic(err)
   117  		}
   118  
   119  		f.Param.Name = param.Name()
   120  		f.Param.Type = "*" + localImport + "." + localType
   121  
   122  		if _, ok := imports[localImport]; !ok {
   123  			importNames = append(importNames, localImport)
   124  		}
   125  
   126  		imports[localImport] = pkgPath
   127  
   128  		// (TODO|@amason): check which grpc lib CallOption is imported from in
   129  		// this interface; it could be either google.golang.org/grpc or
   130  		// github.com/golang/protobuf/grpc, although in vitess we currently
   131  		// always use the former.
   132  
   133  		// In the case of unary RPCs, the first result is a Pointer. In the case
   134  		// of streaming RPCs, it is a Named type whose underlying type is an
   135  		// Interface.
   136  		//
   137  		// The second result is always error.
   138  		result := sig.Results().At(0)
   139  		switch result.Type().(type) {
   140  		case *types.Pointer:
   141  			localType, localImport, pkgPath, err = extractLocalPointerType(result)
   142  		case *types.Named:
   143  			switch result.Type().Underlying().(type) {
   144  			case *types.Interface:
   145  				f.IsStreaming = true
   146  				localType, localImport, pkgPath, err = extractLocalNamedType(result)
   147  				if err == nil && *local {
   148  					// We need to get the pointer type returned by `stream.Recv()`
   149  					// in the local case for the stream adapter.
   150  					var recvType, recvImport, recvPkgPath string
   151  					recvType, recvImport, recvPkgPath, err = extractRecvType(result)
   152  					if err == nil {
   153  						f.StreamMessage = buildParam("stream", recvImport, recvType, true)
   154  						importNames = addImport(recvImport, recvPkgPath, importNames, imports)
   155  					}
   156  				}
   157  			default:
   158  				err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type().Underlying())
   159  			}
   160  		default:
   161  			err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type())
   162  		}
   163  
   164  		if err != nil {
   165  			panic(err)
   166  		}
   167  
   168  		f.Result = buildParam(result.Name(), localImport, localType, !f.IsStreaming)
   169  		importNames = addImport(localImport, pkgPath, importNames, imports)
   170  	}
   171  
   172  	sort.Strings(importNames)
   173  	sort.Strings(funcNames)
   174  
   175  	def := &ClientInterfaceDef{
   176  		PackageName: *pkgName,
   177  		Type:        *implType,
   178  		ClientName:  "grpcvtctldclient",
   179  	}
   180  
   181  	if *local {
   182  		def.ClientName = "localvtctldclient"
   183  		def.Local = true
   184  	}
   185  
   186  	for _, name := range importNames {
   187  		imp := &Import{
   188  			Path: imports[name],
   189  		}
   190  
   191  		if filepath.Base(imp.Path) != name {
   192  			imp.Alias = name
   193  		}
   194  
   195  		def.Imports = append(def.Imports, imp)
   196  	}
   197  
   198  	for _, name := range funcNames {
   199  		def.Methods = append(def.Methods, funcs[name])
   200  	}
   201  
   202  	if err := tmpl.Execute(output, def); err != nil {
   203  		panic(err)
   204  	}
   205  }
   206  
   207  // ClientInterfaceDef is a struct providing enough information to generate an
   208  // implementation of a gRPC Client interface.
   209  type ClientInterfaceDef struct {
   210  	PackageName string
   211  	Type        string
   212  	Imports     []*Import
   213  	Methods     []*Func
   214  	Local       bool
   215  	ClientName  string
   216  }
   217  
   218  // NeedsGRPCShim returns true if the generated client code needs the internal
   219  // grpcshim imported. Currently this is true if the client is Local and has any
   220  // methods that are streaming RPCs.
   221  func (def *ClientInterfaceDef) NeedsGRPCShim() bool {
   222  	if !def.Local {
   223  		return false
   224  	}
   225  
   226  	for _, m := range def.Methods {
   227  		if m.IsStreaming {
   228  			return true
   229  		}
   230  	}
   231  
   232  	return false
   233  }
   234  
   235  // Import contains the meta information about a Go import.
   236  type Import struct {
   237  	Alias string
   238  	Path  string
   239  }
   240  
   241  // Func is the variable part of a gRPC client interface method (i.e. not the
   242  // context or dialopts arguments, or the error part of the result tuple).
   243  type Func struct {
   244  	Name          string
   245  	Param         Param
   246  	Result        Param
   247  	IsStreaming   bool
   248  	StreamMessage Param
   249  }
   250  
   251  // Param represents an element of either a parameter list or result list. It
   252  // contains an optional name, and a package-local type. This struct exists
   253  // purely to power template execution, which is why the Type field is simply a
   254  // bare string.
   255  type Param struct {
   256  	Name string
   257  	// locally-qualified type, e.g. "grpc.CallOption", and not "google.golang.org/grpc.CallOption".
   258  	Type string
   259  }
   260  
   261  func buildParam(name string, localImport string, localType string, isPointer bool) Param {
   262  	p := Param{
   263  		Name: name,
   264  		Type: fmt.Sprintf("%s.%s", localImport, localType),
   265  	}
   266  
   267  	if isPointer {
   268  		p.Type = "*" + p.Type
   269  	}
   270  
   271  	return p
   272  }
   273  
   274  func addImport(localImport string, pkgPath string, importNames []string, imports map[string]string) []string {
   275  	if _, ok := imports[localImport]; !ok {
   276  		importNames = append(importNames, localImport)
   277  	}
   278  
   279  	imports[localImport] = pkgPath
   280  	return importNames
   281  }
   282  
   283  func loadPackage(source string) (*packages.Package, error) {
   284  	pkgs, err := packages.Load(&packages.Config{
   285  		Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo,
   286  	}, source)
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  
   291  	if len(pkgs) != 1 {
   292  		return nil, errors.New("must specify exactly one package")
   293  	}
   294  
   295  	pkg := pkgs[0]
   296  	if len(pkg.Errors) > 0 {
   297  		var err error
   298  
   299  		for _, e := range pkg.Errors {
   300  			switch err {
   301  			case nil:
   302  				err = fmt.Errorf("errors loading package %s: %s", source, e.Error())
   303  			default:
   304  				err = fmt.Errorf("%w; %s", err, e.Error())
   305  			}
   306  		}
   307  
   308  		return nil, err
   309  	}
   310  
   311  	return pkg, nil
   312  }
   313  
   314  func extractSourceInterface(pkg *packages.Package, name string) (*types.Interface, error) {
   315  	obj := pkg.Types.Scope().Lookup(name)
   316  	if obj == nil {
   317  		return nil, fmt.Errorf("no symbol found with name %s", name)
   318  	}
   319  
   320  	switch t := obj.Type().(type) {
   321  	case *types.Named:
   322  		iface, ok := t.Underlying().(*types.Interface)
   323  		if !ok {
   324  			return nil, fmt.Errorf("symbol %s was not an interface but %T", name, t.Underlying())
   325  		}
   326  
   327  		return iface, nil
   328  	case *types.Interface:
   329  		return t, nil
   330  	}
   331  
   332  	return nil, fmt.Errorf("symbol %s was not an interface but %T", name, obj.Type())
   333  }
   334  
   335  var vitessProtoRegexp = regexp.MustCompile(`^vitess.io.*/proto/.*`)
   336  
   337  func rewriteProtoImports(pkg *types.Package) string {
   338  	if vitessProtoRegexp.MatchString(pkg.Path()) {
   339  		return pkg.Name() + "pb"
   340  	}
   341  
   342  	return pkg.Name()
   343  }
   344  
   345  func extractLocalNamedType(v *types.Var) (name string, localImport string, pkgPath string, err error) {
   346  	named, ok := v.Type().(*types.Named)
   347  	if !ok {
   348  		return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type())
   349  	}
   350  
   351  	name = named.Obj().Name()
   352  	localImport = rewriteProtoImports(named.Obj().Pkg())
   353  	pkgPath = named.Obj().Pkg().Path()
   354  
   355  	return name, localImport, pkgPath, nil
   356  }
   357  
   358  func extractLocalPointerType(v *types.Var) (name string, localImport string, pkgPath string, err error) {
   359  	ptr, ok := v.Type().(*types.Pointer)
   360  	if !ok {
   361  		return "", "", "", fmt.Errorf("expected a pointer type for %s, got %v", v.Name(), v.Type())
   362  	}
   363  
   364  	typ, ok := ptr.Elem().(*types.Named)
   365  	if !ok {
   366  		return "", "", "", fmt.Errorf("expected an underlying named type for %s, got %v", v.Name(), ptr.Elem())
   367  	}
   368  
   369  	name = typ.Obj().Name()
   370  	localImport = rewriteProtoImports(typ.Obj().Pkg())
   371  	pkgPath = typ.Obj().Pkg().Path()
   372  
   373  	return name, localImport, pkgPath, nil
   374  }
   375  
   376  func extractRecvType(v *types.Var) (name string, localImport string, pkgPath string, err error) {
   377  	named, ok := v.Type().(*types.Named)
   378  	if !ok {
   379  		return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type())
   380  	}
   381  
   382  	iface, ok := named.Underlying().(*types.Interface)
   383  	if !ok {
   384  		return "", "", "", fmt.Errorf("expected %s to name an interface type, got %v", v.Name(), named.Underlying())
   385  	}
   386  
   387  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   388  		m := iface.ExplicitMethod(i)
   389  		if m.Name() != "Recv" {
   390  			continue
   391  		}
   392  
   393  		sig, ok := m.Type().(*types.Signature)
   394  		if !ok {
   395  			return "", "", "", fmt.Errorf("%s.Recv should have type Signature; got %v", v.Name(), m.Type())
   396  		}
   397  
   398  		if sig.Results().Len() != 2 {
   399  			return "", "", "", fmt.Errorf("%s.Recv should return two values, not %d", v.Name(), sig.Results().Len())
   400  		}
   401  
   402  		return extractLocalPointerType(sig.Results().At(0))
   403  	}
   404  
   405  	return "", "", "", fmt.Errorf("interface %s has no explicit method named Recv", named.Obj().Name())
   406  }