github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/internal/rewrite/rewriter.go (about)

     1  package rewrite
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"io/ioutil"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/99designs/gqlgen/internal/code"
    14  	"golang.org/x/tools/go/packages"
    15  )
    16  
    17  type Rewriter struct {
    18  	pkg    *packages.Package
    19  	files  map[string]string
    20  	copied map[ast.Decl]bool
    21  }
    22  
    23  func New(dir string) (*Rewriter, error) {
    24  	importPath := code.ImportPathForDir(dir)
    25  	if importPath == "" {
    26  		return nil, fmt.Errorf("import path not found for directory: %q", dir)
    27  	}
    28  	pkgs, err := packages.Load(&packages.Config{
    29  		Mode: packages.NeedSyntax | packages.NeedTypes,
    30  	}, importPath)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	if len(pkgs) == 0 {
    35  		return nil, fmt.Errorf("package not found for importPath: %s", importPath)
    36  	}
    37  
    38  	return &Rewriter{
    39  		pkg:    pkgs[0],
    40  		files:  map[string]string{},
    41  		copied: map[ast.Decl]bool{},
    42  	}, nil
    43  }
    44  
    45  func (r *Rewriter) getSource(start, end token.Pos) string {
    46  	startPos := r.pkg.Fset.Position(start)
    47  	endPos := r.pkg.Fset.Position(end)
    48  
    49  	if startPos.Filename != endPos.Filename {
    50  		panic("cant get source spanning multiple files")
    51  	}
    52  
    53  	file := r.getFile(startPos.Filename)
    54  	return file[startPos.Offset:endPos.Offset]
    55  }
    56  
    57  func (r *Rewriter) getFile(filename string) string {
    58  	if _, ok := r.files[filename]; !ok {
    59  		b, err := ioutil.ReadFile(filename)
    60  		if err != nil {
    61  			panic(fmt.Errorf("unable to load file, already exists: %s", err.Error()))
    62  		}
    63  
    64  		r.files[filename] = string(b)
    65  
    66  	}
    67  
    68  	return r.files[filename]
    69  }
    70  
    71  func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
    72  	for _, f := range r.pkg.Syntax {
    73  		for _, d := range f.Decls {
    74  			d, isFunc := d.(*ast.FuncDecl)
    75  			if !isFunc {
    76  				continue
    77  			}
    78  			if d.Name.Name != methodname {
    79  				continue
    80  			}
    81  			if d.Recv == nil || len(d.Recv.List) == 0 {
    82  				continue
    83  			}
    84  			recv := d.Recv.List[0].Type
    85  			if star, isStar := recv.(*ast.StarExpr); isStar {
    86  				recv = star.X
    87  			}
    88  			ident, ok := recv.(*ast.Ident)
    89  			if !ok {
    90  				continue
    91  			}
    92  
    93  			if ident.Name != structname {
    94  				continue
    95  			}
    96  
    97  			r.copied[d] = true
    98  
    99  			return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
   100  		}
   101  	}
   102  
   103  	return ""
   104  }
   105  
   106  func (r *Rewriter) MarkStructCopied(name string) {
   107  	for _, f := range r.pkg.Syntax {
   108  		for _, d := range f.Decls {
   109  			d, isGen := d.(*ast.GenDecl)
   110  			if !isGen {
   111  				continue
   112  			}
   113  			if d.Tok != token.TYPE || len(d.Specs) == 0 {
   114  				continue
   115  			}
   116  
   117  			spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec)
   118  			if !isTypeSpec {
   119  				continue
   120  			}
   121  
   122  			if spec.Name.Name != name {
   123  				continue
   124  			}
   125  
   126  			r.copied[d] = true
   127  		}
   128  	}
   129  }
   130  
   131  func (r *Rewriter) ExistingImports(filename string) []Import {
   132  	filename, err := filepath.Abs(filename)
   133  	if err != nil {
   134  		panic(err)
   135  	}
   136  	for _, f := range r.pkg.Syntax {
   137  		pos := r.pkg.Fset.Position(f.Pos())
   138  
   139  		if filename != pos.Filename {
   140  			continue
   141  		}
   142  
   143  		var imps []Import
   144  		for _, i := range f.Imports {
   145  			name := ""
   146  			if i.Name != nil {
   147  				name = i.Name.Name
   148  			}
   149  			path, err := strconv.Unquote(i.Path.Value)
   150  			if err != nil {
   151  				panic(err)
   152  			}
   153  			imps = append(imps, Import{name, path})
   154  		}
   155  		return imps
   156  	}
   157  	return nil
   158  }
   159  
   160  func (r *Rewriter) RemainingSource(filename string) string {
   161  	filename, err := filepath.Abs(filename)
   162  	if err != nil {
   163  		panic(err)
   164  	}
   165  	for _, f := range r.pkg.Syntax {
   166  		pos := r.pkg.Fset.Position(f.Pos())
   167  
   168  		if filename != pos.Filename {
   169  			continue
   170  		}
   171  
   172  		var buf bytes.Buffer
   173  
   174  		for _, d := range f.Decls {
   175  			if r.copied[d] {
   176  				continue
   177  			}
   178  
   179  			if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT {
   180  				continue
   181  			}
   182  
   183  			buf.WriteString(r.getSource(d.Pos(), d.End()))
   184  			buf.WriteString("\n")
   185  		}
   186  
   187  		return strings.TrimSpace(buf.String())
   188  	}
   189  	return ""
   190  }
   191  
   192  type Import struct {
   193  	Alias      string
   194  	ImportPath string
   195  }