github.com/buth/gqlgen@v0.7.2/internal/imports/prune.go (about)

     1  // Wrapper around x/tools/imports that only removes imports, never adds new ones.
     2  
     3  package imports
     4  
     5  import (
     6  	"bytes"
     7  	"go/ast"
     8  	"go/build"
     9  	"go/parser"
    10  	"go/printer"
    11  	"go/token"
    12  	"path/filepath"
    13  	"strings"
    14  
    15  	"golang.org/x/tools/imports"
    16  
    17  	"golang.org/x/tools/go/ast/astutil"
    18  )
    19  
    20  type visitFn func(node ast.Node)
    21  
    22  func (fn visitFn) Visit(node ast.Node) ast.Visitor {
    23  	fn(node)
    24  	return fn
    25  }
    26  
    27  // Prune removes any unused imports
    28  func Prune(filename string, src []byte) ([]byte, error) {
    29  	fset := token.NewFileSet()
    30  
    31  	file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	unused, err := getUnusedImports(file, filename)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	for ipath, name := range unused {
    41  		astutil.DeleteNamedImport(fset, file, name, ipath)
    42  	}
    43  	printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
    44  
    45  	var buf bytes.Buffer
    46  	if err := printConfig.Fprint(&buf, fset, file); err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
    51  }
    52  
    53  func getUnusedImports(file ast.Node, filename string) (map[string]string, error) {
    54  	imported := map[string]*ast.ImportSpec{}
    55  	used := map[string]bool{}
    56  
    57  	abs, err := filepath.Abs(filename)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	srcDir := filepath.Dir(abs)
    62  
    63  	ast.Walk(visitFn(func(node ast.Node) {
    64  		if node == nil {
    65  			return
    66  		}
    67  		switch v := node.(type) {
    68  		case *ast.ImportSpec:
    69  			if v.Name != nil {
    70  				imported[v.Name.Name] = v
    71  				break
    72  			}
    73  			ipath := strings.Trim(v.Path.Value, `"`)
    74  			if ipath == "C" {
    75  				break
    76  			}
    77  
    78  			local := importPathToName(ipath, srcDir)
    79  
    80  			imported[local] = v
    81  		case *ast.SelectorExpr:
    82  			xident, ok := v.X.(*ast.Ident)
    83  			if !ok {
    84  				break
    85  			}
    86  			if xident.Obj != nil {
    87  				// if the parser can resolve it, it's not a package ref
    88  				break
    89  			}
    90  			used[xident.Name] = true
    91  		}
    92  	}), file)
    93  
    94  	for pkg := range used {
    95  		delete(imported, pkg)
    96  	}
    97  
    98  	unusedImport := map[string]string{}
    99  	for pkg, is := range imported {
   100  		if !used[pkg] && pkg != "_" && pkg != "." {
   101  			name := ""
   102  			if is.Name != nil {
   103  				name = is.Name.Name
   104  			}
   105  			unusedImport[strings.Trim(is.Path.Value, `"`)] = name
   106  		}
   107  	}
   108  
   109  	return unusedImport, nil
   110  }
   111  
   112  func importPathToName(importPath, srcDir string) (packageName string) {
   113  	pkg, err := build.Default.Import(importPath, srcDir, 0)
   114  	if err != nil {
   115  		return ""
   116  	}
   117  
   118  	return pkg.Name
   119  }