gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_generics/main.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // go_generics reads a Go source file and writes a new version of that file with
    16  // a few transformations applied to each. Namely:
    17  //
    18  //  1. Global types can be explicitly renamed with the -t option. For example,
    19  //     if -t=A=B is passed in, all references to A will be replaced with
    20  //     references to B; a function declaration like:
    21  //
    22  //     func f(arg *A)
    23  //
    24  //     would be renamed to:
    25  //
    26  //     func f(arg *B)
    27  //
    28  //  2. Global type definitions and their method sets will be removed when they're
    29  //     being renamed with -t. For example, if -t=A=B is passed in, the following
    30  //     definition and methods that existed in the input file wouldn't exist at
    31  //     all in the output file:
    32  //
    33  //     type A struct{}
    34  //
    35  //     func (*A) f() {}
    36  //
    37  //  3. All global types, variables, constants and functions (not methods) are
    38  //     prefixed and suffixed based on the option -prefix and -suffix arguments.
    39  //     For example, if -suffix=A is passed in, the following globals:
    40  //
    41  //     func f()
    42  //     type t struct{}
    43  //
    44  //     would be renamed to:
    45  //
    46  //     func fA()
    47  //     type tA struct{}
    48  //
    49  //     Some special tags are also modified. For example:
    50  //
    51  //     "state:.(t)"
    52  //
    53  //     would become:
    54  //
    55  //     "state:.(tA)"
    56  //
    57  // 4. The package is renamed to the value via the -p argument.
    58  // 5. Value of constants can be modified with -c argument.
    59  //
    60  // Note that not just the top-level declarations are renamed, all references to
    61  // them are also properly renamed as well, taking into account visibility rules
    62  // and shadowing. For example, if -suffix=A is passed in, the following:
    63  //
    64  // var b = 100
    65  //
    66  //	func f() {
    67  //		g(b)
    68  //		b := 0
    69  //		g(b)
    70  //	}
    71  //
    72  // Would be replaced with:
    73  //
    74  // var bA = 100
    75  //
    76  //	func f() {
    77  //		g(bA)
    78  //		b := 0
    79  //		g(b)
    80  //	}
    81  //
    82  // Note that the second call to g() kept "b" as an argument because it refers to
    83  // the local variable "b".
    84  //
    85  // Note that go_generics can handle anonymous fields with renamed types if
    86  // -anon is passed in, however it does not perform strict checking on parameter
    87  // types that share the same name as the global type and therefore will rename
    88  // them as well.
    89  //
    90  // You can see an example in the tools/go_generics/generics_tests/interface test.
    91  package main
    92  
    93  import (
    94  	"bytes"
    95  	"flag"
    96  	"fmt"
    97  	"go/ast"
    98  	"go/format"
    99  	"go/parser"
   100  	"go/token"
   101  	"io/ioutil"
   102  	"os"
   103  	"regexp"
   104  	"strings"
   105  
   106  	"gvisor.dev/gvisor/tools/go_generics/globals"
   107  )
   108  
   109  var (
   110  	input        = flag.String("i", "", "input `file`")
   111  	output       = flag.String("o", "", "output `file`")
   112  	suffix       = flag.String("suffix", "", "`suffix` to add to each global symbol")
   113  	prefix       = flag.String("prefix", "", "`prefix` to add to each global symbol")
   114  	packageName  = flag.String("p", "main", "output package `name`")
   115  	printAST     = flag.Bool("ast", false, "prints the AST")
   116  	processAnon  = flag.Bool("anon", false, "process anonymous fields")
   117  	types        = make(mapValue)
   118  	consts       = make(mapValue)
   119  	imports      = make(mapValue)
   120  	inputSubstr  = make(mapValue)
   121  	outputSubstr = make(mapValue)
   122  )
   123  
   124  // mapValue implements flag.Value. We use a mapValue flag instead of a regular
   125  // string flag when we want to allow more than one instance of the flag. For
   126  // example, we allow several "-t A=B" arguments, and will rename them all.
   127  type mapValue map[string]string
   128  
   129  func (m mapValue) String() string {
   130  	var b bytes.Buffer
   131  	first := true
   132  	for k, v := range m {
   133  		if !first {
   134  			b.WriteRune(',')
   135  		} else {
   136  			first = false
   137  		}
   138  		b.WriteString(k)
   139  		b.WriteRune('=')
   140  		b.WriteString(v)
   141  	}
   142  	return b.String()
   143  }
   144  
   145  func (m mapValue) Set(s string) error {
   146  	sep := strings.Index(s, "=")
   147  	if sep == -1 {
   148  		return fmt.Errorf("missing '=' from '%s'", s)
   149  	}
   150  
   151  	m[s[:sep]] = s[sep+1:]
   152  
   153  	return nil
   154  }
   155  
   156  // stateTagRegexp matches against the 'typed' state tags.
   157  var stateTagRegexp = regexp.MustCompile(`^(.*[^a-z0-9_])state:"\.\(([^\)]*)\)"(.*)$`)
   158  
   159  var identifierRegexp = regexp.MustCompile(`^(.*[^a-zA-Z_])([a-zA-Z_][a-zA-Z0-9_]*)(.*)$`)
   160  
   161  func main() {
   162  	flag.Usage = func() {
   163  		fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
   164  		flag.PrintDefaults()
   165  	}
   166  
   167  	flag.Var(types, "t", "rename type A to B when `A=B` is passed in. Multiple such mappings are allowed.")
   168  	flag.Var(consts, "c", "reassign constant A to value B when `A=B` is passed in. Multiple such mappings are allowed.")
   169  	flag.Var(imports, "import", "specifies the import libraries to use when types are not local. `name=path` specifies that 'name', used in types as name.type, refers to the package living in 'path'.")
   170  	flag.Var(inputSubstr, "in-substr", "replace input sub-string A with B when `A=B` is passed in. Multiple such mappings are allowed.")
   171  	flag.Var(outputSubstr, "out-substr", "replace output sub-string A with B when `A=B` is passed in. Multiple such mappings are allowed.")
   172  	flag.Parse()
   173  
   174  	if *input == "" || *output == "" {
   175  		flag.Usage()
   176  		os.Exit(1)
   177  	}
   178  
   179  	// Parse the input file.
   180  	fset := token.NewFileSet()
   181  	inputBytes, err := os.ReadFile(*input)
   182  	if err != nil {
   183  		fmt.Fprintf(os.Stderr, "%v\n", err)
   184  		os.Exit(1)
   185  	}
   186  	for old, new := range inputSubstr {
   187  		inputBytes = bytes.ReplaceAll(inputBytes, []byte(old), []byte(new))
   188  	}
   189  	f, err := parser.ParseFile(fset, *input, inputBytes, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors)
   190  	if err != nil {
   191  		fmt.Fprintf(os.Stderr, "%v\n", err)
   192  		os.Exit(1)
   193  	}
   194  
   195  	// Print the AST if requested.
   196  	if *printAST {
   197  		ast.Print(fset, f)
   198  	}
   199  
   200  	cmap := ast.NewCommentMap(fset, f, f.Comments)
   201  
   202  	// Update imports based on what's used in types and consts.
   203  	maps := []mapValue{types, consts}
   204  	importDecl, err := updateImports(maps, imports)
   205  	if err != nil {
   206  		fmt.Fprintf(os.Stderr, "%v\n", err)
   207  		os.Exit(1)
   208  	}
   209  	types = maps[0]
   210  	consts = maps[1]
   211  
   212  	// Reassign all specified constants.
   213  	for _, decl := range f.Decls {
   214  		d, ok := decl.(*ast.GenDecl)
   215  		if !ok || d.Tok != token.CONST {
   216  			continue
   217  		}
   218  
   219  		for _, gs := range d.Specs {
   220  			s := gs.(*ast.ValueSpec)
   221  			for i, id := range s.Names {
   222  				if n, ok := consts[id.Name]; ok {
   223  					s.Values[i] = &ast.BasicLit{Value: n}
   224  				}
   225  			}
   226  		}
   227  	}
   228  
   229  	// Go through all globals and their uses in the AST and rename the types
   230  	// with explicitly provided names, and rename all types, variables,
   231  	// consts and functions with the provided prefix and suffix.
   232  	globals.Visit(fset, f, func(ident *ast.Ident, kind globals.SymKind) {
   233  		if n, ok := types[ident.Name]; ok && kind == globals.KindType {
   234  			ident.Name = n
   235  		} else {
   236  			switch kind {
   237  			case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
   238  				if ident.Name != "_" && !(ident.Name == "init" && kind == globals.KindFunction) {
   239  					ident.Name = *prefix + ident.Name + *suffix
   240  				}
   241  			case globals.KindTag:
   242  				// Modify the state tag appropriately.
   243  				if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil {
   244  					if t := identifierRegexp.FindStringSubmatch(m[2]); t != nil {
   245  						typeName := *prefix + t[2] + *suffix
   246  						if n, ok := types[t[2]]; ok {
   247  							typeName = n
   248  						}
   249  						ident.Name = m[1] + `state:".(` + t[1] + typeName + t[3] + `)"` + m[3]
   250  					}
   251  				}
   252  			}
   253  		}
   254  	}, *processAnon)
   255  
   256  	// Remove the definition of all types that are being remapped.
   257  	set := make(typeSet)
   258  	for _, v := range types {
   259  		set[v] = struct{}{}
   260  	}
   261  	removeTypes(set, f)
   262  
   263  	// Add the new imports, if any, to the top.
   264  	if importDecl != nil {
   265  		newDecls := make([]ast.Decl, 0, len(f.Decls)+1)
   266  		newDecls = append(newDecls, importDecl)
   267  		newDecls = append(newDecls, f.Decls...)
   268  		f.Decls = newDecls
   269  	}
   270  
   271  	// Update comments to remove the ones potentially associated with the
   272  	// type T that we removed.
   273  	f.Comments = cmap.Filter(f).Comments()
   274  
   275  	// If there are file (package) comments, delete them.
   276  	if f.Doc != nil {
   277  		for i, cg := range f.Comments {
   278  			if cg == f.Doc {
   279  				f.Comments = append(f.Comments[:i], f.Comments[i+1:]...)
   280  				break
   281  			}
   282  		}
   283  	}
   284  
   285  	// Write the output file.
   286  	f.Name.Name = *packageName
   287  
   288  	var buf bytes.Buffer
   289  	if err := format.Node(&buf, fset, f); err != nil {
   290  		fmt.Fprintf(os.Stderr, "%v\n", err)
   291  		os.Exit(1)
   292  	}
   293  
   294  	byteBuf := buf.Bytes()
   295  	for old, new := range outputSubstr {
   296  		byteBuf = bytes.ReplaceAll(byteBuf, []byte(old), []byte(new))
   297  	}
   298  
   299  	if err := ioutil.WriteFile(*output, byteBuf, 0644); err != nil {
   300  		fmt.Fprintf(os.Stderr, "%v\n", err)
   301  		os.Exit(1)
   302  	}
   303  }