git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/internal/rewrite/rewriter.go (about)

     1  package rewrite
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"io/ioutil"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"golang.org/x/tools/go/packages"
    14  )
    15  
    16  type Rewriter struct {
    17  	pkg    *packages.Package
    18  	files  map[string]string
    19  	copied map[ast.Decl]bool
    20  }
    21  
    22  func New(importPath string) (*Rewriter, error) {
    23  	pkgs, err := packages.Load(&packages.Config{
    24  		Mode: packages.NeedSyntax | packages.NeedTypes,
    25  	}, importPath)
    26  	if err != nil {
    27  		return nil, err
    28  	}
    29  	if len(pkgs) == 0 {
    30  		return nil, fmt.Errorf("package not found for importPath: %s", importPath)
    31  	}
    32  
    33  	return &Rewriter{
    34  		pkg:    pkgs[0],
    35  		files:  map[string]string{},
    36  		copied: map[ast.Decl]bool{},
    37  	}, nil
    38  }
    39  
    40  func (r *Rewriter) getSource(start, end token.Pos) string {
    41  	startPos := r.pkg.Fset.Position(start)
    42  	endPos := r.pkg.Fset.Position(end)
    43  
    44  	if startPos.Filename != endPos.Filename {
    45  		panic("cant get source spanning multiple files")
    46  	}
    47  
    48  	file := r.getFile(startPos.Filename)
    49  	return file[startPos.Offset:endPos.Offset]
    50  }
    51  
    52  func (r *Rewriter) getFile(filename string) string {
    53  	if _, ok := r.files[filename]; !ok {
    54  		b, err := ioutil.ReadFile(filename)
    55  		if err != nil {
    56  			panic(fmt.Errorf("unable to load file, already exists: %s", err.Error()))
    57  		}
    58  
    59  		r.files[filename] = string(b)
    60  
    61  	}
    62  
    63  	return r.files[filename]
    64  }
    65  
    66  func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
    67  	for _, f := range r.pkg.Syntax {
    68  		for _, d := range f.Decls {
    69  			d, isFunc := d.(*ast.FuncDecl)
    70  			if !isFunc {
    71  				continue
    72  			}
    73  			if d.Name.Name != methodname {
    74  				continue
    75  			}
    76  			if d.Recv == nil || len(d.Recv.List) == 0 {
    77  				continue
    78  			}
    79  			recv := d.Recv.List[0].Type
    80  			if star, isStar := recv.(*ast.StarExpr); isStar {
    81  				recv = star.X
    82  			}
    83  			ident, ok := recv.(*ast.Ident)
    84  			if !ok {
    85  				continue
    86  			}
    87  
    88  			if ident.Name != structname {
    89  				continue
    90  			}
    91  
    92  			r.copied[d] = true
    93  
    94  			return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
    95  		}
    96  	}
    97  
    98  	return ""
    99  }
   100  
   101  func (r *Rewriter) MarkStructCopied(name string) {
   102  	for _, f := range r.pkg.Syntax {
   103  		for _, d := range f.Decls {
   104  			d, isGen := d.(*ast.GenDecl)
   105  			if !isGen {
   106  				continue
   107  			}
   108  			if d.Tok != token.TYPE || len(d.Specs) == 0 {
   109  				continue
   110  			}
   111  
   112  			spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec)
   113  			if !isTypeSpec {
   114  				continue
   115  			}
   116  
   117  			if spec.Name.Name != name {
   118  				continue
   119  			}
   120  
   121  			r.copied[d] = true
   122  		}
   123  	}
   124  }
   125  
   126  func (r *Rewriter) ExistingImports(filename string) []Import {
   127  	filename, err := filepath.Abs(filename)
   128  	if err != nil {
   129  		panic(err)
   130  	}
   131  	for _, f := range r.pkg.Syntax {
   132  		pos := r.pkg.Fset.Position(f.Pos())
   133  
   134  		if filename != pos.Filename {
   135  			continue
   136  		}
   137  
   138  		var imps []Import
   139  		for _, i := range f.Imports {
   140  			name := ""
   141  			if i.Name != nil {
   142  				name = i.Name.Name
   143  			}
   144  			path, err := strconv.Unquote(i.Path.Value)
   145  			if err != nil {
   146  				panic(err)
   147  			}
   148  			imps = append(imps, Import{name, path})
   149  		}
   150  		return imps
   151  	}
   152  	return nil
   153  }
   154  
   155  func (r *Rewriter) RemainingSource(filename string) string {
   156  	filename, err := filepath.Abs(filename)
   157  	if err != nil {
   158  		panic(err)
   159  	}
   160  	for _, f := range r.pkg.Syntax {
   161  		pos := r.pkg.Fset.Position(f.Pos())
   162  
   163  		if filename != pos.Filename {
   164  			continue
   165  		}
   166  
   167  		var buf bytes.Buffer
   168  
   169  		for _, d := range f.Decls {
   170  			if r.copied[d] {
   171  				continue
   172  			}
   173  
   174  			if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT {
   175  				continue
   176  			}
   177  
   178  			buf.WriteString(r.getSource(d.Pos(), d.End()))
   179  			buf.WriteString("\n")
   180  		}
   181  
   182  		return strings.TrimSpace(buf.String())
   183  	}
   184  	return ""
   185  }
   186  
   187  type Import struct {
   188  	Alias      string
   189  	ImportPath string
   190  }