github.com/benma/gogen@v0.0.0-20160826115606-cf49914b915a/specific/process.go (about)

     1  // Package specific copies the source from a package and generates a second
     2  // package replacing some of the types used. It's aimed at taking generic
     3  // packages that rely on interface{} and generating packages that use a
     4  // specific type.
     5  package specific
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/printer"
    12  	"go/token"
    13  	"golang.org/x/tools/go/ast/astutil"
    14  	"io/ioutil"
    15  	"os"
    16  	"path"
    17  )
    18  
    19  type Options struct {
    20  	SkipTestFiles bool
    21  }
    22  
    23  var DefaultOptions = Options{
    24  	SkipTestFiles: false,
    25  }
    26  
    27  // Process creates a specific package from the generic specified in pkg
    28  func Process(pkg, outdir string, newType string, optset ...func(*Options)) error {
    29  	opts := DefaultOptions
    30  	for _, fn := range optset {
    31  		fn(&opts)
    32  	}
    33  
    34  	p, err := findPackage(pkg)
    35  	if err != nil {
    36  		return err
    37  	}
    38  
    39  	if outdir == "" {
    40  		outdir = path.Base(pkg)
    41  	}
    42  
    43  	if err := os.MkdirAll(outdir, os.ModePerm); err != nil {
    44  		return err
    45  	}
    46  
    47  	t := parseTargetType(newType)
    48  
    49  	files, err := processFiles(p, p.GoFiles, t)
    50  	if err != nil {
    51  		return err
    52  	}
    53  
    54  	if err := write(outdir, files); err != nil {
    55  		return err
    56  	}
    57  
    58  	if opts.SkipTestFiles {
    59  		return nil
    60  	}
    61  
    62  	files, err = processFiles(p, p.TestGoFiles, t)
    63  	if err != nil {
    64  		return err
    65  	}
    66  
    67  	return write(outdir, files)
    68  }
    69  
    70  func processFiles(p Package, files []string, t targetType) ([]processedFile, error) {
    71  	var result []processedFile
    72  	for _, f := range files {
    73  		res, err := processFile(p, f, t)
    74  		if err != nil {
    75  			return result, err
    76  		}
    77  		result = append(result, res)
    78  	}
    79  	return result, nil
    80  }
    81  
    82  func processFile(p Package, filename string, t targetType) (processedFile, error) {
    83  	res := processedFile{filename: filename}
    84  
    85  	in, err := os.Open(path.Join(p.Dir, filename))
    86  	if err != nil {
    87  		return res, FileError{Package: p.Dir, File: filename, Err: err}
    88  	}
    89  	src, err := ioutil.ReadAll(in)
    90  	if err != nil {
    91  		return res, FileError{Package: p.Dir, File: filename, Err: err}
    92  	}
    93  
    94  	res.fset = token.NewFileSet()
    95  	res.file, err = parser.ParseFile(res.fset, res.filename, src, parser.ParseComments|parser.AllErrors|parser.DeclarationErrors)
    96  	if err != nil {
    97  		return res, FileError{Package: p.Dir, File: filename, Err: err}
    98  	}
    99  
   100  	if replace(t, res.file) && t.newPkg != "" {
   101  		astutil.AddImport(res.fset, res.file, t.newPkg)
   102  	}
   103  
   104  	return res, err
   105  }
   106  
   107  func replace(t targetType, n ast.Node) (replaced bool) {
   108  	newType := t.newType
   109  	ast.Walk(visitFn(func(node ast.Node) {
   110  		if node == nil {
   111  			return
   112  		}
   113  		switch n := node.(type) {
   114  		case *ast.ArrayType:
   115  			if t, ok := n.Elt.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 {
   116  				str := ast.NewIdent(newType)
   117  				str.NamePos = t.Pos()
   118  				n.Elt = str
   119  				replaced = true
   120  			}
   121  		case *ast.ChanType:
   122  			if t, ok := n.Value.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 {
   123  				str := ast.NewIdent(newType)
   124  				str.NamePos = t.Pos()
   125  				n.Value = str
   126  				replaced = true
   127  			}
   128  		case *ast.MapType:
   129  			if t, ok := n.Key.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 {
   130  				str := ast.NewIdent(newType)
   131  				str.NamePos = t.Pos()
   132  				n.Key = str
   133  				replaced = true
   134  			}
   135  			if t, ok := n.Value.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 {
   136  				str := ast.NewIdent(newType)
   137  				str.NamePos = t.Pos()
   138  				n.Value = str
   139  				replaced = true
   140  			}
   141  		case *ast.Field:
   142  			if t, ok := n.Type.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 {
   143  				str := ast.NewIdent(newType)
   144  				str.NamePos = t.Pos()
   145  				n.Type = str
   146  				replaced = true
   147  			}
   148  		}
   149  	}), n)
   150  	return replaced
   151  }
   152  
   153  type visitFn func(node ast.Node)
   154  
   155  func (fn visitFn) Visit(node ast.Node) ast.Visitor {
   156  	fn(node)
   157  	return fn
   158  }
   159  
   160  func write(outdir string, files []processedFile) error {
   161  	for _, f := range files {
   162  		out, err := os.Create(path.Join(outdir, f.filename))
   163  		if err != nil {
   164  			return FileError{Package: outdir, File: f.filename, Err: err}
   165  		}
   166  
   167  		fmt.Fprintf(out, "/*\n"+
   168  			"* CODE GENERATED AUTOMATICALLY WITH github.com/ernesto-jimenez/gogen/specific\n"+
   169  			"* THIS FILE SHOULD NOT BE EDITED BY HAND\n"+
   170  			"*/\n\n")
   171  		printer.Fprint(out, f.fset, f.file)
   172  	}
   173  	return nil
   174  }
   175  
   176  type FileError struct {
   177  	Package string
   178  	File    string
   179  	Err     error
   180  }
   181  
   182  func (ferr FileError) Error() string {
   183  	return fmt.Sprintf("error in %s: %s", path.Join(ferr.Package, ferr.File), ferr.Err.Error())
   184  }
   185  
   186  type processedFile struct {
   187  	filename string
   188  	fset     *token.FileSet
   189  	file     *ast.File
   190  }