github.com/benma/gogen@v0.0.0-20160826115606-cf49914b915a/exportdefault/generator.go (about)

     1  // Package exportdefault provides the functionality to automatically generate
     2  // package-level exported functions wrapping calls to a package-level default
     3  // instance of a type.
     4  //
     5  // This helps auto-generating code for the common use case where a package
     6  // implements certain information as methods within a stub and, for
     7  // convenience, exports functions that wrap calls to those methods on a default
     8  // variable.
     9  //
    10  // Some examples of that behaviour in the stdlib:
    11  //
    12  //  - `net/http` has `http.DefaultClient` and functions like `http.Get` just
    13  //     call the default `http.DefaultClient.Get`
    14  //  - `log` has `log.Logger` and functions like `log.Print` just call the
    15  //     default `log.std.Print`
    16  package exportdefault
    17  
    18  import (
    19  	"bytes"
    20  	"fmt"
    21  	"go/ast"
    22  	"go/build"
    23  	"go/doc"
    24  	"go/importer"
    25  	"go/parser"
    26  	"go/token"
    27  	"go/types"
    28  	"io"
    29  	"io/ioutil"
    30  	"path"
    31  	"regexp"
    32  	"text/template"
    33  
    34  	"github.com/ernesto-jimenez/gogen/cleanimports"
    35  	"github.com/ernesto-jimenez/gogen/imports"
    36  )
    37  
    38  // Generator contains the metadata needed to generate all the function wrappers
    39  // arround methods from a package variable
    40  type Generator struct {
    41  	Name           string
    42  	Imports        map[string]string
    43  	funcs          []fn
    44  	FuncNamePrefix string
    45  	Include        *regexp.Regexp
    46  	Exclude        *regexp.Regexp
    47  }
    48  
    49  // New initialises a new Generator for the corresponding package's variable
    50  //
    51  // Returns an error if the package or variable are invalid
    52  func New(pkg string, variable string) (*Generator, error) {
    53  	scope, docs, err := parsePackageSource(pkg)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	importer, funcs, err := analyzeCode(scope, docs, variable)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	return &Generator{
    64  		Name:    docs.Name,
    65  		Imports: importer.Imports(),
    66  		funcs:   funcs,
    67  	}, nil
    68  }
    69  
    70  // Write the generated code into the given io.Writer
    71  //
    72  // Returns an error if there is a problem generating the code
    73  func (g *Generator) Write(w io.Writer) error {
    74  	buff := bytes.NewBuffer(nil)
    75  
    76  	// Generate header
    77  	if err := headerTpl.Execute(buff, g); err != nil {
    78  		return err
    79  	}
    80  
    81  	// Generate funcs
    82  	for _, fn := range g.funcs {
    83  		if g.Include != nil && !g.Include.MatchString(fn.Name) {
    84  			continue
    85  		}
    86  		if g.Exclude != nil && g.Exclude.MatchString(fn.Name) {
    87  			continue
    88  		}
    89  		fn.FuncNamePrefix = g.FuncNamePrefix
    90  		buff.Write([]byte("\n\n"))
    91  		if err := funcTpl.Execute(buff, &fn); err != nil {
    92  			return err
    93  		}
    94  	}
    95  
    96  	return cleanimports.Clean(w, buff.Bytes())
    97  }
    98  
    99  type fn struct {
   100  	FuncNamePrefix string
   101  	WrappedVar     string
   102  	Name           string
   103  	CurrentPkg     string
   104  	TypeInfo       *types.Func
   105  }
   106  
   107  func (f *fn) Qualifier(p *types.Package) string {
   108  	if p == nil || p.Name() == f.CurrentPkg {
   109  		return ""
   110  	}
   111  	return p.Name()
   112  }
   113  
   114  func (f *fn) Params() string {
   115  	sig := f.TypeInfo.Type().(*types.Signature)
   116  	params := sig.Params()
   117  	p := ""
   118  	comma := ""
   119  	to := params.Len()
   120  	var i int
   121  
   122  	if sig.Variadic() {
   123  		to--
   124  	}
   125  	for i = 0; i < to; i++ {
   126  		param := params.At(i)
   127  		name := param.Name()
   128  		if name == "" {
   129  			name = fmt.Sprintf("p%d", i)
   130  		}
   131  		p += fmt.Sprintf("%s%s %s", comma, name, types.TypeString(param.Type(), f.Qualifier))
   132  		comma = ", "
   133  	}
   134  	if sig.Variadic() {
   135  		param := params.At(params.Len() - 1)
   136  		name := param.Name()
   137  		if name == "" {
   138  			name = fmt.Sprintf("p%d", to)
   139  		}
   140  		p += fmt.Sprintf("%s%s ...%s", comma, name, types.TypeString(param.Type().(*types.Slice).Elem(), f.Qualifier))
   141  	}
   142  	return p
   143  }
   144  
   145  func (f *fn) ReturnsAnything() bool {
   146  	sig := f.TypeInfo.Type().(*types.Signature)
   147  	params := sig.Results()
   148  	return params.Len() > 0
   149  }
   150  
   151  func (f *fn) ReturnTypes() string {
   152  	sig := f.TypeInfo.Type().(*types.Signature)
   153  	params := sig.Results()
   154  	p := ""
   155  	comma := ""
   156  	to := params.Len()
   157  	var i int
   158  
   159  	for i = 0; i < to; i++ {
   160  		param := params.At(i)
   161  		p += fmt.Sprintf("%s %s", comma, types.TypeString(param.Type(), f.Qualifier))
   162  		comma = ", "
   163  	}
   164  	if to > 1 {
   165  		p = fmt.Sprintf("(%s)", p)
   166  	}
   167  	return p
   168  }
   169  
   170  func (f *fn) ForwardedParams() string {
   171  	sig := f.TypeInfo.Type().(*types.Signature)
   172  	params := sig.Params()
   173  	p := ""
   174  	comma := ""
   175  	to := params.Len()
   176  	var i int
   177  
   178  	if sig.Variadic() {
   179  		to--
   180  	}
   181  	for i = 0; i < to; i++ {
   182  		param := params.At(i)
   183  		name := param.Name()
   184  		if name == "" {
   185  			name = fmt.Sprintf("p%d", i)
   186  		}
   187  		p += fmt.Sprintf("%s%s", comma, name)
   188  		comma = ", "
   189  	}
   190  	if sig.Variadic() {
   191  		param := params.At(params.Len() - 1)
   192  		name := param.Name()
   193  		if name == "" {
   194  			name = fmt.Sprintf("p%d", to)
   195  		}
   196  		p += fmt.Sprintf("%s%s...", comma, name)
   197  	}
   198  	return p
   199  }
   200  
   201  // parsePackageSource returns the types scope and the package documentation from the specified package
   202  func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
   203  	pd, err := build.Import(pkg, ".", 0)
   204  	if err != nil {
   205  		return nil, nil, err
   206  	}
   207  
   208  	fset := token.NewFileSet()
   209  	files := make(map[string]*ast.File)
   210  	fileList := make([]*ast.File, len(pd.GoFiles))
   211  	for i, fname := range pd.GoFiles {
   212  		src, err := ioutil.ReadFile(path.Join(pd.SrcRoot, pd.ImportPath, fname))
   213  		if err != nil {
   214  			return nil, nil, err
   215  		}
   216  		f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors)
   217  		if err != nil {
   218  			return nil, nil, err
   219  		}
   220  		files[fname] = f
   221  		fileList[i] = f
   222  	}
   223  
   224  	cfg := types.Config{
   225  		Importer: importer.Default(),
   226  	}
   227  	info := types.Info{
   228  		Defs: make(map[*ast.Ident]types.Object),
   229  	}
   230  	tp, err := cfg.Check(pkg, fset, fileList, &info)
   231  	if err != nil {
   232  		return nil, nil, err
   233  	}
   234  
   235  	scope := tp.Scope()
   236  
   237  	ap, _ := ast.NewPackage(fset, files, nil, nil)
   238  	docs := doc.New(ap, pkg, doc.AllDecls|doc.AllMethods)
   239  
   240  	return scope, docs, nil
   241  }
   242  
   243  func analyzeCode(scope *types.Scope, docs *doc.Package, variable string) (imports.Importer, []fn, error) {
   244  	pkg := docs.Name
   245  	v, ok := scope.Lookup(variable).(*types.Var)
   246  	if v == nil {
   247  		return nil, nil, fmt.Errorf("impossible to find variable %s", variable)
   248  	}
   249  	if !ok {
   250  		return nil, nil, fmt.Errorf("%s must be a variable", variable)
   251  	}
   252  	var vType interface {
   253  		NumMethods() int
   254  		Method(int) *types.Func
   255  	}
   256  	switch t := v.Type().(type) {
   257  	case *types.Interface:
   258  		vType = t
   259  	case *types.Pointer:
   260  		vType = t.Elem().(*types.Named)
   261  	case *types.Named:
   262  		vType = t
   263  		if t, ok := t.Underlying().(*types.Interface); ok {
   264  			vType = t
   265  		}
   266  	default:
   267  		return nil, nil, fmt.Errorf("variable is of an invalid type: %T", v.Type().Underlying())
   268  	}
   269  
   270  	importer := imports.New(pkg)
   271  	var funcs []fn
   272  	for i := 0; i < vType.NumMethods(); i++ {
   273  		f := vType.Method(i)
   274  
   275  		if !f.Exported() {
   276  			continue
   277  		}
   278  
   279  		sig := f.Type().(*types.Signature)
   280  
   281  		funcs = append(funcs, fn{
   282  			WrappedVar: variable,
   283  			Name:       f.Name(),
   284  			CurrentPkg: pkg,
   285  			TypeInfo:   f,
   286  		})
   287  		importer.AddImportsFrom(sig.Params())
   288  		importer.AddImportsFrom(sig.Results())
   289  	}
   290  	return importer, funcs, nil
   291  }
   292  
   293  var headerTpl = template.Must(template.New("header").Parse(`/*
   294  * CODE GENERATED AUTOMATICALLY WITH goexportdefault
   295  * THIS FILE MUST NOT BE EDITED BY HAND
   296  *
   297  * Install goexportdefault with:
   298  * go get github.com/ernesto-jimenez/gogen/cmd/goexportdefault
   299  */
   300  
   301  package {{.Name}}
   302  
   303  import (
   304  {{range $path, $name := .Imports}}
   305  	{{$name}} "{{$path}}"{{end}}
   306  )
   307  `))
   308  
   309  var funcTpl = template.Must(template.New("func").Parse(`// {{.FuncNamePrefix}}{{.Name}} is a wrapper around {{.WrappedVar}}.{{.Name}}
   310  func {{.FuncNamePrefix}}{{.Name}}({{.Params}}) {{.ReturnTypes}} {
   311  	{{if .ReturnsAnything}}return {{end}}{{.WrappedVar}}.{{.Name}}({{.ForwardedParams}})
   312  }`))