gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_stateify/main.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Stateify provides a simple way to generate Load/Save methods based on
    16  // existing types and struct tags.
    17  package main
    18  
    19  import (
    20  	"flag"
    21  	"fmt"
    22  	"go/ast"
    23  	"go/parser"
    24  	"go/token"
    25  	"os"
    26  	"path/filepath"
    27  	"reflect"
    28  	"strings"
    29  	"sync"
    30  
    31  	"gvisor.dev/gvisor/tools/constraintutil"
    32  )
    33  
    34  var (
    35  	fullPkg  = flag.String("fullpkg", "", "fully qualified output package")
    36  	imports  = flag.String("imports", "", "extra imports for the output file")
    37  	output   = flag.String("output", "", "output file")
    38  	statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
    39  )
    40  
    41  // resolveTypeName returns a qualified type name.
    42  func resolveTypeName(typ ast.Expr) (field string, qualified string) {
    43  	for done := false; !done; {
    44  		// Resolve star expressions.
    45  		switch rs := typ.(type) {
    46  		case *ast.StarExpr:
    47  			qualified += "*"
    48  			typ = rs.X
    49  		case *ast.ArrayType:
    50  			if rs.Len == nil {
    51  				// Slice type declaration.
    52  				qualified += "[]"
    53  			} else {
    54  				// Array type declaration.
    55  				qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]"
    56  			}
    57  			typ = rs.Elt
    58  		default:
    59  			// No more descent.
    60  			done = true
    61  		}
    62  	}
    63  
    64  	// Resolve a package selector.
    65  	sel, ok := typ.(*ast.SelectorExpr)
    66  	if ok {
    67  		qualified = qualified + sel.X.(*ast.Ident).Name + "."
    68  		typ = sel.Sel
    69  	}
    70  
    71  	// Figure out actual type name.
    72  	field = typ.(*ast.Ident).Name
    73  	qualified = qualified + field
    74  	return
    75  }
    76  
    77  // extractStateTag pulls the relevant state tag.
    78  func extractStateTag(tag *ast.BasicLit) string {
    79  	if tag == nil {
    80  		return ""
    81  	}
    82  	if len(tag.Value) < 2 {
    83  		return ""
    84  	}
    85  	return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state")
    86  }
    87  
    88  // scanFunctions is a set of functions passed to scanFields.
    89  type scanFunctions struct {
    90  	zerovalue func(name string)
    91  	normal    func(name string)
    92  	wait      func(name string)
    93  	value     func(name, typName string)
    94  }
    95  
    96  // scanFields scans the fields of a struct.
    97  //
    98  // Each provided function will be applied to appropriately tagged fields, or
    99  // skipped if nil.
   100  //
   101  // Fields tagged nosave are skipped.
   102  func scanFields(ss *ast.StructType, fn scanFunctions) {
   103  	if ss.Fields.List == nil {
   104  		// No fields.
   105  		return
   106  	}
   107  
   108  	// Scan all fields.
   109  	for _, field := range ss.Fields.List {
   110  		if field.Names == nil {
   111  			// Anonymous types can't be embedded, so we don't need
   112  			// to worry about providing a useful name here.
   113  			name, _ := resolveTypeName(field.Type)
   114  			scanField(name, field, fn)
   115  			continue
   116  		}
   117  
   118  		// Iterate over potentially multiple fields defined on the same line.
   119  		for _, nameI := range field.Names {
   120  			name := nameI.Name
   121  			// Skip _ fields.
   122  			if name == "_" {
   123  				continue
   124  			}
   125  			scanField(name, field, fn)
   126  		}
   127  	}
   128  }
   129  
   130  // scanField scans a single struct field with a resolved name.
   131  func scanField(name string, field *ast.Field, fn scanFunctions) {
   132  	// Is this a anonymous struct? If yes, then continue the
   133  	// recursion with the given prefix. We don't pay attention to
   134  	// any tags on the top-level struct field.
   135  	tag := extractStateTag(field.Tag)
   136  	if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
   137  		scanFields(anon, fn)
   138  		return
   139  	}
   140  
   141  	switch tag {
   142  	case "zerovalue":
   143  		if fn.zerovalue != nil {
   144  			fn.zerovalue(name)
   145  		}
   146  
   147  	case "":
   148  		if fn.normal != nil {
   149  			fn.normal(name)
   150  		}
   151  
   152  	case "wait":
   153  		if fn.wait != nil {
   154  			fn.wait(name)
   155  		}
   156  
   157  	case "manual", "nosave", "ignore":
   158  		// Do nothing.
   159  
   160  	default:
   161  		if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") {
   162  			if fn.value != nil {
   163  				fn.value(name, tag[2:len(tag)-1])
   164  			}
   165  		}
   166  	}
   167  }
   168  
   169  func camelCased(name string) string {
   170  	return strings.ToUpper(name[:1]) + name[1:]
   171  }
   172  
   173  func main() {
   174  	// Parse flags.
   175  	flag.Usage = func() {
   176  		fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
   177  		flag.PrintDefaults()
   178  	}
   179  	flag.Parse()
   180  	if len(flag.Args()) == 0 {
   181  		flag.Usage()
   182  		os.Exit(1)
   183  	}
   184  	if *fullPkg == "" {
   185  		fmt.Fprintf(os.Stderr, "Error: package required.")
   186  		os.Exit(1)
   187  	}
   188  
   189  	// Open the output file.
   190  	var (
   191  		outputFile *os.File
   192  		err        error
   193  	)
   194  	if *output == "" || *output == "-" {
   195  		outputFile = os.Stdout
   196  	} else {
   197  		outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
   198  		if err != nil {
   199  			fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err)
   200  		}
   201  		defer outputFile.Close()
   202  	}
   203  
   204  	// Set the statePrefix for below, depending on the import.
   205  	statePrefix := ""
   206  	if *statePkg != "" {
   207  		parts := strings.Split(*statePkg, "/")
   208  		statePrefix = parts[len(parts)-1] + "."
   209  	}
   210  
   211  	// initCalls is dumped at the end.
   212  	var initCalls []string
   213  
   214  	// Common closures.
   215  	emitRegister := func(name string) {
   216  		initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
   217  	}
   218  
   219  	// Automated warning.
   220  	fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
   221  
   222  	// Emit build constraints.
   223  	bcexpr, err := constraintutil.CombineFromFiles(flag.Args())
   224  	if err != nil {
   225  		fmt.Fprintf(os.Stderr, "Failed to infer build constraints: %v", err)
   226  		os.Exit(1)
   227  	}
   228  	outputFile.WriteString(constraintutil.Lines(bcexpr))
   229  
   230  	// Emit the package name.
   231  	_, pkg := filepath.Split(*fullPkg)
   232  	fmt.Fprintf(outputFile, "package %s\n\n", pkg)
   233  
   234  	// Emit the imports lazily.
   235  	var once sync.Once
   236  	maybeEmitImports := func() {
   237  		once.Do(func() {
   238  			// Emit the imports.
   239  			fmt.Fprint(outputFile, "import (\n")
   240  			fmt.Fprint(outputFile, " \"context\"\n")
   241  			if *statePkg != "" {
   242  				fmt.Fprintf(outputFile, "	\"%s\"\n", *statePkg)
   243  			}
   244  			if *imports != "" {
   245  				for _, i := range strings.Split(*imports, ",") {
   246  					fmt.Fprintf(outputFile, "	\"%s\"\n", i)
   247  				}
   248  			}
   249  			fmt.Fprint(outputFile, ")\n\n")
   250  		})
   251  	}
   252  
   253  	files := make([]*ast.File, 0, len(flag.Args()))
   254  
   255  	// Parse the input files.
   256  	for _, filename := range flag.Args() {
   257  		// Parse the file.
   258  		fset := token.NewFileSet()
   259  		f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
   260  		if err != nil {
   261  			// Not a valid input file?
   262  			fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
   263  			os.Exit(1)
   264  		}
   265  
   266  		files = append(files, f)
   267  	}
   268  
   269  	type method struct {
   270  		typeName   string
   271  		methodName string
   272  	}
   273  
   274  	// Search for and add all method to a set. We auto-detecting several
   275  	// different methods (and insert them if we don't find them, in order
   276  	// to ensure that expectations match reality).
   277  	//
   278  	// While we do this, figure out the right receiver name. If there are
   279  	// multiple distinct receivers, then we will just pick the last one.
   280  	simpleMethods := make(map[method]struct{})
   281  	receiverNames := make(map[string]string)
   282  	for _, f := range files {
   283  		// Go over all functions.
   284  		for _, decl := range f.Decls {
   285  			d, ok := decl.(*ast.FuncDecl)
   286  			if !ok {
   287  				continue
   288  			}
   289  			if d.Recv == nil || len(d.Recv.List) != 1 {
   290  				// Not a named method.
   291  				continue
   292  			}
   293  
   294  			// Save the method and the receiver.
   295  			name, _ := resolveTypeName(d.Recv.List[0].Type)
   296  			simpleMethods[method{
   297  				typeName:   name,
   298  				methodName: d.Name.Name,
   299  			}] = struct{}{}
   300  			if len(d.Recv.List[0].Names) > 0 {
   301  				receiverNames[name] = d.Recv.List[0].Names[0].Name
   302  			}
   303  		}
   304  	}
   305  
   306  	for _, f := range files {
   307  		// Go over all named types.
   308  		for _, decl := range f.Decls {
   309  			d, ok := decl.(*ast.GenDecl)
   310  			if !ok || d.Tok != token.TYPE {
   311  				continue
   312  			}
   313  
   314  			// Only generate code for types marked "// +stateify
   315  			// savable" in one of the proceeding comment lines. If
   316  			// the line is marked "// +stateify type" then only
   317  			// generate type information and register the type.
   318  			// If the type also has a "// +stateify identtype"
   319  			// comment, the functions are instead generated to refer to
   320  			// the type that this newly-defined type is identical to, rather
   321  			// than about the newly-defined type itself.
   322  			if d.Doc == nil {
   323  				continue
   324  			}
   325  			var (
   326  				generateTypeInfo    = false
   327  				generateSaverLoader = false
   328  				isIdentType         = false
   329  			)
   330  			for _, l := range d.Doc.List {
   331  				if l.Text == "// +stateify savable" {
   332  					generateTypeInfo = true
   333  					generateSaverLoader = true
   334  				}
   335  				if l.Text == "// +stateify type" {
   336  					generateTypeInfo = true
   337  				}
   338  				if l.Text == "// +stateify identtype" {
   339  					isIdentType = true
   340  				}
   341  			}
   342  			if !generateTypeInfo && !generateSaverLoader {
   343  				continue
   344  			}
   345  
   346  			for _, gs := range d.Specs {
   347  				ts := gs.(*ast.TypeSpec)
   348  				recv, ok := receiverNames[ts.Name.Name]
   349  				if !ok {
   350  					// Maybe no methods were defined?
   351  					recv = strings.ToLower(ts.Name.Name[:1])
   352  				}
   353  				switch x := ts.Type.(type) {
   354  				case *ast.StructType:
   355  					maybeEmitImports()
   356  					if isIdentType {
   357  						fmt.Fprintf(os.Stderr, "Cannot use `+stateify identtype` on a struct type (%v); must be a type definition of an identical type.", ts.Name.Name)
   358  						os.Exit(1)
   359  					}
   360  
   361  					// Record the slot for each field.
   362  					fieldCount := 0
   363  					fields := make(map[string]int)
   364  					emitField := func(name string) {
   365  						fmt.Fprintf(outputFile, "		\"%s\",\n", name)
   366  						fields[name] = fieldCount
   367  						fieldCount++
   368  					}
   369  					emitFieldValue := func(name string, _ string) {
   370  						emitField(name)
   371  					}
   372  					emitLoadValue := func(name, typName string) {
   373  						fmt.Fprintf(outputFile, "	stateSourceObject.LoadValue(%d, new(%s), func(y any) { %s.load%s(ctx, y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName)
   374  					}
   375  					emitLoad := func(name string) {
   376  						fmt.Fprintf(outputFile, "	stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name)
   377  					}
   378  					emitLoadWait := func(name string) {
   379  						fmt.Fprintf(outputFile, "	stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name)
   380  					}
   381  					emitSaveValue := func(name, typName string) {
   382  						// Emit typName to be more robust against code generation bugs,
   383  						// but instead of one line make two lines to silence ST1023
   384  						// finding (i.e. avoid nogo finding: "should omit type $typName
   385  						// from declaration; it will be inferred from the right-hand side")
   386  						fmt.Fprintf(outputFile, "	var %sValue %s\n", name, typName)
   387  						fmt.Fprintf(outputFile, "	%sValue = %s.save%s()\n", name, recv, camelCased(name))
   388  						fmt.Fprintf(outputFile, "	stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
   389  					}
   390  					emitSave := func(name string) {
   391  						fmt.Fprintf(outputFile, "	stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name)
   392  					}
   393  					emitZeroCheck := func(name string) {
   394  						fmt.Fprintf(outputFile, "	if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name)
   395  					}
   396  
   397  					// Generate the type name method.
   398  					fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
   399  					fmt.Fprintf(outputFile, "	return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
   400  					fmt.Fprintf(outputFile, "}\n\n")
   401  
   402  					// Generate the fields method.
   403  					fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
   404  					fmt.Fprintf(outputFile, "	return []string{\n")
   405  					scanFields(x, scanFunctions{
   406  						normal: emitField,
   407  						wait:   emitField,
   408  						value:  emitFieldValue,
   409  					})
   410  					fmt.Fprintf(outputFile, "	}\n")
   411  					fmt.Fprintf(outputFile, "}\n\n")
   412  
   413  					// Define beforeSave if a definition was not found. This prevents
   414  					// the code from compiling if a custom beforeSave was defined in a
   415  					// file not provided to this binary and prevents inherited methods
   416  					// from being called multiple times by overriding them.
   417  					if _, ok := simpleMethods[method{
   418  						typeName:   ts.Name.Name,
   419  						methodName: "beforeSave",
   420  					}]; !ok && generateSaverLoader {
   421  						fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name)
   422  					}
   423  
   424  					// Generate the save method.
   425  					//
   426  					// N.B. For historical reasons, we perform the value saves first,
   427  					// and perform the value loads last. There should be no dependency
   428  					// on this specific behavior, but the ability to specify slots
   429  					// allows a manual implementation to be order-dependent.
   430  					if generateSaverLoader {
   431  						fmt.Fprintf(outputFile, "// +checklocksignore\n")
   432  						fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
   433  						fmt.Fprintf(outputFile, "	%s.beforeSave()\n", recv)
   434  						scanFields(x, scanFunctions{zerovalue: emitZeroCheck})
   435  						scanFields(x, scanFunctions{value: emitSaveValue})
   436  						scanFields(x, scanFunctions{normal: emitSave, wait: emitSave})
   437  						fmt.Fprintf(outputFile, "}\n\n")
   438  					}
   439  
   440  					// Define afterLoad if a definition was not found. We do this for
   441  					// the same reason that we do it for beforeSave.
   442  					_, hasAfterLoad := simpleMethods[method{
   443  						typeName:   ts.Name.Name,
   444  						methodName: "afterLoad",
   445  					}]
   446  					if !hasAfterLoad && generateSaverLoader {
   447  						fmt.Fprintf(outputFile, "func (%s *%s) afterLoad(context.Context) {}\n\n", recv, ts.Name.Name)
   448  					}
   449  
   450  					// Generate the load method.
   451  					//
   452  					// N.B. See the comment above for the save method.
   453  					if generateSaverLoader {
   454  						fmt.Fprintf(outputFile, "// +checklocksignore\n")
   455  						fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(ctx context.Context, stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
   456  						scanFields(x, scanFunctions{normal: emitLoad, wait: emitLoadWait})
   457  						scanFields(x, scanFunctions{value: emitLoadValue})
   458  						if hasAfterLoad {
   459  							// The call to afterLoad is made conditionally, because when
   460  							// AfterLoad is called, the object encodes a dependency on
   461  							// referred objects (i.e. fields). This means that afterLoad
   462  							// will not be called until the other afterLoads are called.
   463  							fmt.Fprintf(outputFile, "	stateSourceObject.AfterLoad(func () { %s.afterLoad(ctx) })\n", recv)
   464  						}
   465  						fmt.Fprintf(outputFile, "}\n\n")
   466  					}
   467  
   468  					// Add to our registration.
   469  					emitRegister(ts.Name.Name)
   470  
   471  				case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
   472  					maybeEmitImports()
   473  
   474  					// Generate the info methods.
   475  					fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
   476  					fmt.Fprintf(outputFile, "	return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
   477  					fmt.Fprintf(outputFile, "}\n\n")
   478  
   479  					if !isIdentType {
   480  						fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
   481  						fmt.Fprintf(outputFile, "	return nil\n")
   482  						fmt.Fprintf(outputFile, "}\n\n")
   483  					} else {
   484  						var typeName string
   485  						switch y := x.(type) {
   486  						case *ast.Ident:
   487  							typeName = y.Name
   488  						case *ast.SelectorExpr:
   489  							expIdent, ok := y.X.(*ast.Ident)
   490  							if !ok {
   491  								fmt.Fprintf(os.Stderr, "Cannot use non-ident %v (type %T) in type selector expression %v", y.X, y.X, y)
   492  								os.Exit(1)
   493  							}
   494  							typeName = fmt.Sprintf("%s.%s", expIdent.Name, y.Sel.Name)
   495  						default:
   496  							fmt.Fprintf(os.Stderr, "Cannot use `+stateify identtype` on a non-identifier/non-selector type definition (%v => %v of type %T); must be a type definition of an identical type.", ts.Name.Name, x, x)
   497  							os.Exit(1)
   498  						}
   499  						fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
   500  						fmt.Fprintf(outputFile, "	return (*%s)(%s).StateFields()\n", typeName, recv)
   501  						fmt.Fprintf(outputFile, "}\n\n")
   502  						if generateSaverLoader {
   503  							fmt.Fprintf(outputFile, "// +checklocksignore\n")
   504  							fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
   505  							fmt.Fprintf(outputFile, "	(*%s)(%s).StateSave(stateSinkObject)\n", typeName, recv)
   506  							fmt.Fprintf(outputFile, "}\n\n")
   507  							fmt.Fprintf(outputFile, "// +checklocksignore\n")
   508  							fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(ctx context.Context, stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
   509  							fmt.Fprintf(outputFile, "	(*%s)(%s).StateLoad(ctx, stateSourceObject)\n", typeName, recv)
   510  							fmt.Fprintf(outputFile, "}\n\n")
   511  						}
   512  					}
   513  
   514  					// See above.
   515  					emitRegister(ts.Name.Name)
   516  				}
   517  			}
   518  		}
   519  	}
   520  
   521  	if len(initCalls) > 0 {
   522  		// Emit the init() function.
   523  		fmt.Fprintf(outputFile, "func init() {\n")
   524  		for _, ic := range initCalls {
   525  			fmt.Fprintf(outputFile, "	%s\n", ic)
   526  		}
   527  		fmt.Fprintf(outputFile, "}\n")
   528  	}
   529  }