github.com/benma/gogen@v0.0.0-20160826115606-cf49914b915a/exportdefault/generator.go (about) 1 // Package exportdefault provides the functionality to automatically generate 2 // package-level exported functions wrapping calls to a package-level default 3 // instance of a type. 4 // 5 // This helps auto-generating code for the common use case where a package 6 // implements certain information as methods within a stub and, for 7 // convenience, exports functions that wrap calls to those methods on a default 8 // variable. 9 // 10 // Some examples of that behaviour in the stdlib: 11 // 12 // - `net/http` has `http.DefaultClient` and functions like `http.Get` just 13 // call the default `http.DefaultClient.Get` 14 // - `log` has `log.Logger` and functions like `log.Print` just call the 15 // default `log.std.Print` 16 package exportdefault 17 18 import ( 19 "bytes" 20 "fmt" 21 "go/ast" 22 "go/build" 23 "go/doc" 24 "go/importer" 25 "go/parser" 26 "go/token" 27 "go/types" 28 "io" 29 "io/ioutil" 30 "path" 31 "regexp" 32 "text/template" 33 34 "github.com/ernesto-jimenez/gogen/cleanimports" 35 "github.com/ernesto-jimenez/gogen/imports" 36 ) 37 38 // Generator contains the metadata needed to generate all the function wrappers 39 // arround methods from a package variable 40 type Generator struct { 41 Name string 42 Imports map[string]string 43 funcs []fn 44 FuncNamePrefix string 45 Include *regexp.Regexp 46 Exclude *regexp.Regexp 47 } 48 49 // New initialises a new Generator for the corresponding package's variable 50 // 51 // Returns an error if the package or variable are invalid 52 func New(pkg string, variable string) (*Generator, error) { 53 scope, docs, err := parsePackageSource(pkg) 54 if err != nil { 55 return nil, err 56 } 57 58 importer, funcs, err := analyzeCode(scope, docs, variable) 59 if err != nil { 60 return nil, err 61 } 62 63 return &Generator{ 64 Name: docs.Name, 65 Imports: importer.Imports(), 66 funcs: funcs, 67 }, nil 68 } 69 70 // Write the generated code into the given io.Writer 71 // 72 // Returns an error if there is a problem generating the code 73 func (g *Generator) Write(w io.Writer) error { 74 buff := bytes.NewBuffer(nil) 75 76 // Generate header 77 if err := headerTpl.Execute(buff, g); err != nil { 78 return err 79 } 80 81 // Generate funcs 82 for _, fn := range g.funcs { 83 if g.Include != nil && !g.Include.MatchString(fn.Name) { 84 continue 85 } 86 if g.Exclude != nil && g.Exclude.MatchString(fn.Name) { 87 continue 88 } 89 fn.FuncNamePrefix = g.FuncNamePrefix 90 buff.Write([]byte("\n\n")) 91 if err := funcTpl.Execute(buff, &fn); err != nil { 92 return err 93 } 94 } 95 96 return cleanimports.Clean(w, buff.Bytes()) 97 } 98 99 type fn struct { 100 FuncNamePrefix string 101 WrappedVar string 102 Name string 103 CurrentPkg string 104 TypeInfo *types.Func 105 } 106 107 func (f *fn) Qualifier(p *types.Package) string { 108 if p == nil || p.Name() == f.CurrentPkg { 109 return "" 110 } 111 return p.Name() 112 } 113 114 func (f *fn) Params() string { 115 sig := f.TypeInfo.Type().(*types.Signature) 116 params := sig.Params() 117 p := "" 118 comma := "" 119 to := params.Len() 120 var i int 121 122 if sig.Variadic() { 123 to-- 124 } 125 for i = 0; i < to; i++ { 126 param := params.At(i) 127 name := param.Name() 128 if name == "" { 129 name = fmt.Sprintf("p%d", i) 130 } 131 p += fmt.Sprintf("%s%s %s", comma, name, types.TypeString(param.Type(), f.Qualifier)) 132 comma = ", " 133 } 134 if sig.Variadic() { 135 param := params.At(params.Len() - 1) 136 name := param.Name() 137 if name == "" { 138 name = fmt.Sprintf("p%d", to) 139 } 140 p += fmt.Sprintf("%s%s ...%s", comma, name, types.TypeString(param.Type().(*types.Slice).Elem(), f.Qualifier)) 141 } 142 return p 143 } 144 145 func (f *fn) ReturnsAnything() bool { 146 sig := f.TypeInfo.Type().(*types.Signature) 147 params := sig.Results() 148 return params.Len() > 0 149 } 150 151 func (f *fn) ReturnTypes() string { 152 sig := f.TypeInfo.Type().(*types.Signature) 153 params := sig.Results() 154 p := "" 155 comma := "" 156 to := params.Len() 157 var i int 158 159 for i = 0; i < to; i++ { 160 param := params.At(i) 161 p += fmt.Sprintf("%s %s", comma, types.TypeString(param.Type(), f.Qualifier)) 162 comma = ", " 163 } 164 if to > 1 { 165 p = fmt.Sprintf("(%s)", p) 166 } 167 return p 168 } 169 170 func (f *fn) ForwardedParams() string { 171 sig := f.TypeInfo.Type().(*types.Signature) 172 params := sig.Params() 173 p := "" 174 comma := "" 175 to := params.Len() 176 var i int 177 178 if sig.Variadic() { 179 to-- 180 } 181 for i = 0; i < to; i++ { 182 param := params.At(i) 183 name := param.Name() 184 if name == "" { 185 name = fmt.Sprintf("p%d", i) 186 } 187 p += fmt.Sprintf("%s%s", comma, name) 188 comma = ", " 189 } 190 if sig.Variadic() { 191 param := params.At(params.Len() - 1) 192 name := param.Name() 193 if name == "" { 194 name = fmt.Sprintf("p%d", to) 195 } 196 p += fmt.Sprintf("%s%s...", comma, name) 197 } 198 return p 199 } 200 201 // parsePackageSource returns the types scope and the package documentation from the specified package 202 func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) { 203 pd, err := build.Import(pkg, ".", 0) 204 if err != nil { 205 return nil, nil, err 206 } 207 208 fset := token.NewFileSet() 209 files := make(map[string]*ast.File) 210 fileList := make([]*ast.File, len(pd.GoFiles)) 211 for i, fname := range pd.GoFiles { 212 src, err := ioutil.ReadFile(path.Join(pd.SrcRoot, pd.ImportPath, fname)) 213 if err != nil { 214 return nil, nil, err 215 } 216 f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors) 217 if err != nil { 218 return nil, nil, err 219 } 220 files[fname] = f 221 fileList[i] = f 222 } 223 224 cfg := types.Config{ 225 Importer: importer.Default(), 226 } 227 info := types.Info{ 228 Defs: make(map[*ast.Ident]types.Object), 229 } 230 tp, err := cfg.Check(pkg, fset, fileList, &info) 231 if err != nil { 232 return nil, nil, err 233 } 234 235 scope := tp.Scope() 236 237 ap, _ := ast.NewPackage(fset, files, nil, nil) 238 docs := doc.New(ap, pkg, doc.AllDecls|doc.AllMethods) 239 240 return scope, docs, nil 241 } 242 243 func analyzeCode(scope *types.Scope, docs *doc.Package, variable string) (imports.Importer, []fn, error) { 244 pkg := docs.Name 245 v, ok := scope.Lookup(variable).(*types.Var) 246 if v == nil { 247 return nil, nil, fmt.Errorf("impossible to find variable %s", variable) 248 } 249 if !ok { 250 return nil, nil, fmt.Errorf("%s must be a variable", variable) 251 } 252 var vType interface { 253 NumMethods() int 254 Method(int) *types.Func 255 } 256 switch t := v.Type().(type) { 257 case *types.Interface: 258 vType = t 259 case *types.Pointer: 260 vType = t.Elem().(*types.Named) 261 case *types.Named: 262 vType = t 263 if t, ok := t.Underlying().(*types.Interface); ok { 264 vType = t 265 } 266 default: 267 return nil, nil, fmt.Errorf("variable is of an invalid type: %T", v.Type().Underlying()) 268 } 269 270 importer := imports.New(pkg) 271 var funcs []fn 272 for i := 0; i < vType.NumMethods(); i++ { 273 f := vType.Method(i) 274 275 if !f.Exported() { 276 continue 277 } 278 279 sig := f.Type().(*types.Signature) 280 281 funcs = append(funcs, fn{ 282 WrappedVar: variable, 283 Name: f.Name(), 284 CurrentPkg: pkg, 285 TypeInfo: f, 286 }) 287 importer.AddImportsFrom(sig.Params()) 288 importer.AddImportsFrom(sig.Results()) 289 } 290 return importer, funcs, nil 291 } 292 293 var headerTpl = template.Must(template.New("header").Parse(`/* 294 * CODE GENERATED AUTOMATICALLY WITH goexportdefault 295 * THIS FILE MUST NOT BE EDITED BY HAND 296 * 297 * Install goexportdefault with: 298 * go get github.com/ernesto-jimenez/gogen/cmd/goexportdefault 299 */ 300 301 package {{.Name}} 302 303 import ( 304 {{range $path, $name := .Imports}} 305 {{$name}} "{{$path}}"{{end}} 306 ) 307 `)) 308 309 var funcTpl = template.Must(template.New("func").Parse(`// {{.FuncNamePrefix}}{{.Name}} is a wrapper around {{.WrappedVar}}.{{.Name}} 310 func {{.FuncNamePrefix}}{{.Name}}({{.Params}}) {{.ReturnTypes}} { 311 {{if .ReturnsAnything}}return {{end}}{{.WrappedVar}}.{{.Name}}({{.ForwardedParams}}) 312 }`))