github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/internal/rewrite/rewriter.go (about)

     1  package rewrite
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"os"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/mstephano/gqlgen-schemagen/internal/code"
    14  	"golang.org/x/tools/go/packages"
    15  )
    16  
    17  type Rewriter struct {
    18  	pkg    *packages.Package
    19  	files  map[string]string
    20  	copied map[ast.Decl]bool
    21  }
    22  
    23  func New(dir string) (*Rewriter, error) {
    24  	importPath := code.ImportPathForDir(dir)
    25  	if importPath == "" {
    26  		return nil, fmt.Errorf("import path not found for directory: %q", dir)
    27  	}
    28  	pkgs, err := packages.Load(&packages.Config{
    29  		Mode: packages.NeedSyntax | packages.NeedTypes,
    30  	}, importPath)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	if len(pkgs) == 0 {
    35  		return nil, fmt.Errorf("package not found for importPath: %s", importPath)
    36  	}
    37  
    38  	return &Rewriter{
    39  		pkg:    pkgs[0],
    40  		files:  map[string]string{},
    41  		copied: map[ast.Decl]bool{},
    42  	}, nil
    43  }
    44  
    45  func (r *Rewriter) getSource(start, end token.Pos) string {
    46  	startPos := r.pkg.Fset.Position(start)
    47  	endPos := r.pkg.Fset.Position(end)
    48  
    49  	if startPos.Filename != endPos.Filename {
    50  		panic("cant get source spanning multiple files")
    51  	}
    52  
    53  	file := r.getFile(startPos.Filename)
    54  	return file[startPos.Offset:endPos.Offset]
    55  }
    56  
    57  func (r *Rewriter) getFile(filename string) string {
    58  	if _, ok := r.files[filename]; !ok {
    59  		b, err := os.ReadFile(filename)
    60  		if err != nil {
    61  			panic(fmt.Errorf("unable to load file, already exists: %w", err))
    62  		}
    63  
    64  		r.files[filename] = string(b)
    65  
    66  	}
    67  
    68  	return r.files[filename]
    69  }
    70  
    71  func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
    72  	for _, f := range r.pkg.Syntax {
    73  		for _, d := range f.Decls {
    74  			d, isFunc := d.(*ast.FuncDecl)
    75  			if !isFunc {
    76  				continue
    77  			}
    78  			if d.Name.Name != methodname {
    79  				continue
    80  			}
    81  			if d.Recv == nil || len(d.Recv.List) == 0 {
    82  				continue
    83  			}
    84  			recv := d.Recv.List[0].Type
    85  			if star, isStar := recv.(*ast.StarExpr); isStar {
    86  				recv = star.X
    87  			}
    88  			ident, ok := recv.(*ast.Ident)
    89  			if !ok {
    90  				continue
    91  			}
    92  
    93  			if ident.Name != structname {
    94  				continue
    95  			}
    96  			return d.Doc.Text()
    97  		}
    98  	}
    99  
   100  	return ""
   101  }
   102  
   103  func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
   104  	for _, f := range r.pkg.Syntax {
   105  		for _, d := range f.Decls {
   106  			d, isFunc := d.(*ast.FuncDecl)
   107  			if !isFunc {
   108  				continue
   109  			}
   110  			if d.Name.Name != methodname {
   111  				continue
   112  			}
   113  			if d.Recv == nil || len(d.Recv.List) == 0 {
   114  				continue
   115  			}
   116  			recv := d.Recv.List[0].Type
   117  			if star, isStar := recv.(*ast.StarExpr); isStar {
   118  				recv = star.X
   119  			}
   120  			ident, ok := recv.(*ast.Ident)
   121  			if !ok {
   122  				continue
   123  			}
   124  
   125  			if ident.Name != structname {
   126  				continue
   127  			}
   128  
   129  			r.copied[d] = true
   130  
   131  			return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
   132  		}
   133  	}
   134  
   135  	return ""
   136  }
   137  
   138  func (r *Rewriter) MarkStructCopied(name string) {
   139  	for _, f := range r.pkg.Syntax {
   140  		for _, d := range f.Decls {
   141  			d, isGen := d.(*ast.GenDecl)
   142  			if !isGen {
   143  				continue
   144  			}
   145  			if d.Tok != token.TYPE || len(d.Specs) == 0 {
   146  				continue
   147  			}
   148  
   149  			spec, isTypeSpec := d.Specs[0].(*ast.TypeSpec)
   150  			if !isTypeSpec {
   151  				continue
   152  			}
   153  
   154  			if spec.Name.Name != name {
   155  				continue
   156  			}
   157  
   158  			r.copied[d] = true
   159  		}
   160  	}
   161  }
   162  
   163  func (r *Rewriter) ExistingImports(filename string) []Import {
   164  	filename, err := filepath.Abs(filename)
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	for _, f := range r.pkg.Syntax {
   169  		pos := r.pkg.Fset.Position(f.Pos())
   170  
   171  		if filename != pos.Filename {
   172  			continue
   173  		}
   174  
   175  		var imps []Import
   176  		for _, i := range f.Imports {
   177  			name := ""
   178  			if i.Name != nil {
   179  				name = i.Name.Name
   180  			}
   181  			path, err := strconv.Unquote(i.Path.Value)
   182  			if err != nil {
   183  				panic(err)
   184  			}
   185  			imps = append(imps, Import{name, path})
   186  		}
   187  		return imps
   188  	}
   189  	return nil
   190  }
   191  
   192  func (r *Rewriter) RemainingSource(filename string) string {
   193  	filename, err := filepath.Abs(filename)
   194  	if err != nil {
   195  		panic(err)
   196  	}
   197  	for _, f := range r.pkg.Syntax {
   198  		pos := r.pkg.Fset.Position(f.Pos())
   199  
   200  		if filename != pos.Filename {
   201  			continue
   202  		}
   203  
   204  		var buf bytes.Buffer
   205  
   206  		for _, d := range f.Decls {
   207  			if r.copied[d] {
   208  				continue
   209  			}
   210  
   211  			if d, isGen := d.(*ast.GenDecl); isGen && d.Tok == token.IMPORT {
   212  				continue
   213  			}
   214  
   215  			buf.WriteString(r.getSource(d.Pos(), d.End()))
   216  			buf.WriteString("\n")
   217  		}
   218  
   219  		return strings.TrimSpace(buf.String())
   220  	}
   221  	return ""
   222  }
   223  
   224  type Import struct {
   225  	Alias      string
   226  	ImportPath string
   227  }