github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/cmd/importsort/main.go (about) 1 // Copyright (c) 2017 Arista Networks, Inc. 2 // Use of this source code is governed by the Apache License 2.0 3 // that can be found in the COPYING file. 4 5 package main 6 7 import ( 8 "bytes" 9 "errors" 10 "flag" 11 "fmt" 12 "go/build" 13 "io/ioutil" 14 "os" 15 "path/filepath" 16 "sort" 17 "strings" 18 19 "golang.org/x/tools/go/vcs" 20 ) 21 22 // Implementation taken from "isStandardImportPath" in go's source. 23 func isStdLibPath(path string) bool { 24 i := strings.Index(path, "/") 25 if i < 0 { 26 i = len(path) 27 } 28 elem := path[:i] 29 return !strings.Contains(elem, ".") 30 } 31 32 // sortImports takes in an "import" body and returns it sorted 33 func sortImports(in []byte, sections []string) []byte { 34 type importLine struct { 35 index int // index into inLines 36 path string // import path used for sorting 37 } 38 // imports holds all the import lines, separated by section. The 39 // first section is for stdlib imports, the following sections 40 // hold the user specified sections, the final section is for 41 // everything else. 42 imports := make([][]importLine, len(sections)+2) 43 addImport := func(section, index int, importPath string) { 44 imports[section] = append(imports[section], importLine{index, importPath}) 45 } 46 stdlib := 0 47 offset := 1 48 other := len(imports) - 1 49 50 inLines := bytes.Split(in, []byte{'\n'}) 51 for i, line := range inLines { 52 if len(line) == 0 { 53 continue 54 } 55 start := bytes.IndexByte(line, '"') 56 if start == -1 { 57 continue 58 } 59 if comment := bytes.Index(line, []byte("//")); comment > -1 && comment < start { 60 continue 61 } 62 63 start++ // skip '"' 64 end := bytes.IndexByte(line[start:], '"') + start 65 s := string(line[start:end]) 66 67 found := false 68 for j, sect := range sections { 69 if strings.HasPrefix(s, sect) && (len(sect) == len(s) || s[len(sect)] == '/') { 70 addImport(j+offset, i, s) 71 found = true 72 break 73 } 74 } 75 if found { 76 continue 77 } 78 79 if isStdLibPath(s) { 80 addImport(stdlib, i, s) 81 } else { 82 addImport(other, i, s) 83 } 84 } 85 86 out := make([]byte, 0, len(in)+2) 87 needSeperator := false 88 for _, section := range imports { 89 if len(section) == 0 { 90 continue 91 } 92 if needSeperator { 93 out = append(out, '\n') 94 } 95 sort.Slice(section, func(a, b int) bool { 96 return section[a].path < section[b].path 97 }) 98 for _, s := range section { 99 out = append(out, inLines[s.index]...) 100 out = append(out, '\n') 101 } 102 needSeperator = true 103 } 104 105 return out 106 } 107 108 func genFile(in []byte, sections []string) ([]byte, error) { 109 out := make([]byte, 0, len(in)+3) // Add some fudge to avoid re-allocation 110 111 for { 112 const importLine = "\nimport (\n" 113 const importLineLen = len(importLine) 114 importStart := bytes.Index(in, []byte(importLine)) 115 if importStart == -1 { 116 break 117 } 118 // Save to `out` everything up to and including "import(\n" 119 out = append(out, in[:importStart+importLineLen]...) 120 in = in[importStart+importLineLen:] 121 importLen := bytes.Index(in, []byte("\n)\n")) 122 if importLen == -1 { 123 return nil, errors.New(`parsing error: missing ")"`) 124 } 125 // Sort body of "import" and write it to `out` 126 out = append(out, sortImports(in[:importLen], sections)...) 127 out = append(out, []byte(")")...) 128 in = in[importLen+2:] 129 } 130 // Write everything leftover to out 131 out = append(out, in...) 132 return out, nil 133 } 134 135 // returns true if the file changed 136 func processFile(filename string, writeFile, listDiffFiles bool, sections []string) (bool, error) { 137 in, err := ioutil.ReadFile(filename) 138 if err != nil { 139 return false, err 140 } 141 out, err := genFile(in, sections) 142 if err != nil { 143 return false, err 144 } 145 146 equal := bytes.Equal(in, out) 147 if listDiffFiles { 148 return !equal, nil 149 } 150 if !writeFile { 151 os.Stdout.Write(out) 152 return !equal, nil 153 } 154 155 if equal { 156 return false, nil 157 } 158 temp, err := ioutil.TempFile(filepath.Dir(filename), filepath.Base(filename)) 159 if err != nil { 160 return false, err 161 } 162 defer os.RemoveAll(temp.Name()) 163 s, err := os.Stat(filename) 164 if err != nil { 165 return false, err 166 } 167 if _, err = temp.Write(out); err != nil { 168 return false, err 169 } 170 if err := temp.Close(); err != nil { 171 return false, err 172 } 173 if err := os.Chmod(temp.Name(), s.Mode()); err != nil { 174 return false, err 175 } 176 if err := os.Rename(temp.Name(), filename); err != nil { 177 return false, err 178 } 179 180 return true, nil 181 } 182 183 // maps directory to vcsRoot 184 var vcsRootCache = make(map[string]string) 185 186 func vcsRootImportPath(f string) (string, error) { 187 path, err := filepath.Abs(f) 188 if err != nil { 189 return "", err 190 } 191 dir := filepath.Dir(path) 192 if root, ok := vcsRootCache[dir]; ok { 193 return root, nil 194 } 195 gopath := build.Default.GOPATH 196 var root string 197 _, root, err = vcs.FromDir(dir, filepath.Join(gopath, "src")) 198 if err != nil { 199 return "", err 200 } 201 vcsRootCache[dir] = root 202 return root, nil 203 } 204 205 func main() { 206 writeFile := flag.Bool("w", false, "write result to file instead of stdout") 207 listDiffFiles := flag.Bool("l", false, "list files whose formatting differs from importsort") 208 var sections multistring 209 flag.Var(§ions, "s", "package `prefix` to define an import section,"+ 210 ` ex: "cvshub.com/company". May be specified multiple times.`+ 211 " If not specified the repository root is used.") 212 213 flag.Parse() 214 215 checkVCSRoot := sections == nil 216 for _, f := range flag.Args() { 217 if checkVCSRoot { 218 root, err := vcsRootImportPath(f) 219 if err != nil { 220 fmt.Fprintf(os.Stderr, "error determining VCS root for file %q: %s", f, err) 221 continue 222 } else { 223 sections = multistring{root} 224 } 225 } 226 diff, err := processFile(f, *writeFile, *listDiffFiles, sections) 227 if err != nil { 228 fmt.Fprintf(os.Stderr, "error while proccessing file %q: %s", f, err) 229 continue 230 } 231 if *listDiffFiles && diff { 232 fmt.Println(f) 233 } 234 } 235 } 236 237 type multistring []string 238 239 func (m *multistring) String() string { 240 return strings.Join(*m, ", ") 241 } 242 func (m *multistring) Set(s string) error { 243 *m = append(*m, s) 244 return nil 245 }