github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/internal/rewrite/rewriter.go (about)

     1  package rewrite
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"os"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/geneva/gqlgen/internal/code"
    14  	"golang.org/x/tools/go/packages"
    15  )
    16  
    17  type Rewriter struct {
    18  	pkg    *packages.Package
    19  	files  map[string]string
    20  	copied map[ast.Decl]bool
    21  }
    22  
    23  func New(dir string) (*Rewriter, error) {
    24  	importPath := code.ImportPathForDir(dir)
    25  	if importPath == "" {
    26  		return nil, fmt.Errorf("import path not found for directory: %q", dir)
    27  	}
    28  	pkgs, err := packages.Load(&packages.Config{
    29  		Mode: packages.NeedSyntax | packages.NeedTypes,
    30  	}, importPath)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	if len(pkgs) == 0 {
    35  		return nil, fmt.Errorf("package not found for importPath: %s", importPath)
    36  	}
    37  
    38  	return &Rewriter{
    39  		pkg:    pkgs[0],
    40  		files:  map[string]string{},
    41  		copied: map[ast.Decl]bool{},
    42  	}, nil
    43  }
    44  
    45  func (r *Rewriter) getSource(start, end token.Pos) string {
    46  	startPos := r.pkg.Fset.Position(start)
    47  	endPos := r.pkg.Fset.Position(end)
    48  
    49  	if startPos.Filename != endPos.Filename {
    50  		panic("cant get source spanning multiple files")
    51  	}
    52  
    53  	file := r.getFile(startPos.Filename)
    54  	return file[startPos.Offset:endPos.Offset]
    55  }
    56  
    57  func (r *Rewriter) getFile(filename string) string {
    58  	if _, ok := r.files[filename]; !ok {
    59  		b, err := os.ReadFile(filename)
    60  		if err != nil {
    61  			panic(fmt.Errorf("unable to load file, already exists: %w", err))
    62  		}
    63  
    64  		r.files[filename] = string(b)
    65  
    66  	}
    67  
    68  	return r.files[filename]
    69  }
    70  
    71  func (r *Rewriter) GetPrevDecl(structname, methodname string) *ast.FuncDecl {
    72  	for _, f := range r.pkg.Syntax {
    73  		for _, d := range f.Decls {
    74  			d, isFunc := d.(*ast.FuncDecl)
    75  			if !isFunc {
    76  				continue
    77  			}
    78  			if d.Name.Name != methodname {
    79  				continue
    80  			}
    81  			if d.Recv == nil || len(d.Recv.List) == 0 {
    82  				continue
    83  			}
    84  			recv := d.Recv.List[0].Type
    85  			if star, isStar := recv.(*ast.StarExpr); isStar {
    86  				recv = star.X
    87  			}
    88  			ident, ok := recv.(*ast.Ident)
    89  			if !ok {
    90  				continue
    91  			}
    92  			if ident.Name != structname {
    93  				continue
    94  			}
    95  			r.copied[d] = true
    96  			return d
    97  		}
    98  	}
    99  	return nil
   100  }
   101  
   102  func (r *Rewriter) GetMethodComment(structname, methodname string) string {
   103  	d := r.GetPrevDecl(structname, methodname)
   104  	if d != nil {
   105  		return d.Doc.Text()
   106  	}
   107  	return ""
   108  }
   109  
   110  func (r *Rewriter) GetMethodBody(structname, methodname string) string {
   111  	d := r.GetPrevDecl(structname, methodname)
   112  	if d != nil {
   113  		return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
   114  	}
   115  	return ""
   116  }
   117  
   118  func (r *Rewriter) MarkStructCopied(name string) {
   119  	for _, f := range r.pkg.Syntax {
   120  		for _, d := range f.Decls {
   121  			d, isGen := d.(*ast.GenDecl)
   122  			if !isGen {
   123  				continue
   124  			}
   125  			if d.Tok != token.TYPE || len(d.Specs) == 0 {
   126  				continue
   127  			}
   128  
   129  			spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec)
   130  			if !isTypeSpec {
   131  				continue
   132  			}
   133  
   134  			if spec.Name.Name != name {
   135  				continue
   136  			}
   137  
   138  			r.copied[d] = true
   139  		}
   140  	}
   141  }
   142  
   143  func (r *Rewriter) ExistingImports(filename string) []Import {
   144  	filename, err := filepath.Abs(filename)
   145  	if err != nil {
   146  		panic(err)
   147  	}
   148  	for _, f := range r.pkg.Syntax {
   149  		pos := r.pkg.Fset.Position(f.Pos())
   150  
   151  		if filename != pos.Filename {
   152  			continue
   153  		}
   154  
   155  		var imps []Import
   156  		for _, i := range f.Imports {
   157  			name := ""
   158  			if i.Name != nil {
   159  				name = i.Name.Name
   160  			}
   161  			path, err := strconv.Unquote(i.Path.Value)
   162  			if err != nil {
   163  				panic(err)
   164  			}
   165  			imps = append(imps, Import{name, path})
   166  		}
   167  		return imps
   168  	}
   169  	return nil
   170  }
   171  
   172  func (r *Rewriter) RemainingSource(filename string) string {
   173  	filename, err := filepath.Abs(filename)
   174  	if err != nil {
   175  		panic(err)
   176  	}
   177  	for _, f := range r.pkg.Syntax {
   178  		pos := r.pkg.Fset.Position(f.Pos())
   179  
   180  		if filename != pos.Filename {
   181  			continue
   182  		}
   183  
   184  		var buf bytes.Buffer
   185  
   186  		for _, d := range f.Decls {
   187  			if r.copied[d] {
   188  				continue
   189  			}
   190  
   191  			if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT {
   192  				continue
   193  			}
   194  
   195  			buf.WriteString(r.getSource(d.Pos(), d.End()))
   196  			buf.WriteString("\n")
   197  		}
   198  
   199  		return strings.TrimSpace(buf.String())
   200  	}
   201  	return ""
   202  }
   203  
   204  type Import struct {
   205  	Alias      string
   206  	ImportPath string
   207  }