github.com/sridharv/stencil@v0.0.0-20170626103218-a81b4a7626a1/stencil.go (about)

     1  // Package stencil generates specialized versions of Go packages by replacing types.
     2  package stencil
     3  
     4  import (
     5  	"go/ast"
     6  	"go/format"
     7  	"go/parser"
     8  	"go/token"
     9  	"strings"
    10  
    11  	"bytes"
    12  
    13  	"path/filepath"
    14  
    15  	"os"
    16  
    17  	"io/ioutil"
    18  
    19  	"go/build"
    20  
    21  	"josharian/apply"
    22  
    23  	"github.com/pkg/errors"
    24  	"golang.org/x/tools/imports"
    25  )
    26  
    27  type file struct {
    28  	data []byte
    29  	path string
    30  }
    31  
    32  // Process process paths, generating vendored, specialized code for any stencil import paths.
    33  // If format is true any go files in paths are processed using goimports.
    34  //
    35  // For detailed documentation consult the docs for "github.com/sridharv/stencil/cmd/stencil"
    36  func Process(paths []string, format bool) error {
    37  	files, err := processStencil(paths)
    38  	if err != nil {
    39  		return err
    40  	}
    41  
    42  	for _, f := range files {
    43  		dir := filepath.Dir(f.path)
    44  		if err := os.MkdirAll(dir, 0755); err != nil {
    45  			return errors.WithStack(err)
    46  		}
    47  		if err := ioutil.WriteFile(f.path, f.data, 0644); err != nil {
    48  			return errors.WithStack(err)
    49  		}
    50  	}
    51  	if !format {
    52  		return nil
    53  	}
    54  	return doImports(paths)
    55  }
    56  
    57  func doImports(paths []string) error {
    58  	for _, p := range paths {
    59  		s, err := os.Stat(p)
    60  		if err != nil {
    61  			return errors.WithStack(err)
    62  		}
    63  		if s.IsDir() {
    64  			continue
    65  		}
    66  		b, err := ioutil.ReadFile(p)
    67  		if err != nil {
    68  			return errors.Wrapf(err, "%s", p)
    69  		}
    70  		if b, err = imports.Process(p, b, nil); err != nil {
    71  			return errors.Wrapf(err, "%s", p)
    72  		}
    73  		if err = ioutil.WriteFile(p, b, s.Mode()); err != nil {
    74  			return errors.Wrapf(err, "failed to write %s", p)
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  type replacer map[string]string
    81  
    82  func (r replacer) preReplace(c apply.ApplyCursor) bool {
    83  	switch t := c.Node().(type) {
    84  	case *ast.GenDecl:
    85  		// Delete named type specifications that will be replaced.
    86  		if len(t.Specs) == 0 {
    87  			return true
    88  		}
    89  		spec, ok := t.Specs[0].(*ast.TypeSpec)
    90  		if !ok {
    91  			return true
    92  		}
    93  
    94  		if _, ok = r[spec.Name.Name]; !ok {
    95  			return true
    96  		}
    97  		c.Delete()
    98  	case *ast.Ident:
    99  		if t == nil {
   100  			return true
   101  		}
   102  		if s, ok := r[t.Name]; ok {
   103  			t.Name = s
   104  		}
   105  	case *ast.InterfaceType:
   106  		rep, ok := r["interface"]
   107  		if !ok {
   108  			return true
   109  		}
   110  		if _, isType := c.Parent().(*ast.TypeSpec); isType {
   111  			return true
   112  		}
   113  		c.Replace(&ast.Ident{
   114  			Name:    rep,
   115  			NamePos: t.Pos(),
   116  		})
   117  	}
   118  	return true
   119  }
   120  
   121  func listPackages(paths []string) (map[string][]string, error) {
   122  	if len(paths) == 0 {
   123  		paths = append(paths, ".")
   124  	}
   125  	dirs := map[string][]string{}
   126  	for _, arg := range paths {
   127  		c, err := filepath.Abs(arg)
   128  		if err != nil {
   129  			return nil, errors.WithStack(err)
   130  		}
   131  		if strings.HasSuffix(c, ".go") {
   132  			dir := filepath.Dir(c)
   133  			dirs[dir] = append(dirs[dir], c)
   134  			continue
   135  		}
   136  		infos, err := ioutil.ReadDir(c)
   137  		if err != nil {
   138  			return nil, errors.WithStack(err)
   139  		}
   140  		var files []string
   141  		for _, i := range infos {
   142  			n := i.Name()
   143  			if strings.HasSuffix(n, ".go") && !strings.HasSuffix(n, "_test.go") {
   144  				files = append(files, filepath.Join(c, n))
   145  			}
   146  		}
   147  		dirs[c] = files
   148  	}
   149  	return dirs, nil
   150  }
   151  
   152  func packageExists(roots []string, pkg string) (string, bool) {
   153  	for _, r := range roots {
   154  		// Rough heuristic to check if a package exists.
   155  		dir := filepath.Join(r, pkg)
   156  		if s, err := os.Stat(dir); err == nil && s.IsDir() {
   157  			return dir, true
   158  		}
   159  	}
   160  	return "", false
   161  }
   162  
   163  func replacements(roots []string, pkg string) (string, replacer) {
   164  	parts, path := strings.Split(pkg, "/"), pkg
   165  	// See if we can form a substitution pattern from the parts here
   166  	r := replacer{}
   167  	dir, found := packageExists(roots, path)
   168  	for !found && len(parts) > 2 {
   169  		l := len(parts)
   170  		// A path looks like github.com/foo/bar/Parameter/Specialization
   171  		// r[originalType] = replacementType
   172  		r[parts[l-2]] = parts[l-1]
   173  		parts = parts[:l-2]
   174  		path = strings.Join(parts, "/")
   175  		dir, found = packageExists(roots, path)
   176  	}
   177  	if !found || len(r) == 0 {
   178  		return "", nil
   179  	}
   180  	return dir, r
   181  }
   182  
   183  func makeStencilled(stencil, stencilled string, r replacer, res *[]file) error {
   184  	fs := token.NewFileSet()
   185  	pkgs, err := parser.ParseDir(fs, stencil, func(s os.FileInfo) bool {
   186  		return !strings.HasSuffix(s.Name(), "_test.go")
   187  	}, parser.AllErrors|parser.ParseComments)
   188  	if err != nil {
   189  		return errors.Wrapf(err, "%s: errors parsing", stencil)
   190  	}
   191  	if len(pkgs) != 1 {
   192  		return errors.Errorf("%d: expected 1 package, got %d", stencil, len(pkgs))
   193  	}
   194  	var files map[string]*ast.File
   195  	for _, p := range pkgs {
   196  		files = p.Files
   197  		break
   198  	}
   199  	for path, f := range files {
   200  		target := filepath.Join(stencilled, filepath.Base(path))
   201  		apply.Apply(f, r.preReplace, nil)
   202  		var b bytes.Buffer
   203  		if err := format.Node(&b, fs, f); err != nil {
   204  			return errors.Errorf("%s:%s: code generation failed", stencil, f.Name)
   205  		}
   206  		out, err := imports.Process(target, b.Bytes(), nil)
   207  		if err != nil {
   208  			return errors.WithStack(err)
   209  		}
   210  		*res = append(*res, file{path: target, data: out})
   211  	}
   212  	return nil
   213  }
   214  
   215  func srcRoot(dir string) (string, error) {
   216  	srcs := build.Default.SrcDirs()
   217  	for _, src := range srcs {
   218  		if strings.HasPrefix(dir, src) {
   219  			return src, nil
   220  		}
   221  	}
   222  
   223  	var candidates []os.FileInfo
   224  	for d := dir; d != filepath.Dir(d); d = filepath.Dir(d) {
   225  		if filepath.Base(d) != "src" {
   226  			continue
   227  		}
   228  		info, err := os.Stat(d)
   229  		if err != nil {
   230  			return "", errors.Wrapf(err, "failed to stat parent dir: %s", d)
   231  		}
   232  		candidates = append(candidates, info)
   233  	}
   234  
   235  	for _, src := range srcs {
   236  		si, err := os.Stat(src)
   237  		if err != nil {
   238  			return "", errors.Wrapf(err, "couldn't stat Go src folder: %s", src)
   239  		}
   240  		for _, ci := range candidates {
   241  			if os.SameFile(ci, si) {
   242  				return src, nil
   243  			}
   244  		}
   245  	}
   246  
   247  	return "", errors.Errorf("%s: not in GOPATH", dir)
   248  }
   249  
   250  func processDir(dir string, files []string, res *[]file) error {
   251  	// Read files
   252  	fs := token.NewFileSet()
   253  	srcs, err := srcRoot(dir)
   254  	if err != nil {
   255  		return err
   256  	}
   257  
   258  	vendor := filepath.Join(dir, "vendor")
   259  	for d := dir; d != srcs; d = filepath.Dir(d) {
   260  		v := filepath.Join(d, "vendor")
   261  		st, err := os.Stat(d)
   262  		if err == nil && st.IsDir() {
   263  			vendor = v
   264  			break
   265  		}
   266  	}
   267  	roots := append(build.Default.SrcDirs(), vendor)
   268  
   269  	for _, fl := range files {
   270  		f, err := parser.ParseFile(fs, fl, nil, parser.ImportsOnly)
   271  		if err != nil {
   272  			return errors.Wrapf(err, "%s: parse failed", fl)
   273  		}
   274  		for _, imp := range f.Imports {
   275  			path := imp.Path.Value
   276  			path = path[1 : len(path)-1]
   277  			stencil, r := replacements(roots, path)
   278  			if stencil == "" {
   279  				continue
   280  			}
   281  			if err = makeStencilled(stencil, filepath.Join(vendor, path), r, res); err != nil {
   282  				return err
   283  			}
   284  		}
   285  	}
   286  	return nil
   287  }
   288  
   289  func processStencil(paths []string) ([]file, error) {
   290  	dirs, err := listPackages(paths)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  	var res []file
   295  	for dir, files := range dirs {
   296  		if err := processDir(dir, files, &res); err != nil {
   297  			return nil, err
   298  		}
   299  	}
   300  	return res, nil
   301  }