github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/internal/rewrite/rewriter.go (about) 1 package rewrite 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/token" 8 "os" 9 "path/filepath" 10 "strconv" 11 "strings" 12 13 "github.com/spread-ai/gqlgen/internal/code" 14 "golang.org/x/tools/go/packages" 15 ) 16 17 type Rewriter struct { 18 pkg *packages.Package 19 files map[string]string 20 copied map[ast.Decl]bool 21 } 22 23 func New(dir string) (*Rewriter, error) { 24 importPath := code.ImportPathForDir(dir) 25 if importPath == "" { 26 return nil, fmt.Errorf("import path not found for directory: %q", dir) 27 } 28 pkgs, err := packages.Load(&packages.Config{ 29 Mode: packages.NeedSyntax | packages.NeedTypes, 30 }, importPath) 31 if err != nil { 32 return nil, err 33 } 34 if len(pkgs) == 0 { 35 return nil, fmt.Errorf("package not found for importPath: %s", importPath) 36 } 37 38 return &Rewriter{ 39 pkg: pkgs[0], 40 files: map[string]string{}, 41 copied: map[ast.Decl]bool{}, 42 }, nil 43 } 44 45 func (r *Rewriter) getSource(start, end token.Pos) string { 46 startPos := r.pkg.Fset.Position(start) 47 endPos := r.pkg.Fset.Position(end) 48 49 if startPos.Filename != endPos.Filename { 50 panic("cant get source spanning multiple files") 51 } 52 53 file := r.getFile(startPos.Filename) 54 return file[startPos.Offset:endPos.Offset] 55 } 56 57 func (r *Rewriter) getFile(filename string) string { 58 if _, ok := r.files[filename]; !ok { 59 b, err := os.ReadFile(filename) 60 if err != nil { 61 panic(fmt.Errorf("unable to load file, already exists: %w", err)) 62 } 63 64 r.files[filename] = string(b) 65 66 } 67 68 return r.files[filename] 69 } 70 71 func (r *Rewriter) GetMethodComment(structname string, methodname string) string { 72 for _, f := range r.pkg.Syntax { 73 for _, d := range f.Decls { 74 d, isFunc := d.(*ast.FuncDecl) 75 if !isFunc { 76 continue 77 } 78 if d.Name.Name != methodname { 79 continue 80 } 81 if d.Recv == nil || len(d.Recv.List) == 0 { 82 continue 83 } 84 recv := d.Recv.List[0].Type 85 if star, isStar := recv.(*ast.StarExpr); isStar { 86 recv = star.X 87 } 88 ident, ok := recv.(*ast.Ident) 89 if !ok { 90 continue 91 } 92 93 if ident.Name != structname { 94 continue 95 } 96 return d.Doc.Text() 97 } 98 } 99 100 return "" 101 } 102 func (r *Rewriter) GetMethodBody(structname string, methodname string) string { 103 for _, f := range r.pkg.Syntax { 104 for _, d := range f.Decls { 105 d, isFunc := d.(*ast.FuncDecl) 106 if !isFunc { 107 continue 108 } 109 if d.Name.Name != methodname { 110 continue 111 } 112 if d.Recv == nil || len(d.Recv.List) == 0 { 113 continue 114 } 115 recv := d.Recv.List[0].Type 116 if star, isStar := recv.(*ast.StarExpr); isStar { 117 recv = star.X 118 } 119 ident, ok := recv.(*ast.Ident) 120 if !ok { 121 continue 122 } 123 124 if ident.Name != structname { 125 continue 126 } 127 128 r.copied[d] = true 129 130 return r.getSource(d.Body.Pos()+1, d.Body.End()-1) 131 } 132 } 133 134 return "" 135 } 136 137 func (r *Rewriter) MarkStructCopied(name string) { 138 for _, f := range r.pkg.Syntax { 139 for _, d := range f.Decls { 140 d, isGen := d.(*ast.GenDecl) 141 if !isGen { 142 continue 143 } 144 if d.Tok != token.TYPE || len(d.Specs) == 0 { 145 continue 146 } 147 148 spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec) 149 if !isTypeSpec { 150 continue 151 } 152 153 if spec.Name.Name != name { 154 continue 155 } 156 157 r.copied[d] = true 158 } 159 } 160 } 161 162 func (r *Rewriter) ExistingImports(filename string) []Import { 163 filename, err := filepath.Abs(filename) 164 if err != nil { 165 panic(err) 166 } 167 for _, f := range r.pkg.Syntax { 168 pos := r.pkg.Fset.Position(f.Pos()) 169 170 if filename != pos.Filename { 171 continue 172 } 173 174 var imps []Import 175 for _, i := range f.Imports { 176 name := "" 177 if i.Name != nil { 178 name = i.Name.Name 179 } 180 path, err := strconv.Unquote(i.Path.Value) 181 if err != nil { 182 panic(err) 183 } 184 imps = append(imps, Import{name, path}) 185 } 186 return imps 187 } 188 return nil 189 } 190 191 func (r *Rewriter) RemainingSource(filename string) string { 192 filename, err := filepath.Abs(filename) 193 if err != nil { 194 panic(err) 195 } 196 for _, f := range r.pkg.Syntax { 197 pos := r.pkg.Fset.Position(f.Pos()) 198 199 if filename != pos.Filename { 200 continue 201 } 202 203 var buf bytes.Buffer 204 205 for _, d := range f.Decls { 206 if r.copied[d] { 207 continue 208 } 209 210 if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT { 211 continue 212 } 213 214 buf.WriteString(r.getSource(d.Pos(), d.End())) 215 buf.WriteString("\n") 216 } 217 218 return strings.TrimSpace(buf.String()) 219 } 220 return "" 221 } 222 223 type Import struct { 224 Alias string 225 ImportPath string 226 }