github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/cmd/importsort/main.go (about)

     1  // Copyright (c) 2017 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"go/build"
    13  	"io/ioutil"
    14  	"os"
    15  	"path/filepath"
    16  	"sort"
    17  	"strings"
    18  
    19  	"golang.org/x/tools/go/vcs"
    20  )
    21  
    22  // Implementation taken from "isStandardImportPath" in go's source.
    23  func isStdLibPath(path string) bool {
    24  	i := strings.Index(path, "/")
    25  	if i < 0 {
    26  		i = len(path)
    27  	}
    28  	elem := path[:i]
    29  	return !strings.Contains(elem, ".")
    30  }
    31  
    32  // sortImports takes in an "import" body and returns it sorted
    33  func sortImports(in []byte, sections []string) []byte {
    34  	type importLine struct {
    35  		index int    // index into inLines
    36  		path  string // import path used for sorting
    37  	}
    38  	// imports holds all the import lines, separated by section. The
    39  	// first section is for stdlib imports, the following sections
    40  	// hold the user specified sections, the final section is for
    41  	// everything else.
    42  	imports := make([][]importLine, len(sections)+2)
    43  	addImport := func(section, index int, importPath string) {
    44  		imports[section] = append(imports[section], importLine{index, importPath})
    45  	}
    46  	stdlib := 0
    47  	offset := 1
    48  	other := len(imports) - 1
    49  
    50  	inLines := bytes.Split(in, []byte{'\n'})
    51  	for i, line := range inLines {
    52  		if len(line) == 0 {
    53  			continue
    54  		}
    55  		start := bytes.IndexByte(line, '"')
    56  		if start == -1 {
    57  			continue
    58  		}
    59  		if comment := bytes.Index(line, []byte("//")); comment > -1 && comment < start {
    60  			continue
    61  		}
    62  
    63  		start++ // skip '"'
    64  		end := bytes.IndexByte(line[start:], '"') + start
    65  		s := string(line[start:end])
    66  
    67  		found := false
    68  		for j, sect := range sections {
    69  			if strings.HasPrefix(s, sect) && (len(sect) == len(s) || s[len(sect)] == '/') {
    70  				addImport(j+offset, i, s)
    71  				found = true
    72  				break
    73  			}
    74  		}
    75  		if found {
    76  			continue
    77  		}
    78  
    79  		if isStdLibPath(s) {
    80  			addImport(stdlib, i, s)
    81  		} else {
    82  			addImport(other, i, s)
    83  		}
    84  	}
    85  
    86  	out := make([]byte, 0, len(in)+2)
    87  	needSeperator := false
    88  	for _, section := range imports {
    89  		if len(section) == 0 {
    90  			continue
    91  		}
    92  		if needSeperator {
    93  			out = append(out, '\n')
    94  		}
    95  		sort.Slice(section, func(a, b int) bool {
    96  			return section[a].path < section[b].path
    97  		})
    98  		for _, s := range section {
    99  			out = append(out, inLines[s.index]...)
   100  			out = append(out, '\n')
   101  		}
   102  		needSeperator = true
   103  	}
   104  
   105  	return out
   106  }
   107  
   108  func genFile(in []byte, sections []string) ([]byte, error) {
   109  	out := make([]byte, 0, len(in)+3) // Add some fudge to avoid re-allocation
   110  
   111  	for {
   112  		const importLine = "\nimport (\n"
   113  		const importLineLen = len(importLine)
   114  		importStart := bytes.Index(in, []byte(importLine))
   115  		if importStart == -1 {
   116  			break
   117  		}
   118  		// Save to `out` everything up to and including "import(\n"
   119  		out = append(out, in[:importStart+importLineLen]...)
   120  		in = in[importStart+importLineLen:]
   121  		importLen := bytes.Index(in, []byte("\n)\n"))
   122  		if importLen == -1 {
   123  			return nil, errors.New(`parsing error: missing ")"`)
   124  		}
   125  		// Sort body of "import" and write it to `out`
   126  		out = append(out, sortImports(in[:importLen], sections)...)
   127  		out = append(out, []byte(")")...)
   128  		in = in[importLen+2:]
   129  	}
   130  	// Write everything leftover to out
   131  	out = append(out, in...)
   132  	return out, nil
   133  }
   134  
   135  // returns true if the file changed
   136  func processFile(filename string, writeFile, listDiffFiles bool, sections []string) (bool, error) {
   137  	in, err := ioutil.ReadFile(filename)
   138  	if err != nil {
   139  		return false, err
   140  	}
   141  	out, err := genFile(in, sections)
   142  	if err != nil {
   143  		return false, err
   144  	}
   145  
   146  	equal := bytes.Equal(in, out)
   147  	if listDiffFiles {
   148  		return !equal, nil
   149  	}
   150  	if !writeFile {
   151  		os.Stdout.Write(out)
   152  		return !equal, nil
   153  	}
   154  
   155  	if equal {
   156  		return false, nil
   157  	}
   158  	temp, err := ioutil.TempFile(filepath.Dir(filename), filepath.Base(filename))
   159  	if err != nil {
   160  		return false, err
   161  	}
   162  	defer os.RemoveAll(temp.Name())
   163  	s, err := os.Stat(filename)
   164  	if err != nil {
   165  		return false, err
   166  	}
   167  	if _, err = temp.Write(out); err != nil {
   168  		return false, err
   169  	}
   170  	if err := temp.Close(); err != nil {
   171  		return false, err
   172  	}
   173  	if err := os.Chmod(temp.Name(), s.Mode()); err != nil {
   174  		return false, err
   175  	}
   176  	if err := os.Rename(temp.Name(), filename); err != nil {
   177  		return false, err
   178  	}
   179  
   180  	return true, nil
   181  }
   182  
   183  // maps directory to vcsRoot
   184  var vcsRootCache = make(map[string]string)
   185  
   186  func vcsRootImportPath(f string) (string, error) {
   187  	path, err := filepath.Abs(f)
   188  	if err != nil {
   189  		return "", err
   190  	}
   191  	dir := filepath.Dir(path)
   192  	if root, ok := vcsRootCache[dir]; ok {
   193  		return root, nil
   194  	}
   195  	gopath := build.Default.GOPATH
   196  	var root string
   197  	_, root, err = vcs.FromDir(dir, filepath.Join(gopath, "src"))
   198  	if err != nil {
   199  		return "", err
   200  	}
   201  	vcsRootCache[dir] = root
   202  	return root, nil
   203  }
   204  
   205  func main() {
   206  	writeFile := flag.Bool("w", false, "write result to file instead of stdout")
   207  	listDiffFiles := flag.Bool("l", false, "list files whose formatting differs from importsort")
   208  	var sections multistring
   209  	flag.Var(&sections, "s", "package `prefix` to define an import section,"+
   210  		` ex: "cvshub.com/company". May be specified multiple times.`+
   211  		" If not specified the repository root is used.")
   212  
   213  	flag.Parse()
   214  
   215  	checkVCSRoot := sections == nil
   216  	for _, f := range flag.Args() {
   217  		if checkVCSRoot {
   218  			root, err := vcsRootImportPath(f)
   219  			if err != nil {
   220  				fmt.Fprintf(os.Stderr, "error determining VCS root for file %q: %s", f, err)
   221  				continue
   222  			} else {
   223  				sections = multistring{root}
   224  			}
   225  		}
   226  		diff, err := processFile(f, *writeFile, *listDiffFiles, sections)
   227  		if err != nil {
   228  			fmt.Fprintf(os.Stderr, "error while proccessing file %q: %s", f, err)
   229  			continue
   230  		}
   231  		if *listDiffFiles && diff {
   232  			fmt.Println(f)
   233  		}
   234  	}
   235  }
   236  
   237  type multistring []string
   238  
   239  func (m *multistring) String() string {
   240  	return strings.Join(*m, ", ")
   241  }
   242  func (m *multistring) Set(s string) error {
   243  	*m = append(*m, s)
   244  	return nil
   245  }