go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/internal/svctool/tool.go (about) 1 // Copyright 2016 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package svctool implements svcmux/svcdec tools command line parsing 16 package svctool 17 18 import ( 19 "bytes" 20 "context" 21 "flag" 22 "fmt" 23 "go/ast" 24 "go/build" 25 "go/format" 26 "go/token" 27 "io" 28 "os" 29 "path/filepath" 30 "sort" 31 "strings" 32 33 "go.chromium.org/luci/common/logging/gologger" 34 ) 35 36 // Service contains the result of parsing the generated code for a pRPC service. 37 type Service struct { 38 TypeName string 39 Node *ast.InterfaceType 40 Methods []*Method 41 } 42 43 type Method struct { 44 Name string 45 Node *ast.Field 46 InputType string 47 OutputType string 48 } 49 50 type Import struct { 51 Name string 52 Path string 53 } 54 55 // Tool is a helper class for svcmux and svcdec. 56 type Tool struct { 57 // Name of the tool, e.g. "svcmux" or "svcdec". 58 Name string 59 // OutputFilenameSuffix is the suffix of generated file names, 60 // e.g. "mux" or "dec" for foo_mux.go or foo_dec.go. 61 OutputFilenameSuffix string 62 63 // Set by ParseArgs from command-line arguments. 64 65 // Types are type names from the Go package defined by Dir or FileNames. 66 Types []string 67 // Output is the base name for the output file. 68 Output string 69 // Dir is a Go package's directory. 70 Dir string 71 // FileNames is a list of source files from a single Go package. 72 FileNames []string 73 } 74 75 func (t *Tool) usage() { 76 fmt.Fprintf(os.Stderr, "Usage of %s:\n", t.Name) 77 fmt.Fprintf(os.Stderr, "\t%s [flags] -type T [directory]\n", t.Name) 78 fmt.Fprintf(os.Stderr, "\t%s [flags] -type T files... # Must be a single package\n", t.Name) 79 flag.PrintDefaults() 80 } 81 82 func (t *Tool) parseFlags(args []string) []string { 83 var flags = flag.NewFlagSet(t.Name, flag.ExitOnError) 84 typeFlag := flags.String("type", "", "comma-separated list of type names; must be set") 85 flags.StringVar(&t.Output, "output", "", "output file name; default <type>_string.go") 86 flags.Usage = t.usage 87 flags.Parse(args) 88 89 splitTypes := strings.Split(*typeFlag, ",") 90 t.Types = make([]string, 0, len(splitTypes)) 91 for _, typ := range splitTypes { 92 typ = strings.TrimSpace(typ) 93 if typ != "" { 94 t.Types = append(t.Types, typ) 95 } 96 } 97 if len(t.Types) == 0 { 98 fmt.Fprintln(os.Stderr, "type is not specified") 99 flags.Usage() 100 os.Exit(2) 101 } 102 return flags.Args() 103 } 104 105 // ParseArgs parses command arguments. Exits if they are invalid. 106 func (t *Tool) ParseArgs(args []string) { 107 args = t.parseFlags(args) 108 109 switch len(args) { 110 case 0: 111 args = []string{"."} 112 fallthrough 113 114 case 1: 115 info, err := os.Stat(args[0]) 116 if err != nil { 117 fmt.Fprintln(os.Stderr, err) 118 os.Exit(2) 119 } 120 if info.IsDir() { 121 t.Dir = args[0] 122 t.FileNames, err = goFilesIn(args[0]) 123 if err != nil { 124 fmt.Fprintln(os.Stderr, err) 125 os.Exit(2) 126 } 127 break 128 } 129 fallthrough 130 131 default: 132 t.Dir = filepath.Dir(args[0]) 133 t.FileNames = args 134 } 135 } 136 137 // GeneratorArgs is passed to the function responsible for generating files. 138 type GeneratorArgs struct { 139 PackageName string 140 Services []*Service 141 ExtraImports []Import 142 Out io.Writer 143 } 144 type Generator func(ctx context.Context, a *GeneratorArgs) error 145 146 // importSorted converts a map name -> path to []Import sorted by name. 147 func importSorted(imports map[string]string) []Import { 148 names := make([]string, 0, len(imports)) 149 for n := range imports { 150 names = append(names, n) 151 } 152 sort.Strings(names) 153 result := make([]Import, len(names)) 154 for i, n := range names { 155 result[i] = Import{n, imports[n]} 156 } 157 return result 158 } 159 160 // Run parses Go files and generates a new file using f. 161 func (t *Tool) Run(ctx context.Context, f Generator) error { 162 // Validate arguments. 163 if len(t.FileNames) == 0 { 164 return fmt.Errorf("files not specified") 165 } 166 if len(t.Types) == 0 { 167 return fmt.Errorf("types not specified") 168 } 169 170 // Determine output file name. 171 outputName := t.Output 172 if outputName == "" { 173 if t.Dir == "" { 174 return fmt.Errorf("neither output not dir are specified") 175 } 176 baseName := fmt.Sprintf("%s_%s.go", t.Types[0], t.OutputFilenameSuffix) 177 outputName = filepath.Join(t.Dir, strings.ToLower(baseName)) 178 } 179 180 // Parse Go files and resolve specified types. 181 p := &parser{ 182 fileSet: token.NewFileSet(), 183 types: t.Types, 184 } 185 if err := p.parsePackage(t.FileNames); err != nil { 186 return fmt.Errorf("could not parse .go files: %s", err) 187 } 188 if err := p.resolveServices(ctx); err != nil { 189 return err 190 } 191 192 // Run the generator. 193 var buf bytes.Buffer 194 genArgs := &GeneratorArgs{ 195 PackageName: p.files[0].Name.Name, 196 Services: p.services, 197 ExtraImports: importSorted(p.extraImports), 198 Out: &buf, 199 } 200 if err := f(ctx, genArgs); err != nil { 201 return err 202 } 203 204 // Format the output. 205 src, err := format.Source(buf.Bytes()) 206 if err != nil { 207 println(buf.String()) 208 return fmt.Errorf("gofmt: %s", err) 209 } 210 211 // Write to file. 212 return os.WriteFile(outputName, src, 0644) 213 } 214 215 // Main does some setup (arg parsing, logging), calls t.Run, prints any errors 216 // and exits. 217 func (t *Tool) Main(args []string, f Generator) { 218 c := gologger.StdConfig.Use(context.Background()) 219 t.ParseArgs(args) 220 221 if err := t.Run(c, f); err != nil { 222 fmt.Fprintln(os.Stderr, err.Error()) 223 os.Exit(1) 224 } 225 os.Exit(0) 226 } 227 228 // goFilesIn lists .go files in dir. 229 func goFilesIn(dir string) ([]string, error) { 230 pkg, err := build.ImportDir(dir, 0) 231 if err != nil { 232 return nil, fmt.Errorf("cannot process directory %s: %s", dir, err) 233 } 234 var names []string 235 names = append(names, pkg.GoFiles...) 236 names = append(names, pkg.CgoFiles...) 237 names = prefixDirectory(dir, names) 238 return names, nil 239 } 240 241 // prefixDirectory places the directory name on the beginning of each name in the list. 242 func prefixDirectory(directory string, names []string) []string { 243 if directory == "." { 244 return names 245 } 246 ret := make([]string, len(names)) 247 for i, name := range names { 248 ret[i] = filepath.Join(directory, name) 249 } 250 return ret 251 }