github.com/benma/gogen@v0.0.0-20160826115606-cf49914b915a/specific/process.go (about) 1 // Package specific copies the source from a package and generates a second 2 // package replacing some of the types used. It's aimed at taking generic 3 // packages that rely on interface{} and generating packages that use a 4 // specific type. 5 package specific 6 7 import ( 8 "fmt" 9 "go/ast" 10 "go/parser" 11 "go/printer" 12 "go/token" 13 "golang.org/x/tools/go/ast/astutil" 14 "io/ioutil" 15 "os" 16 "path" 17 ) 18 19 type Options struct { 20 SkipTestFiles bool 21 } 22 23 var DefaultOptions = Options{ 24 SkipTestFiles: false, 25 } 26 27 // Process creates a specific package from the generic specified in pkg 28 func Process(pkg, outdir string, newType string, optset ...func(*Options)) error { 29 opts := DefaultOptions 30 for _, fn := range optset { 31 fn(&opts) 32 } 33 34 p, err := findPackage(pkg) 35 if err != nil { 36 return err 37 } 38 39 if outdir == "" { 40 outdir = path.Base(pkg) 41 } 42 43 if err := os.MkdirAll(outdir, os.ModePerm); err != nil { 44 return err 45 } 46 47 t := parseTargetType(newType) 48 49 files, err := processFiles(p, p.GoFiles, t) 50 if err != nil { 51 return err 52 } 53 54 if err := write(outdir, files); err != nil { 55 return err 56 } 57 58 if opts.SkipTestFiles { 59 return nil 60 } 61 62 files, err = processFiles(p, p.TestGoFiles, t) 63 if err != nil { 64 return err 65 } 66 67 return write(outdir, files) 68 } 69 70 func processFiles(p Package, files []string, t targetType) ([]processedFile, error) { 71 var result []processedFile 72 for _, f := range files { 73 res, err := processFile(p, f, t) 74 if err != nil { 75 return result, err 76 } 77 result = append(result, res) 78 } 79 return result, nil 80 } 81 82 func processFile(p Package, filename string, t targetType) (processedFile, error) { 83 res := processedFile{filename: filename} 84 85 in, err := os.Open(path.Join(p.Dir, filename)) 86 if err != nil { 87 return res, FileError{Package: p.Dir, File: filename, Err: err} 88 } 89 src, err := ioutil.ReadAll(in) 90 if err != nil { 91 return res, FileError{Package: p.Dir, File: filename, Err: err} 92 } 93 94 res.fset = token.NewFileSet() 95 res.file, err = parser.ParseFile(res.fset, res.filename, src, parser.ParseComments|parser.AllErrors|parser.DeclarationErrors) 96 if err != nil { 97 return res, FileError{Package: p.Dir, File: filename, Err: err} 98 } 99 100 if replace(t, res.file) && t.newPkg != "" { 101 astutil.AddImport(res.fset, res.file, t.newPkg) 102 } 103 104 return res, err 105 } 106 107 func replace(t targetType, n ast.Node) (replaced bool) { 108 newType := t.newType 109 ast.Walk(visitFn(func(node ast.Node) { 110 if node == nil { 111 return 112 } 113 switch n := node.(type) { 114 case *ast.ArrayType: 115 if t, ok := n.Elt.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 { 116 str := ast.NewIdent(newType) 117 str.NamePos = t.Pos() 118 n.Elt = str 119 replaced = true 120 } 121 case *ast.ChanType: 122 if t, ok := n.Value.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 { 123 str := ast.NewIdent(newType) 124 str.NamePos = t.Pos() 125 n.Value = str 126 replaced = true 127 } 128 case *ast.MapType: 129 if t, ok := n.Key.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 { 130 str := ast.NewIdent(newType) 131 str.NamePos = t.Pos() 132 n.Key = str 133 replaced = true 134 } 135 if t, ok := n.Value.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 { 136 str := ast.NewIdent(newType) 137 str.NamePos = t.Pos() 138 n.Value = str 139 replaced = true 140 } 141 case *ast.Field: 142 if t, ok := n.Type.(*ast.InterfaceType); ok && t.Methods.NumFields() == 0 { 143 str := ast.NewIdent(newType) 144 str.NamePos = t.Pos() 145 n.Type = str 146 replaced = true 147 } 148 } 149 }), n) 150 return replaced 151 } 152 153 type visitFn func(node ast.Node) 154 155 func (fn visitFn) Visit(node ast.Node) ast.Visitor { 156 fn(node) 157 return fn 158 } 159 160 func write(outdir string, files []processedFile) error { 161 for _, f := range files { 162 out, err := os.Create(path.Join(outdir, f.filename)) 163 if err != nil { 164 return FileError{Package: outdir, File: f.filename, Err: err} 165 } 166 167 fmt.Fprintf(out, "/*\n"+ 168 "* CODE GENERATED AUTOMATICALLY WITH github.com/ernesto-jimenez/gogen/specific\n"+ 169 "* THIS FILE SHOULD NOT BE EDITED BY HAND\n"+ 170 "*/\n\n") 171 printer.Fprint(out, f.fset, f.file) 172 } 173 return nil 174 } 175 176 type FileError struct { 177 Package string 178 File string 179 Err error 180 } 181 182 func (ferr FileError) Error() string { 183 return fmt.Sprintf("error in %s: %s", path.Join(ferr.Package, ferr.File), ferr.Err.Error()) 184 } 185 186 type processedFile struct { 187 filename string 188 fset *token.FileSet 189 file *ast.File 190 }