github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_generics/go_merge/main.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 import ( 18 "bytes" 19 "flag" 20 "fmt" 21 "go/ast" 22 "go/format" 23 "go/parser" 24 "go/token" 25 "os" 26 "path/filepath" 27 "strconv" 28 "strings" 29 30 "github.com/SagerNet/gvisor/tools/tags" 31 ) 32 33 var ( 34 output = flag.String("o", "", "output `file`") 35 ) 36 37 func fatalf(s string, args ...interface{}) { 38 fmt.Fprintf(os.Stderr, s, args...) 39 os.Exit(1) 40 } 41 42 func main() { 43 flag.Usage = func() { 44 fmt.Fprintf(os.Stderr, "Usage: %s [options] <input1> [<input2> ...]\n", os.Args[0]) 45 flag.PrintDefaults() 46 } 47 48 flag.Parse() 49 if *output == "" || len(flag.Args()) == 0 { 50 flag.Usage() 51 os.Exit(1) 52 } 53 54 // Load all files. 55 files := make(map[string]*ast.File) 56 fset := token.NewFileSet() 57 var name string 58 for _, fname := range flag.Args() { 59 f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) 60 if err != nil { 61 fatalf("%v\n", err) 62 } 63 64 files[fname] = f 65 if name == "" { 66 name = f.Name.Name 67 } else if name != f.Name.Name { 68 fatalf("Expected '%s' for package name instead of '%s'.\n", name, f.Name.Name) 69 } 70 } 71 72 // Merge all files into one. 73 pkg := &ast.Package{ 74 Name: name, 75 Files: files, 76 } 77 f := ast.MergePackageFiles(pkg, ast.FilterUnassociatedComments|ast.FilterFuncDuplicates|ast.FilterImportDuplicates) 78 79 // Create a new declaration slice with all imports at the top, merging any 80 // redundant imports. 81 imports := make(map[string]*ast.ImportSpec) 82 var importNames []string // Keep imports in the original order to get deterministic output. 83 var anonImports []*ast.ImportSpec 84 for _, d := range f.Decls { 85 if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT { 86 for _, s := range g.Specs { 87 i := s.(*ast.ImportSpec) 88 p, _ := strconv.Unquote(i.Path.Value) 89 var n string 90 if i.Name == nil { 91 n = filepath.Base(p) 92 } else { 93 n = i.Name.Name 94 } 95 if n == "_" { 96 anonImports = append(anonImports, i) 97 } else { 98 if i2, ok := imports[n]; ok { 99 if first, second := i.Path.Value, i2.Path.Value; first != second { 100 fatalf("Conflicting paths for import name '%s': '%s' vs. '%s'\n", n, first, second) 101 } 102 } else { 103 imports[n] = i 104 importNames = append(importNames, n) 105 } 106 } 107 } 108 } 109 } 110 newDecls := make([]ast.Decl, 0, len(f.Decls)) 111 if l := len(imports) + len(anonImports); l > 0 { 112 // Non-NoPos Lparen is needed for Go to recognize more than one spec in 113 // ast.GenDecl.Specs. 114 d := &ast.GenDecl{ 115 Tok: token.IMPORT, 116 Lparen: token.NoPos + 1, 117 Specs: make([]ast.Spec, 0, l), 118 } 119 for _, i := range importNames { 120 d.Specs = append(d.Specs, imports[i]) 121 } 122 for _, i := range anonImports { 123 d.Specs = append(d.Specs, i) 124 } 125 newDecls = append(newDecls, d) 126 } 127 for _, d := range f.Decls { 128 if g, ok := d.(*ast.GenDecl); !ok || g.Tok != token.IMPORT { 129 newDecls = append(newDecls, d) 130 } 131 } 132 f.Decls = newDecls 133 134 // Write the output file. 135 var buf bytes.Buffer 136 if err := format.Node(&buf, fset, f); err != nil { 137 fatalf("fomatting: %v\n", err) 138 } 139 outf, err := os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 140 if err != nil { 141 fatalf("opening output: %v\n", err) 142 } 143 defer outf.Close() 144 if t := tags.Aggregate(flag.Args()); len(t) > 0 { 145 fmt.Fprintf(outf, "%s\n\n", strings.Join(t.Lines(), "\n")) 146 } 147 if _, err := outf.Write(buf.Bytes()); err != nil { 148 fatalf("write: %v\n", err) 149 } 150 }