github.com/inturn/pre-commit-gobuild@v1.0.12/hooks/go-imports/go-imports.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"go/ast"
     8  	"go/format"
     9  	"go/parser"
    10  	"go/token"
    11  	"io/ioutil"
    12  	"log"
    13  	"os"
    14  	"path/filepath"
    15  	"sort"
    16  	"strings"
    17  	"sync"
    18  
    19  	"github.com/inturn/pre-commit-gobuild/internal/helpers"
    20  )
    21  
    22  type lintError struct {
    23  	err  error
    24  	path string
    25  }
    26  
    27  var (
    28  	errNotSorted = fmt.Errorf("file imports have to be sorted or formatted according to go.fmt")
    29  	checkOnly    bool
    30  )
    31  
    32  func init() {
    33  	flag.BoolVar(&checkOnly, "check-only", false, "linter only checks imports order and don't modify files")
    34  	flag.Parse()
    35  }
    36  
    37  func main() {
    38  	workDir, err := os.Getwd()
    39  	if err != nil {
    40  		log.Fatal(err)
    41  	}
    42  	dirs := helpers.DirsWith(workDir, "\\.go$")
    43  
    44  	errc := make(chan lintError, 10)
    45  	wg := &sync.WaitGroup{}
    46  
    47  	go func() {
    48  		for _, dir := range dirs {
    49  			if !strings.Contains(dir, "/vendor/") {
    50  				files, err := ioutil.ReadDir(dir)
    51  				if err != nil {
    52  					log.Printf("error occured on read dir %s: %s", dir, err)
    53  				}
    54  				for _, f := range files {
    55  					if !strings.HasSuffix(f.Name(), ".go") {
    56  						continue
    57  					}
    58  					wg.Add(1)
    59  					go func(d, name string) {
    60  						sortFileImports(filepath.Join(d, name), errc)
    61  						wg.Done()
    62  					}(dir, f.Name())
    63  				}
    64  			}
    65  		}
    66  		wg.Wait()
    67  		close(errc)
    68  	}()
    69  
    70  	var le []lintError
    71  	for lintErr := range errc {
    72  		if !checkOnly && lintErr.err == errNotSorted {
    73  			continue
    74  		}
    75  		log.Println(lintErr.path, lintErr.err)
    76  		le = append(le, lintErr)
    77  	}
    78  	if len(le) != 0 {
    79  		log.Println("errors count:", len(le))
    80  		os.Exit(1)
    81  	}
    82  }
    83  
    84  func sortFileImports(path string, errc chan<- lintError) {
    85  	fSet := token.NewFileSet()
    86  
    87  	f, err := parser.ParseFile(fSet, path, nil, parser.ParseComments)
    88  	if err != nil {
    89  		errc <- lintError{
    90  			err:  err,
    91  			path: path,
    92  		}
    93  		return
    94  	}
    95  
    96  	sortImports(f)
    97  
    98  	buf := &bytes.Buffer{}
    99  	if err := format.Node(buf, fSet, f); err != nil {
   100  		errc <- lintError{
   101  			err:  err,
   102  			path: path,
   103  		}
   104  		return
   105  	}
   106  
   107  	data, err := ioutil.ReadFile(path)
   108  	if err != nil {
   109  		errc <- lintError{
   110  			err:  err,
   111  			path: path,
   112  		}
   113  		return
   114  	}
   115  
   116  	if buf.String() == string(data) {
   117  		return
   118  	}
   119  
   120  	if checkOnly {
   121  		errc <- lintError{
   122  			err:  errNotSorted,
   123  			path: path,
   124  		}
   125  		return
   126  	}
   127  
   128  	if err := ioutil.WriteFile(path, buf.Bytes(), 0664); err != nil {
   129  		errc <- lintError{
   130  			err:  err,
   131  			path: path,
   132  		}
   133  		return
   134  	}
   135  
   136  	log.Printf("%s file has changed", path)
   137  }
   138  
   139  func sortImports(f *ast.File) {
   140  	if len(f.Imports) <= 1 {
   141  		return
   142  	}
   143  
   144  	imp1 := make(impSlice, 0)
   145  	imp2 := make(impSlice, 0)
   146  
   147  	for _, imp := range f.Imports {
   148  		impData := importData{}
   149  
   150  		if imp.Doc != nil && imp.Name != nil && imp.Name.Name == "_" {
   151  			impData.comment = imp.Doc.Text()
   152  		}
   153  
   154  		if imp.Name != nil {
   155  			impData.name = imp.Name.Name
   156  		}
   157  		impData.value = imp.Path.Value
   158  
   159  		// determine third-party package import
   160  		if strings.Contains(imp.Path.Value, ".") {
   161  			imp2 = append(imp2, impData)
   162  			continue
   163  		}
   164  
   165  		imp1 = append(imp1, impData)
   166  	}
   167  
   168  	nonImportComment := f.Comments[:0]
   169  	startPos := f.Imports[0].Pos()
   170  	lastPos := f.Imports[len(f.Imports)-1].End()
   171  
   172  	for _, c := range f.Comments {
   173  		if c.Pos() > lastPos || c.Pos() < startPos {
   174  			nonImportComment = append(nonImportComment, c)
   175  		}
   176  	}
   177  
   178  	f.Comments = nonImportComment
   179  
   180  	sort.Sort(imp1)
   181  	sort.Sort(imp2)
   182  
   183  	for _, d := range f.Decls {
   184  		d, ok := d.(*ast.GenDecl)
   185  		if !ok || d.Tok != token.IMPORT {
   186  			// Not an import declaration, so we're done.
   187  			// Imports are always first.
   188  			break
   189  		}
   190  
   191  		if !d.Lparen.IsValid() {
   192  			// Not a block: sorted by default.
   193  			continue
   194  		}
   195  
   196  		d.Specs = d.Specs[:0]
   197  
   198  		for _, imp := range imp1 {
   199  			addISpec(imp, d)
   200  		}
   201  
   202  		if len(imp2) != 0 {
   203  			// Add empty line between groups.
   204  			d.Specs = append(d.Specs, &ast.ImportSpec{Path: &ast.BasicLit{}})
   205  
   206  			for _, imp := range imp2 {
   207  				addISpec(imp, d)
   208  
   209  			}
   210  		}
   211  	}
   212  }
   213  
   214  func addISpec(imp importData, d *ast.GenDecl) {
   215  	if imp.name == "_" {
   216  		comm := imp.comment
   217  		if comm == "" {
   218  			comm = "todo comment here why do you use blank import"
   219  		}
   220  		d.Specs = append(d.Specs, &ast.ImportSpec{
   221  			Path: &ast.BasicLit{Value: "// " + strings.TrimSpace(comm)},
   222  		})
   223  	}
   224  	iSpec := ast.ImportSpec{
   225  		Path: &ast.BasicLit{Value: imp.value},
   226  		Name: &ast.Ident{Name: imp.name},
   227  	}
   228  	d.Specs = append(d.Specs, &iSpec)
   229  }
   230  
   231  type impSlice []importData
   232  
   233  type importData struct {
   234  	value   string
   235  	name    string
   236  	comment string
   237  }
   238  
   239  func (s impSlice) Len() int {
   240  	return len(s)
   241  }
   242  
   243  func (s impSlice) Less(i, j int) bool {
   244  	return s[i].value < s[j].value
   245  }
   246  
   247  func (s impSlice) Swap(i, j int) {
   248  	s[i], s[j] = s[j], s[i]
   249  }