github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/internal/rewrite/rewriter.go (about)

     1  package rewrite
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"os"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/spread-ai/gqlgen/internal/code"
    14  	"golang.org/x/tools/go/packages"
    15  )
    16  
    17  type Rewriter struct {
    18  	pkg    *packages.Package
    19  	files  map[string]string
    20  	copied map[ast.Decl]bool
    21  }
    22  
    23  func New(dir string) (*Rewriter, error) {
    24  	importPath := code.ImportPathForDir(dir)
    25  	if importPath == "" {
    26  		return nil, fmt.Errorf("import path not found for directory: %q", dir)
    27  	}
    28  	pkgs, err := packages.Load(&packages.Config{
    29  		Mode: packages.NeedSyntax | packages.NeedTypes,
    30  	}, importPath)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	if len(pkgs) == 0 {
    35  		return nil, fmt.Errorf("package not found for importPath: %s", importPath)
    36  	}
    37  
    38  	return &Rewriter{
    39  		pkg:    pkgs[0],
    40  		files:  map[string]string{},
    41  		copied: map[ast.Decl]bool{},
    42  	}, nil
    43  }
    44  
    45  func (r *Rewriter) getSource(start, end token.Pos) string {
    46  	startPos := r.pkg.Fset.Position(start)
    47  	endPos := r.pkg.Fset.Position(end)
    48  
    49  	if startPos.Filename != endPos.Filename {
    50  		panic("cant get source spanning multiple files")
    51  	}
    52  
    53  	file := r.getFile(startPos.Filename)
    54  	return file[startPos.Offset:endPos.Offset]
    55  }
    56  
    57  func (r *Rewriter) getFile(filename string) string {
    58  	if _, ok := r.files[filename]; !ok {
    59  		b, err := os.ReadFile(filename)
    60  		if err != nil {
    61  			panic(fmt.Errorf("unable to load file, already exists: %w", err))
    62  		}
    63  
    64  		r.files[filename] = string(b)
    65  
    66  	}
    67  
    68  	return r.files[filename]
    69  }
    70  
    71  func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
    72  	for _, f := range r.pkg.Syntax {
    73  		for _, d := range f.Decls {
    74  			d, isFunc := d.(*ast.FuncDecl)
    75  			if !isFunc {
    76  				continue
    77  			}
    78  			if d.Name.Name != methodname {
    79  				continue
    80  			}
    81  			if d.Recv == nil || len(d.Recv.List) == 0 {
    82  				continue
    83  			}
    84  			recv := d.Recv.List[0].Type
    85  			if star, isStar := recv.(*ast.StarExpr); isStar {
    86  				recv = star.X
    87  			}
    88  			ident, ok := recv.(*ast.Ident)
    89  			if !ok {
    90  				continue
    91  			}
    92  
    93  			if ident.Name != structname {
    94  				continue
    95  			}
    96  			return d.Doc.Text()
    97  		}
    98  	}
    99  
   100  	return ""
   101  }
   102  func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
   103  	for _, f := range r.pkg.Syntax {
   104  		for _, d := range f.Decls {
   105  			d, isFunc := d.(*ast.FuncDecl)
   106  			if !isFunc {
   107  				continue
   108  			}
   109  			if d.Name.Name != methodname {
   110  				continue
   111  			}
   112  			if d.Recv == nil || len(d.Recv.List) == 0 {
   113  				continue
   114  			}
   115  			recv := d.Recv.List[0].Type
   116  			if star, isStar := recv.(*ast.StarExpr); isStar {
   117  				recv = star.X
   118  			}
   119  			ident, ok := recv.(*ast.Ident)
   120  			if !ok {
   121  				continue
   122  			}
   123  
   124  			if ident.Name != structname {
   125  				continue
   126  			}
   127  
   128  			r.copied[d] = true
   129  
   130  			return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
   131  		}
   132  	}
   133  
   134  	return ""
   135  }
   136  
   137  func (r *Rewriter) MarkStructCopied(name string) {
   138  	for _, f := range r.pkg.Syntax {
   139  		for _, d := range f.Decls {
   140  			d, isGen := d.(*ast.GenDecl)
   141  			if !isGen {
   142  				continue
   143  			}
   144  			if d.Tok != token.TYPE || len(d.Specs) == 0 {
   145  				continue
   146  			}
   147  
   148  			spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec)
   149  			if !isTypeSpec {
   150  				continue
   151  			}
   152  
   153  			if spec.Name.Name != name {
   154  				continue
   155  			}
   156  
   157  			r.copied[d] = true
   158  		}
   159  	}
   160  }
   161  
   162  func (r *Rewriter) ExistingImports(filename string) []Import {
   163  	filename, err := filepath.Abs(filename)
   164  	if err != nil {
   165  		panic(err)
   166  	}
   167  	for _, f := range r.pkg.Syntax {
   168  		pos := r.pkg.Fset.Position(f.Pos())
   169  
   170  		if filename != pos.Filename {
   171  			continue
   172  		}
   173  
   174  		var imps []Import
   175  		for _, i := range f.Imports {
   176  			name := ""
   177  			if i.Name != nil {
   178  				name = i.Name.Name
   179  			}
   180  			path, err := strconv.Unquote(i.Path.Value)
   181  			if err != nil {
   182  				panic(err)
   183  			}
   184  			imps = append(imps, Import{name, path})
   185  		}
   186  		return imps
   187  	}
   188  	return nil
   189  }
   190  
   191  func (r *Rewriter) RemainingSource(filename string) string {
   192  	filename, err := filepath.Abs(filename)
   193  	if err != nil {
   194  		panic(err)
   195  	}
   196  	for _, f := range r.pkg.Syntax {
   197  		pos := r.pkg.Fset.Position(f.Pos())
   198  
   199  		if filename != pos.Filename {
   200  			continue
   201  		}
   202  
   203  		var buf bytes.Buffer
   204  
   205  		for _, d := range f.Decls {
   206  			if r.copied[d] {
   207  				continue
   208  			}
   209  
   210  			if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT {
   211  				continue
   212  			}
   213  
   214  			buf.WriteString(r.getSource(d.Pos(), d.End()))
   215  			buf.WriteString("\n")
   216  		}
   217  
   218  		return strings.TrimSpace(buf.String())
   219  	}
   220  	return ""
   221  }
   222  
   223  type Import struct {
   224  	Alias      string
   225  	ImportPath string
   226  }