github.com/sridharv/stencil@v0.0.0-20170626103218-a81b4a7626a1/stencil.go (about) 1 // Package stencil generates specialized versions of Go packages by replacing types. 2 package stencil 3 4 import ( 5 "go/ast" 6 "go/format" 7 "go/parser" 8 "go/token" 9 "strings" 10 11 "bytes" 12 13 "path/filepath" 14 15 "os" 16 17 "io/ioutil" 18 19 "go/build" 20 21 "josharian/apply" 22 23 "github.com/pkg/errors" 24 "golang.org/x/tools/imports" 25 ) 26 27 type file struct { 28 data []byte 29 path string 30 } 31 32 // Process process paths, generating vendored, specialized code for any stencil import paths. 33 // If format is true any go files in paths are processed using goimports. 34 // 35 // For detailed documentation consult the docs for "github.com/sridharv/stencil/cmd/stencil" 36 func Process(paths []string, format bool) error { 37 files, err := processStencil(paths) 38 if err != nil { 39 return err 40 } 41 42 for _, f := range files { 43 dir := filepath.Dir(f.path) 44 if err := os.MkdirAll(dir, 0755); err != nil { 45 return errors.WithStack(err) 46 } 47 if err := ioutil.WriteFile(f.path, f.data, 0644); err != nil { 48 return errors.WithStack(err) 49 } 50 } 51 if !format { 52 return nil 53 } 54 return doImports(paths) 55 } 56 57 func doImports(paths []string) error { 58 for _, p := range paths { 59 s, err := os.Stat(p) 60 if err != nil { 61 return errors.WithStack(err) 62 } 63 if s.IsDir() { 64 continue 65 } 66 b, err := ioutil.ReadFile(p) 67 if err != nil { 68 return errors.Wrapf(err, "%s", p) 69 } 70 if b, err = imports.Process(p, b, nil); err != nil { 71 return errors.Wrapf(err, "%s", p) 72 } 73 if err = ioutil.WriteFile(p, b, s.Mode()); err != nil { 74 return errors.Wrapf(err, "failed to write %s", p) 75 } 76 } 77 return nil 78 } 79 80 type replacer map[string]string 81 82 func (r replacer) preReplace(c apply.ApplyCursor) bool { 83 switch t := c.Node().(type) { 84 case *ast.GenDecl: 85 // Delete named type specifications that will be replaced. 86 if len(t.Specs) == 0 { 87 return true 88 } 89 spec, ok := t.Specs[0].(*ast.TypeSpec) 90 if !ok { 91 return true 92 } 93 94 if _, ok = r[spec.Name.Name]; !ok { 95 return true 96 } 97 c.Delete() 98 case *ast.Ident: 99 if t == nil { 100 return true 101 } 102 if s, ok := r[t.Name]; ok { 103 t.Name = s 104 } 105 case *ast.InterfaceType: 106 rep, ok := r["interface"] 107 if !ok { 108 return true 109 } 110 if _, isType := c.Parent().(*ast.TypeSpec); isType { 111 return true 112 } 113 c.Replace(&ast.Ident{ 114 Name: rep, 115 NamePos: t.Pos(), 116 }) 117 } 118 return true 119 } 120 121 func listPackages(paths []string) (map[string][]string, error) { 122 if len(paths) == 0 { 123 paths = append(paths, ".") 124 } 125 dirs := map[string][]string{} 126 for _, arg := range paths { 127 c, err := filepath.Abs(arg) 128 if err != nil { 129 return nil, errors.WithStack(err) 130 } 131 if strings.HasSuffix(c, ".go") { 132 dir := filepath.Dir(c) 133 dirs[dir] = append(dirs[dir], c) 134 continue 135 } 136 infos, err := ioutil.ReadDir(c) 137 if err != nil { 138 return nil, errors.WithStack(err) 139 } 140 var files []string 141 for _, i := range infos { 142 n := i.Name() 143 if strings.HasSuffix(n, ".go") && !strings.HasSuffix(n, "_test.go") { 144 files = append(files, filepath.Join(c, n)) 145 } 146 } 147 dirs[c] = files 148 } 149 return dirs, nil 150 } 151 152 func packageExists(roots []string, pkg string) (string, bool) { 153 for _, r := range roots { 154 // Rough heuristic to check if a package exists. 155 dir := filepath.Join(r, pkg) 156 if s, err := os.Stat(dir); err == nil && s.IsDir() { 157 return dir, true 158 } 159 } 160 return "", false 161 } 162 163 func replacements(roots []string, pkg string) (string, replacer) { 164 parts, path := strings.Split(pkg, "/"), pkg 165 // See if we can form a substitution pattern from the parts here 166 r := replacer{} 167 dir, found := packageExists(roots, path) 168 for !found && len(parts) > 2 { 169 l := len(parts) 170 // A path looks like github.com/foo/bar/Parameter/Specialization 171 // r[originalType] = replacementType 172 r[parts[l-2]] = parts[l-1] 173 parts = parts[:l-2] 174 path = strings.Join(parts, "/") 175 dir, found = packageExists(roots, path) 176 } 177 if !found || len(r) == 0 { 178 return "", nil 179 } 180 return dir, r 181 } 182 183 func makeStencilled(stencil, stencilled string, r replacer, res *[]file) error { 184 fs := token.NewFileSet() 185 pkgs, err := parser.ParseDir(fs, stencil, func(s os.FileInfo) bool { 186 return !strings.HasSuffix(s.Name(), "_test.go") 187 }, parser.AllErrors|parser.ParseComments) 188 if err != nil { 189 return errors.Wrapf(err, "%s: errors parsing", stencil) 190 } 191 if len(pkgs) != 1 { 192 return errors.Errorf("%d: expected 1 package, got %d", stencil, len(pkgs)) 193 } 194 var files map[string]*ast.File 195 for _, p := range pkgs { 196 files = p.Files 197 break 198 } 199 for path, f := range files { 200 target := filepath.Join(stencilled, filepath.Base(path)) 201 apply.Apply(f, r.preReplace, nil) 202 var b bytes.Buffer 203 if err := format.Node(&b, fs, f); err != nil { 204 return errors.Errorf("%s:%s: code generation failed", stencil, f.Name) 205 } 206 out, err := imports.Process(target, b.Bytes(), nil) 207 if err != nil { 208 return errors.WithStack(err) 209 } 210 *res = append(*res, file{path: target, data: out}) 211 } 212 return nil 213 } 214 215 func srcRoot(dir string) (string, error) { 216 srcs := build.Default.SrcDirs() 217 for _, src := range srcs { 218 if strings.HasPrefix(dir, src) { 219 return src, nil 220 } 221 } 222 223 var candidates []os.FileInfo 224 for d := dir; d != filepath.Dir(d); d = filepath.Dir(d) { 225 if filepath.Base(d) != "src" { 226 continue 227 } 228 info, err := os.Stat(d) 229 if err != nil { 230 return "", errors.Wrapf(err, "failed to stat parent dir: %s", d) 231 } 232 candidates = append(candidates, info) 233 } 234 235 for _, src := range srcs { 236 si, err := os.Stat(src) 237 if err != nil { 238 return "", errors.Wrapf(err, "couldn't stat Go src folder: %s", src) 239 } 240 for _, ci := range candidates { 241 if os.SameFile(ci, si) { 242 return src, nil 243 } 244 } 245 } 246 247 return "", errors.Errorf("%s: not in GOPATH", dir) 248 } 249 250 func processDir(dir string, files []string, res *[]file) error { 251 // Read files 252 fs := token.NewFileSet() 253 srcs, err := srcRoot(dir) 254 if err != nil { 255 return err 256 } 257 258 vendor := filepath.Join(dir, "vendor") 259 for d := dir; d != srcs; d = filepath.Dir(d) { 260 v := filepath.Join(d, "vendor") 261 st, err := os.Stat(d) 262 if err == nil && st.IsDir() { 263 vendor = v 264 break 265 } 266 } 267 roots := append(build.Default.SrcDirs(), vendor) 268 269 for _, fl := range files { 270 f, err := parser.ParseFile(fs, fl, nil, parser.ImportsOnly) 271 if err != nil { 272 return errors.Wrapf(err, "%s: parse failed", fl) 273 } 274 for _, imp := range f.Imports { 275 path := imp.Path.Value 276 path = path[1 : len(path)-1] 277 stencil, r := replacements(roots, path) 278 if stencil == "" { 279 continue 280 } 281 if err = makeStencilled(stencil, filepath.Join(vendor, path), r, res); err != nil { 282 return err 283 } 284 } 285 } 286 return nil 287 } 288 289 func processStencil(paths []string) ([]file, error) { 290 dirs, err := listPackages(paths) 291 if err != nil { 292 return nil, err 293 } 294 var res []file 295 for dir, files := range dirs { 296 if err := processDir(dir, files, &res); err != nil { 297 return nil, err 298 } 299 } 300 return res, nil 301 }