github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/internal/imports/prune.go (about)

     1  // Wrapper around x/tools/imports that only removes imports, never adds new ones.
     2  
     3  package imports
     4  
     5  import (
     6  	"bytes"
     7  	"go/ast"
     8  	"go/parser"
     9  	"go/printer"
    10  	"go/token"
    11  	"strings"
    12  
    13  	"github.com/HaswinVidanage/gqlgen/internal/code"
    14  
    15  	"golang.org/x/tools/go/ast/astutil"
    16  	"golang.org/x/tools/imports"
    17  )
    18  
    19  type visitFn func(node ast.Node)
    20  
    21  func (fn visitFn) Visit(node ast.Node) ast.Visitor {
    22  	fn(node)
    23  	return fn
    24  }
    25  
    26  // Prune removes any unused imports
    27  func Prune(filename string, src []byte) ([]byte, error) {
    28  	fset := token.NewFileSet()
    29  
    30  	file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	unused, err := getUnusedImports(file, filename)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	for ipath, name := range unused {
    40  		astutil.DeleteNamedImport(fset, file, name, ipath)
    41  	}
    42  	printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
    43  
    44  	var buf bytes.Buffer
    45  	if err := printConfig.Fprint(&buf, fset, file); err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
    50  }
    51  
    52  func getUnusedImports(file ast.Node, filename string) (map[string]string, error) {
    53  	imported := map[string]*ast.ImportSpec{}
    54  	used := map[string]bool{}
    55  
    56  	ast.Walk(visitFn(func(node ast.Node) {
    57  		if node == nil {
    58  			return
    59  		}
    60  		switch v := node.(type) {
    61  		case *ast.ImportSpec:
    62  			if v.Name != nil {
    63  				imported[v.Name.Name] = v
    64  				break
    65  			}
    66  			ipath := strings.Trim(v.Path.Value, `"`)
    67  			if ipath == "C" {
    68  				break
    69  			}
    70  
    71  			local := code.NameForPackage(ipath)
    72  
    73  			imported[local] = v
    74  		case *ast.SelectorExpr:
    75  			xident, ok := v.X.(*ast.Ident)
    76  			if !ok {
    77  				break
    78  			}
    79  			if xident.Obj != nil {
    80  				// if the parser can resolve it, it's not a package ref
    81  				break
    82  			}
    83  			used[xident.Name] = true
    84  		}
    85  	}), file)
    86  
    87  	for pkg := range used {
    88  		delete(imported, pkg)
    89  	}
    90  
    91  	unusedImport := map[string]string{}
    92  	for pkg, is := range imported {
    93  		if !used[pkg] && pkg != "_" && pkg != "." {
    94  			name := ""
    95  			if is.Name != nil {
    96  				name = is.Name.Name
    97  			}
    98  			unusedImport[strings.Trim(is.Path.Value, `"`)] = name
    99  		}
   100  	}
   101  
   102  	return unusedImport, nil
   103  }