github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/internal/rewrite/rewriter.go (about) 1 package rewrite 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/token" 8 "os" 9 "path/filepath" 10 "strconv" 11 "strings" 12 13 "github.com/mstephano/gqlgen-schemagen/internal/code" 14 "golang.org/x/tools/go/packages" 15 ) 16 17 type Rewriter struct { 18 pkg *packages.Package 19 files map[string]string 20 copied map[ast.Decl]bool 21 } 22 23 func New(dir string) (*Rewriter, error) { 24 importPath := code.ImportPathForDir(dir) 25 if importPath == "" { 26 return nil, fmt.Errorf("import path not found for directory: %q", dir) 27 } 28 pkgs, err := packages.Load(&packages.Config{ 29 Mode: packages.NeedSyntax | packages.NeedTypes, 30 }, importPath) 31 if err != nil { 32 return nil, err 33 } 34 if len(pkgs) == 0 { 35 return nil, fmt.Errorf("package not found for importPath: %s", importPath) 36 } 37 38 return &Rewriter{ 39 pkg: pkgs[0], 40 files: map[string]string{}, 41 copied: map[ast.Decl]bool{}, 42 }, nil 43 } 44 45 func (r *Rewriter) getSource(start, end token.Pos) string { 46 startPos := r.pkg.Fset.Position(start) 47 endPos := r.pkg.Fset.Position(end) 48 49 if startPos.Filename != endPos.Filename { 50 panic("cant get source spanning multiple files") 51 } 52 53 file := r.getFile(startPos.Filename) 54 return file[startPos.Offset:endPos.Offset] 55 } 56 57 func (r *Rewriter) getFile(filename string) string { 58 if _, ok := r.files[filename]; !ok { 59 b, err := os.ReadFile(filename) 60 if err != nil { 61 panic(fmt.Errorf("unable to load file, already exists: %w", err)) 62 } 63 64 r.files[filename] = string(b) 65 66 } 67 68 return r.files[filename] 69 } 70 71 func (r *Rewriter) GetMethodComment(structname string, methodname string) string { 72 for _, f := range r.pkg.Syntax { 73 for _, d := range f.Decls { 74 d, isFunc := d.(*ast.FuncDecl) 75 if !isFunc { 76 continue 77 } 78 if d.Name.Name != methodname { 79 continue 80 } 81 if d.Recv == nil || len(d.Recv.List) == 0 { 82 continue 83 } 84 recv := d.Recv.List[0].Type 85 if star, isStar := recv.(*ast.StarExpr); isStar { 86 recv = star.X 87 } 88 ident, ok := recv.(*ast.Ident) 89 if !ok { 90 continue 91 } 92 93 if ident.Name != structname { 94 continue 95 } 96 return d.Doc.Text() 97 } 98 } 99 100 return "" 101 } 102 103 func (r *Rewriter) GetMethodBody(structname string, methodname string) string { 104 for _, f := range r.pkg.Syntax { 105 for _, d := range f.Decls { 106 d, isFunc := d.(*ast.FuncDecl) 107 if !isFunc { 108 continue 109 } 110 if d.Name.Name != methodname { 111 continue 112 } 113 if d.Recv == nil || len(d.Recv.List) == 0 { 114 continue 115 } 116 recv := d.Recv.List[0].Type 117 if star, isStar := recv.(*ast.StarExpr); isStar { 118 recv = star.X 119 } 120 ident, ok := recv.(*ast.Ident) 121 if !ok { 122 continue 123 } 124 125 if ident.Name != structname { 126 continue 127 } 128 129 r.copied[d] = true 130 131 return r.getSource(d.Body.Pos()+1, d.Body.End()-1) 132 } 133 } 134 135 return "" 136 } 137 138 func (r *Rewriter) MarkStructCopied(name string) { 139 for _, f := range r.pkg.Syntax { 140 for _, d := range f.Decls { 141 d, isGen := d.(*ast.GenDecl) 142 if !isGen { 143 continue 144 } 145 if d.Tok != token.TYPE || len(d.Specs) == 0 { 146 continue 147 } 148 149 spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec) 150 if !isTypeSpec { 151 continue 152 } 153 154 if spec.Name.Name != name { 155 continue 156 } 157 158 r.copied[d] = true 159 } 160 } 161 } 162 163 func (r *Rewriter) ExistingImports(filename string) []Import { 164 filename, err := filepath.Abs(filename) 165 if err != nil { 166 panic(err) 167 } 168 for _, f := range r.pkg.Syntax { 169 pos := r.pkg.Fset.Position(f.Pos()) 170 171 if filename != pos.Filename { 172 continue 173 } 174 175 var imps []Import 176 for _, i := range f.Imports { 177 name := "" 178 if i.Name != nil { 179 name = i.Name.Name 180 } 181 path, err := strconv.Unquote(i.Path.Value) 182 if err != nil { 183 panic(err) 184 } 185 imps = append(imps, Import{name, path}) 186 } 187 return imps 188 } 189 return nil 190 } 191 192 func (r *Rewriter) RemainingSource(filename string) string { 193 filename, err := filepath.Abs(filename) 194 if err != nil { 195 panic(err) 196 } 197 for _, f := range r.pkg.Syntax { 198 pos := r.pkg.Fset.Position(f.Pos()) 199 200 if filename != pos.Filename { 201 continue 202 } 203 204 var buf bytes.Buffer 205 206 for _, d := range f.Decls { 207 if r.copied[d] { 208 continue 209 } 210 211 if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT { 212 continue 213 } 214 215 buf.WriteString(r.getSource(d.Pos(), d.End())) 216 buf.WriteString("\n") 217 } 218 219 return strings.TrimSpace(buf.String()) 220 } 221 return "" 222 } 223 224 type Import struct { 225 Alias string 226 ImportPath string 227 }