go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/internal/svctool/tool.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package svctool implements svcmux/svcdec tools command line parsing
    16  package svctool
    17  
    18  import (
    19  	"bytes"
    20  	"context"
    21  	"flag"
    22  	"fmt"
    23  	"go/ast"
    24  	"go/build"
    25  	"go/format"
    26  	"go/token"
    27  	"io"
    28  	"os"
    29  	"path/filepath"
    30  	"sort"
    31  	"strings"
    32  
    33  	"go.chromium.org/luci/common/logging/gologger"
    34  )
    35  
    36  // Service contains the result of parsing the generated code for a pRPC service.
    37  type Service struct {
    38  	TypeName string
    39  	Node     *ast.InterfaceType
    40  	Methods  []*Method
    41  }
    42  
    43  type Method struct {
    44  	Name       string
    45  	Node       *ast.Field
    46  	InputType  string
    47  	OutputType string
    48  }
    49  
    50  type Import struct {
    51  	Name string
    52  	Path string
    53  }
    54  
    55  // Tool is a helper class for svcmux and svcdec.
    56  type Tool struct {
    57  	// Name of the tool, e.g. "svcmux" or "svcdec".
    58  	Name string
    59  	// OutputFilenameSuffix is the suffix of generated file names,
    60  	// e.g. "mux" or "dec" for foo_mux.go or foo_dec.go.
    61  	OutputFilenameSuffix string
    62  
    63  	// Set by ParseArgs from command-line arguments.
    64  
    65  	// Types are type names from the Go package defined by Dir or FileNames.
    66  	Types []string
    67  	// Output is the base name for the output file.
    68  	Output string
    69  	// Dir is a Go package's directory.
    70  	Dir string
    71  	// FileNames is a list of source files from a single Go package.
    72  	FileNames []string
    73  }
    74  
    75  func (t *Tool) usage() {
    76  	fmt.Fprintf(os.Stderr, "Usage of %s:\n", t.Name)
    77  	fmt.Fprintf(os.Stderr, "\t%s [flags] -type T [directory]\n", t.Name)
    78  	fmt.Fprintf(os.Stderr, "\t%s [flags] -type T files... # Must be a single package\n", t.Name)
    79  	flag.PrintDefaults()
    80  }
    81  
    82  func (t *Tool) parseFlags(args []string) []string {
    83  	var flags = flag.NewFlagSet(t.Name, flag.ExitOnError)
    84  	typeFlag := flags.String("type", "", "comma-separated list of type names; must be set")
    85  	flags.StringVar(&t.Output, "output", "", "output file name; default <type>_string.go")
    86  	flags.Usage = t.usage
    87  	flags.Parse(args)
    88  
    89  	splitTypes := strings.Split(*typeFlag, ",")
    90  	t.Types = make([]string, 0, len(splitTypes))
    91  	for _, typ := range splitTypes {
    92  		typ = strings.TrimSpace(typ)
    93  		if typ != "" {
    94  			t.Types = append(t.Types, typ)
    95  		}
    96  	}
    97  	if len(t.Types) == 0 {
    98  		fmt.Fprintln(os.Stderr, "type is not specified")
    99  		flags.Usage()
   100  		os.Exit(2)
   101  	}
   102  	return flags.Args()
   103  }
   104  
   105  // ParseArgs parses command arguments. Exits if they are invalid.
   106  func (t *Tool) ParseArgs(args []string) {
   107  	args = t.parseFlags(args)
   108  
   109  	switch len(args) {
   110  	case 0:
   111  		args = []string{"."}
   112  		fallthrough
   113  
   114  	case 1:
   115  		info, err := os.Stat(args[0])
   116  		if err != nil {
   117  			fmt.Fprintln(os.Stderr, err)
   118  			os.Exit(2)
   119  		}
   120  		if info.IsDir() {
   121  			t.Dir = args[0]
   122  			t.FileNames, err = goFilesIn(args[0])
   123  			if err != nil {
   124  				fmt.Fprintln(os.Stderr, err)
   125  				os.Exit(2)
   126  			}
   127  			break
   128  		}
   129  		fallthrough
   130  
   131  	default:
   132  		t.Dir = filepath.Dir(args[0])
   133  		t.FileNames = args
   134  	}
   135  }
   136  
   137  // GeneratorArgs is passed to the function responsible for generating files.
   138  type GeneratorArgs struct {
   139  	PackageName  string
   140  	Services     []*Service
   141  	ExtraImports []Import
   142  	Out          io.Writer
   143  }
   144  type Generator func(ctx context.Context, a *GeneratorArgs) error
   145  
   146  // importSorted converts a map name -> path to []Import sorted by name.
   147  func importSorted(imports map[string]string) []Import {
   148  	names := make([]string, 0, len(imports))
   149  	for n := range imports {
   150  		names = append(names, n)
   151  	}
   152  	sort.Strings(names)
   153  	result := make([]Import, len(names))
   154  	for i, n := range names {
   155  		result[i] = Import{n, imports[n]}
   156  	}
   157  	return result
   158  }
   159  
   160  // Run parses Go files and generates a new file using f.
   161  func (t *Tool) Run(ctx context.Context, f Generator) error {
   162  	// Validate arguments.
   163  	if len(t.FileNames) == 0 {
   164  		return fmt.Errorf("files not specified")
   165  	}
   166  	if len(t.Types) == 0 {
   167  		return fmt.Errorf("types not specified")
   168  	}
   169  
   170  	// Determine output file name.
   171  	outputName := t.Output
   172  	if outputName == "" {
   173  		if t.Dir == "" {
   174  			return fmt.Errorf("neither output not dir are specified")
   175  		}
   176  		baseName := fmt.Sprintf("%s_%s.go", t.Types[0], t.OutputFilenameSuffix)
   177  		outputName = filepath.Join(t.Dir, strings.ToLower(baseName))
   178  	}
   179  
   180  	// Parse Go files and resolve specified types.
   181  	p := &parser{
   182  		fileSet: token.NewFileSet(),
   183  		types:   t.Types,
   184  	}
   185  	if err := p.parsePackage(t.FileNames); err != nil {
   186  		return fmt.Errorf("could not parse .go files: %s", err)
   187  	}
   188  	if err := p.resolveServices(ctx); err != nil {
   189  		return err
   190  	}
   191  
   192  	// Run the generator.
   193  	var buf bytes.Buffer
   194  	genArgs := &GeneratorArgs{
   195  		PackageName:  p.files[0].Name.Name,
   196  		Services:     p.services,
   197  		ExtraImports: importSorted(p.extraImports),
   198  		Out:          &buf,
   199  	}
   200  	if err := f(ctx, genArgs); err != nil {
   201  		return err
   202  	}
   203  
   204  	// Format the output.
   205  	src, err := format.Source(buf.Bytes())
   206  	if err != nil {
   207  		println(buf.String())
   208  		return fmt.Errorf("gofmt: %s", err)
   209  	}
   210  
   211  	// Write to file.
   212  	return os.WriteFile(outputName, src, 0644)
   213  }
   214  
   215  // Main does some setup (arg parsing, logging), calls t.Run, prints any errors
   216  // and exits.
   217  func (t *Tool) Main(args []string, f Generator) {
   218  	c := gologger.StdConfig.Use(context.Background())
   219  	t.ParseArgs(args)
   220  
   221  	if err := t.Run(c, f); err != nil {
   222  		fmt.Fprintln(os.Stderr, err.Error())
   223  		os.Exit(1)
   224  	}
   225  	os.Exit(0)
   226  }
   227  
   228  // goFilesIn lists .go files in dir.
   229  func goFilesIn(dir string) ([]string, error) {
   230  	pkg, err := build.ImportDir(dir, 0)
   231  	if err != nil {
   232  		return nil, fmt.Errorf("cannot process directory %s: %s", dir, err)
   233  	}
   234  	var names []string
   235  	names = append(names, pkg.GoFiles...)
   236  	names = append(names, pkg.CgoFiles...)
   237  	names = prefixDirectory(dir, names)
   238  	return names, nil
   239  }
   240  
   241  // prefixDirectory places the directory name on the beginning of each name in the list.
   242  func prefixDirectory(directory string, names []string) []string {
   243  	if directory == "." {
   244  		return names
   245  	}
   246  	ret := make([]string, len(names))
   247  	for i, name := range names {
   248  		ret[i] = filepath.Join(directory, name)
   249  	}
   250  	return ret
   251  }