git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/internal/rewrite/rewriter.go (about) 1 package rewrite 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/token" 8 "io/ioutil" 9 "path/filepath" 10 "strconv" 11 "strings" 12 13 "golang.org/x/tools/go/packages" 14 ) 15 16 type Rewriter struct { 17 pkg *packages.Package 18 files map[string]string 19 copied map[ast.Decl]bool 20 } 21 22 func New(importPath string) (*Rewriter, error) { 23 pkgs, err := packages.Load(&packages.Config{ 24 Mode: packages.NeedSyntax | packages.NeedTypes, 25 }, importPath) 26 if err != nil { 27 return nil, err 28 } 29 if len(pkgs) == 0 { 30 return nil, fmt.Errorf("package not found for importPath: %s", importPath) 31 } 32 33 return &Rewriter{ 34 pkg: pkgs[0], 35 files: map[string]string{}, 36 copied: map[ast.Decl]bool{}, 37 }, nil 38 } 39 40 func (r *Rewriter) getSource(start, end token.Pos) string { 41 startPos := r.pkg.Fset.Position(start) 42 endPos := r.pkg.Fset.Position(end) 43 44 if startPos.Filename != endPos.Filename { 45 panic("cant get source spanning multiple files") 46 } 47 48 file := r.getFile(startPos.Filename) 49 return file[startPos.Offset:endPos.Offset] 50 } 51 52 func (r *Rewriter) getFile(filename string) string { 53 if _, ok := r.files[filename]; !ok { 54 b, err := ioutil.ReadFile(filename) 55 if err != nil { 56 panic(fmt.Errorf("unable to load file, already exists: %s", err.Error())) 57 } 58 59 r.files[filename] = string(b) 60 61 } 62 63 return r.files[filename] 64 } 65 66 func (r *Rewriter) GetMethodBody(structname string, methodname string) string { 67 for _, f := range r.pkg.Syntax { 68 for _, d := range f.Decls { 69 d, isFunc := d.(*ast.FuncDecl) 70 if !isFunc { 71 continue 72 } 73 if d.Name.Name != methodname { 74 continue 75 } 76 if d.Recv == nil || len(d.Recv.List) == 0 { 77 continue 78 } 79 recv := d.Recv.List[0].Type 80 if star, isStar := recv.(*ast.StarExpr); isStar { 81 recv = star.X 82 } 83 ident, ok := recv.(*ast.Ident) 84 if !ok { 85 continue 86 } 87 88 if ident.Name != structname { 89 continue 90 } 91 92 r.copied[d] = true 93 94 return r.getSource(d.Body.Pos()+1, d.Body.End()-1) 95 } 96 } 97 98 return "" 99 } 100 101 func (r *Rewriter) MarkStructCopied(name string) { 102 for _, f := range r.pkg.Syntax { 103 for _, d := range f.Decls { 104 d, isGen := d.(*ast.GenDecl) 105 if !isGen { 106 continue 107 } 108 if d.Tok != token.TYPE || len(d.Specs) == 0 { 109 continue 110 } 111 112 spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec) 113 if !isTypeSpec { 114 continue 115 } 116 117 if spec.Name.Name != name { 118 continue 119 } 120 121 r.copied[d] = true 122 } 123 } 124 } 125 126 func (r *Rewriter) ExistingImports(filename string) []Import { 127 filename, err := filepath.Abs(filename) 128 if err != nil { 129 panic(err) 130 } 131 for _, f := range r.pkg.Syntax { 132 pos := r.pkg.Fset.Position(f.Pos()) 133 134 if filename != pos.Filename { 135 continue 136 } 137 138 var imps []Import 139 for _, i := range f.Imports { 140 name := "" 141 if i.Name != nil { 142 name = i.Name.Name 143 } 144 path, err := strconv.Unquote(i.Path.Value) 145 if err != nil { 146 panic(err) 147 } 148 imps = append(imps, Import{name, path}) 149 } 150 return imps 151 } 152 return nil 153 } 154 155 func (r *Rewriter) RemainingSource(filename string) string { 156 filename, err := filepath.Abs(filename) 157 if err != nil { 158 panic(err) 159 } 160 for _, f := range r.pkg.Syntax { 161 pos := r.pkg.Fset.Position(f.Pos()) 162 163 if filename != pos.Filename { 164 continue 165 } 166 167 var buf bytes.Buffer 168 169 for _, d := range f.Decls { 170 if r.copied[d] { 171 continue 172 } 173 174 if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT { 175 continue 176 } 177 178 buf.WriteString(r.getSource(d.Pos(), d.End())) 179 buf.WriteString("\n") 180 } 181 182 return strings.TrimSpace(buf.String()) 183 } 184 return "" 185 } 186 187 type Import struct { 188 Alias string 189 ImportPath string 190 }