github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/internal/rewrite/rewriter.go (about) 1 package rewrite 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/token" 8 "os" 9 "path/filepath" 10 "strconv" 11 "strings" 12 13 "github.com/geneva/gqlgen/internal/code" 14 "golang.org/x/tools/go/packages" 15 ) 16 17 type Rewriter struct { 18 pkg *packages.Package 19 files map[string]string 20 copied map[ast.Decl]bool 21 } 22 23 func New(dir string) (*Rewriter, error) { 24 importPath := code.ImportPathForDir(dir) 25 if importPath == "" { 26 return nil, fmt.Errorf("import path not found for directory: %q", dir) 27 } 28 pkgs, err := packages.Load(&packages.Config{ 29 Mode: packages.NeedSyntax | packages.NeedTypes, 30 }, importPath) 31 if err != nil { 32 return nil, err 33 } 34 if len(pkgs) == 0 { 35 return nil, fmt.Errorf("package not found for importPath: %s", importPath) 36 } 37 38 return &Rewriter{ 39 pkg: pkgs[0], 40 files: map[string]string{}, 41 copied: map[ast.Decl]bool{}, 42 }, nil 43 } 44 45 func (r *Rewriter) getSource(start, end token.Pos) string { 46 startPos := r.pkg.Fset.Position(start) 47 endPos := r.pkg.Fset.Position(end) 48 49 if startPos.Filename != endPos.Filename { 50 panic("cant get source spanning multiple files") 51 } 52 53 file := r.getFile(startPos.Filename) 54 return file[startPos.Offset:endPos.Offset] 55 } 56 57 func (r *Rewriter) getFile(filename string) string { 58 if _, ok := r.files[filename]; !ok { 59 b, err := os.ReadFile(filename) 60 if err != nil { 61 panic(fmt.Errorf("unable to load file, already exists: %w", err)) 62 } 63 64 r.files[filename] = string(b) 65 66 } 67 68 return r.files[filename] 69 } 70 71 func (r *Rewriter) GetPrevDecl(structname, methodname string) *ast.FuncDecl { 72 for _, f := range r.pkg.Syntax { 73 for _, d := range f.Decls { 74 d, isFunc := d.(*ast.FuncDecl) 75 if !isFunc { 76 continue 77 } 78 if d.Name.Name != methodname { 79 continue 80 } 81 if d.Recv == nil || len(d.Recv.List) == 0 { 82 continue 83 } 84 recv := d.Recv.List[0].Type 85 if star, isStar := recv.(*ast.StarExpr); isStar { 86 recv = star.X 87 } 88 ident, ok := recv.(*ast.Ident) 89 if !ok { 90 continue 91 } 92 if ident.Name != structname { 93 continue 94 } 95 r.copied[d] = true 96 return d 97 } 98 } 99 return nil 100 } 101 102 func (r *Rewriter) GetMethodComment(structname, methodname string) string { 103 d := r.GetPrevDecl(structname, methodname) 104 if d != nil { 105 return d.Doc.Text() 106 } 107 return "" 108 } 109 110 func (r *Rewriter) GetMethodBody(structname, methodname string) string { 111 d := r.GetPrevDecl(structname, methodname) 112 if d != nil { 113 return r.getSource(d.Body.Pos()+1, d.Body.End()-1) 114 } 115 return "" 116 } 117 118 func (r *Rewriter) MarkStructCopied(name string) { 119 for _, f := range r.pkg.Syntax { 120 for _, d := range f.Decls { 121 d, isGen := d.(*ast.GenDecl) 122 if !isGen { 123 continue 124 } 125 if d.Tok != token.TYPE || len(d.Specs) == 0 { 126 continue 127 } 128 129 spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec) 130 if !isTypeSpec { 131 continue 132 } 133 134 if spec.Name.Name != name { 135 continue 136 } 137 138 r.copied[d] = true 139 } 140 } 141 } 142 143 func (r *Rewriter) ExistingImports(filename string) []Import { 144 filename, err := filepath.Abs(filename) 145 if err != nil { 146 panic(err) 147 } 148 for _, f := range r.pkg.Syntax { 149 pos := r.pkg.Fset.Position(f.Pos()) 150 151 if filename != pos.Filename { 152 continue 153 } 154 155 var imps []Import 156 for _, i := range f.Imports { 157 name := "" 158 if i.Name != nil { 159 name = i.Name.Name 160 } 161 path, err := strconv.Unquote(i.Path.Value) 162 if err != nil { 163 panic(err) 164 } 165 imps = append(imps, Import{name, path}) 166 } 167 return imps 168 } 169 return nil 170 } 171 172 func (r *Rewriter) RemainingSource(filename string) string { 173 filename, err := filepath.Abs(filename) 174 if err != nil { 175 panic(err) 176 } 177 for _, f := range r.pkg.Syntax { 178 pos := r.pkg.Fset.Position(f.Pos()) 179 180 if filename != pos.Filename { 181 continue 182 } 183 184 var buf bytes.Buffer 185 186 for _, d := range f.Decls { 187 if r.copied[d] { 188 continue 189 } 190 191 if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT { 192 continue 193 } 194 195 buf.WriteString(r.getSource(d.Pos(), d.End())) 196 buf.WriteString("\n") 197 } 198 199 return strings.TrimSpace(buf.String()) 200 } 201 return "" 202 } 203 204 type Import struct { 205 Alias string 206 ImportPath string 207 }