github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/internal/imports/prune.go (about)

     1  // Wrapper around x/tools/imports that only removes imports, never adds new ones.
     2  
     3  package imports
     4  
     5  import (
     6  	"bytes"
     7  	"go/ast"
     8  	"go/parser"
     9  	"go/printer"
    10  	"go/token"
    11  	"strings"
    12  
    13  	"github.com/99designs/gqlgen/internal/code"
    14  
    15  	"golang.org/x/tools/go/ast/astutil"
    16  	"golang.org/x/tools/imports"
    17  )
    18  
    19  type visitFn func(node ast.Node)
    20  
    21  func (fn visitFn) Visit(node ast.Node) ast.Visitor {
    22  	fn(node)
    23  	return fn
    24  }
    25  
    26  // Prune removes any unused imports
    27  func Prune(filename string, src []byte, packages *code.Packages) ([]byte, error) {
    28  	fset := token.NewFileSet()
    29  
    30  	file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	unused := getUnusedImports(file, packages)
    36  	for ipath, name := range unused {
    37  		astutil.DeleteNamedImport(fset, file, name, ipath)
    38  	}
    39  	printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
    40  
    41  	var buf bytes.Buffer
    42  	if err := printConfig.Fprint(&buf, fset, file); err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
    47  }
    48  
    49  func getUnusedImports(file ast.Node, packages *code.Packages) map[string]string {
    50  	imported := map[string]*ast.ImportSpec{}
    51  	used := map[string]bool{}
    52  
    53  	ast.Walk(visitFn(func(node ast.Node) {
    54  		if node == nil {
    55  			return
    56  		}
    57  		switch v := node.(type) {
    58  		case *ast.ImportSpec:
    59  			if v.Name != nil {
    60  				imported[v.Name.Name] = v
    61  				break
    62  			}
    63  			ipath := strings.Trim(v.Path.Value, `"`)
    64  			if ipath == "C" {
    65  				break
    66  			}
    67  
    68  			local := packages.NameForPackage(ipath)
    69  
    70  			imported[local] = v
    71  		case *ast.SelectorExpr:
    72  			xident, ok := v.X.(*ast.Ident)
    73  			if !ok {
    74  				break
    75  			}
    76  			if xident.Obj != nil {
    77  				// if the parser can resolve it, it's not a package ref
    78  				break
    79  			}
    80  			used[xident.Name] = true
    81  		}
    82  	}), file)
    83  
    84  	for pkg := range used {
    85  		delete(imported, pkg)
    86  	}
    87  
    88  	unusedImport := map[string]string{}
    89  	for pkg, is := range imported {
    90  		if !used[pkg] && pkg != "_" && pkg != "." {
    91  			name := ""
    92  			if is.Name != nil {
    93  				name = is.Name.Name
    94  			}
    95  			unusedImport[strings.Trim(is.Path.Value, `"`)] = name
    96  		}
    97  	}
    98  
    99  	return unusedImport
   100  }