github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_stateify/main.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Stateify provides a simple way to generate Load/Save methods based on
    16  // existing types and struct tags.
    17  package main
    18  
    19  import (
    20  	"flag"
    21  	"fmt"
    22  	"go/ast"
    23  	"go/parser"
    24  	"go/token"
    25  	"os"
    26  	"path/filepath"
    27  	"reflect"
    28  	"strings"
    29  	"sync"
    30  
    31  	"github.com/SagerNet/gvisor/tools/tags"
    32  )
    33  
    34  var (
    35  	fullPkg  = flag.String("fullpkg", "", "fully qualified output package")
    36  	imports  = flag.String("imports", "", "extra imports for the output file")
    37  	output   = flag.String("output", "", "output file")
    38  	statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
    39  )
    40  
    41  // resolveTypeName returns a qualified type name.
    42  func resolveTypeName(typ ast.Expr) (field string, qualified string) {
    43  	for done := false; !done; {
    44  		// Resolve star expressions.
    45  		switch rs := typ.(type) {
    46  		case *ast.StarExpr:
    47  			qualified += "*"
    48  			typ = rs.X
    49  		case *ast.ArrayType:
    50  			if rs.Len == nil {
    51  				// Slice type declaration.
    52  				qualified += "[]"
    53  			} else {
    54  				// Array type declaration.
    55  				qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]"
    56  			}
    57  			typ = rs.Elt
    58  		default:
    59  			// No more descent.
    60  			done = true
    61  		}
    62  	}
    63  
    64  	// Resolve a package selector.
    65  	sel, ok := typ.(*ast.SelectorExpr)
    66  	if ok {
    67  		qualified = qualified + sel.X.(*ast.Ident).Name + "."
    68  		typ = sel.Sel
    69  	}
    70  
    71  	// Figure out actual type name.
    72  	field = typ.(*ast.Ident).Name
    73  	qualified = qualified + field
    74  	return
    75  }
    76  
    77  // extractStateTag pulls the relevant state tag.
    78  func extractStateTag(tag *ast.BasicLit) string {
    79  	if tag == nil {
    80  		return ""
    81  	}
    82  	if len(tag.Value) < 2 {
    83  		return ""
    84  	}
    85  	return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state")
    86  }
    87  
    88  // scanFunctions is a set of functions passed to scanFields.
    89  type scanFunctions struct {
    90  	zerovalue func(name string)
    91  	normal    func(name string)
    92  	wait      func(name string)
    93  	value     func(name, typName string)
    94  }
    95  
    96  // scanFields scans the fields of a struct.
    97  //
    98  // Each provided function will be applied to appropriately tagged fields, or
    99  // skipped if nil.
   100  //
   101  // Fields tagged nosave are skipped.
   102  func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
   103  	if ss.Fields.List == nil {
   104  		// No fields.
   105  		return
   106  	}
   107  
   108  	// Scan all fields.
   109  	for _, field := range ss.Fields.List {
   110  		// Calculate the name.
   111  		name := ""
   112  		if field.Names != nil {
   113  			// It's a named field; override.
   114  			name = field.Names[0].Name
   115  		} else {
   116  			// Anonymous types can't be embedded, so we don't need
   117  			// to worry about providing a useful name here.
   118  			name, _ = resolveTypeName(field.Type)
   119  		}
   120  
   121  		// Skip _ fields.
   122  		if name == "_" {
   123  			continue
   124  		}
   125  
   126  		// Is this a anonymous struct? If yes, then continue the
   127  		// recursion with the given prefix. We don't pay attention to
   128  		// any tags on the top-level struct field.
   129  		tag := extractStateTag(field.Tag)
   130  		if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
   131  			scanFields(anon, name+".", fn)
   132  			continue
   133  		}
   134  
   135  		switch tag {
   136  		case "zerovalue":
   137  			if fn.zerovalue != nil {
   138  				fn.zerovalue(name)
   139  			}
   140  
   141  		case "":
   142  			if fn.normal != nil {
   143  				fn.normal(name)
   144  			}
   145  
   146  		case "wait":
   147  			if fn.wait != nil {
   148  				fn.wait(name)
   149  			}
   150  
   151  		case "manual", "nosave", "ignore":
   152  			// Do nothing.
   153  
   154  		default:
   155  			if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") {
   156  				if fn.value != nil {
   157  					fn.value(name, tag[2:len(tag)-1])
   158  				}
   159  			}
   160  		}
   161  	}
   162  }
   163  
   164  func camelCased(name string) string {
   165  	return strings.ToUpper(name[:1]) + name[1:]
   166  }
   167  
   168  func main() {
   169  	// Parse flags.
   170  	flag.Usage = func() {
   171  		fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
   172  		flag.PrintDefaults()
   173  	}
   174  	flag.Parse()
   175  	if len(flag.Args()) == 0 {
   176  		flag.Usage()
   177  		os.Exit(1)
   178  	}
   179  	if *fullPkg == "" {
   180  		fmt.Fprintf(os.Stderr, "Error: package required.")
   181  		os.Exit(1)
   182  	}
   183  
   184  	// Open the output file.
   185  	var (
   186  		outputFile *os.File
   187  		err        error
   188  	)
   189  	if *output == "" || *output == "-" {
   190  		outputFile = os.Stdout
   191  	} else {
   192  		outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
   193  		if err != nil {
   194  			fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err)
   195  		}
   196  		defer outputFile.Close()
   197  	}
   198  
   199  	// Set the statePrefix for below, depending on the import.
   200  	statePrefix := ""
   201  	if *statePkg != "" {
   202  		parts := strings.Split(*statePkg, "/")
   203  		statePrefix = parts[len(parts)-1] + "."
   204  	}
   205  
   206  	// initCalls is dumped at the end.
   207  	var initCalls []string
   208  
   209  	// Common closures.
   210  	emitRegister := func(name string) {
   211  		initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
   212  	}
   213  
   214  	// Automated warning.
   215  	fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
   216  
   217  	// Emit build tags.
   218  	if t := tags.Aggregate(flag.Args()); len(t) > 0 {
   219  		fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n"))
   220  	}
   221  
   222  	// Emit the package name.
   223  	_, pkg := filepath.Split(*fullPkg)
   224  	fmt.Fprintf(outputFile, "package %s\n\n", pkg)
   225  
   226  	// Emit the imports lazily.
   227  	var once sync.Once
   228  	maybeEmitImports := func() {
   229  		once.Do(func() {
   230  			// Emit the imports.
   231  			fmt.Fprint(outputFile, "import (\n")
   232  			if *statePkg != "" {
   233  				fmt.Fprintf(outputFile, "	\"%s\"\n", *statePkg)
   234  			}
   235  			if *imports != "" {
   236  				for _, i := range strings.Split(*imports, ",") {
   237  					fmt.Fprintf(outputFile, "	\"%s\"\n", i)
   238  				}
   239  			}
   240  			fmt.Fprint(outputFile, ")\n\n")
   241  		})
   242  	}
   243  
   244  	files := make([]*ast.File, 0, len(flag.Args()))
   245  
   246  	// Parse the input files.
   247  	for _, filename := range flag.Args() {
   248  		// Parse the file.
   249  		fset := token.NewFileSet()
   250  		f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
   251  		if err != nil {
   252  			// Not a valid input file?
   253  			fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
   254  			os.Exit(1)
   255  		}
   256  
   257  		files = append(files, f)
   258  	}
   259  
   260  	type method struct {
   261  		typeName   string
   262  		methodName string
   263  	}
   264  
   265  	// Search for and add all method to a set. We auto-detecting several
   266  	// different methods (and insert them if we don't find them, in order
   267  	// to ensure that expectations match reality).
   268  	//
   269  	// While we do this, figure out the right receiver name. If there are
   270  	// multiple distinct receivers, then we will just pick the last one.
   271  	simpleMethods := make(map[method]struct{})
   272  	receiverNames := make(map[string]string)
   273  	for _, f := range files {
   274  		// Go over all functions.
   275  		for _, decl := range f.Decls {
   276  			d, ok := decl.(*ast.FuncDecl)
   277  			if !ok {
   278  				continue
   279  			}
   280  			if d.Recv == nil || len(d.Recv.List) != 1 {
   281  				// Not a named method.
   282  				continue
   283  			}
   284  
   285  			// Save the method and the receiver.
   286  			name, _ := resolveTypeName(d.Recv.List[0].Type)
   287  			simpleMethods[method{
   288  				typeName:   name,
   289  				methodName: d.Name.Name,
   290  			}] = struct{}{}
   291  			if len(d.Recv.List[0].Names) > 0 {
   292  				receiverNames[name] = d.Recv.List[0].Names[0].Name
   293  			}
   294  		}
   295  	}
   296  
   297  	for _, f := range files {
   298  		// Go over all named types.
   299  		for _, decl := range f.Decls {
   300  			d, ok := decl.(*ast.GenDecl)
   301  			if !ok || d.Tok != token.TYPE {
   302  				continue
   303  			}
   304  
   305  			// Only generate code for types marked "// +stateify
   306  			// savable" in one of the proceeding comment lines. If
   307  			// the line is marked "// +stateify type" then only
   308  			// generate type information and register the type.
   309  			if d.Doc == nil {
   310  				continue
   311  			}
   312  			var (
   313  				generateTypeInfo    = false
   314  				generateSaverLoader = false
   315  			)
   316  			for _, l := range d.Doc.List {
   317  				if l.Text == "// +stateify savable" {
   318  					generateTypeInfo = true
   319  					generateSaverLoader = true
   320  					break
   321  				}
   322  				if l.Text == "// +stateify type" {
   323  					generateTypeInfo = true
   324  				}
   325  			}
   326  			if !generateTypeInfo && !generateSaverLoader {
   327  				continue
   328  			}
   329  
   330  			for _, gs := range d.Specs {
   331  				ts := gs.(*ast.TypeSpec)
   332  				recv, ok := receiverNames[ts.Name.Name]
   333  				if !ok {
   334  					// Maybe no methods were defined?
   335  					recv = strings.ToLower(ts.Name.Name[:1])
   336  				}
   337  				switch x := ts.Type.(type) {
   338  				case *ast.StructType:
   339  					maybeEmitImports()
   340  
   341  					// Record the slot for each field.
   342  					fieldCount := 0
   343  					fields := make(map[string]int)
   344  					emitField := func(name string) {
   345  						fmt.Fprintf(outputFile, "		\"%s\",\n", name)
   346  						fields[name] = fieldCount
   347  						fieldCount++
   348  					}
   349  					emitFieldValue := func(name string, _ string) {
   350  						emitField(name)
   351  					}
   352  					emitLoadValue := func(name, typName string) {
   353  						fmt.Fprintf(outputFile, "	stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName)
   354  					}
   355  					emitLoad := func(name string) {
   356  						fmt.Fprintf(outputFile, "	stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name)
   357  					}
   358  					emitLoadWait := func(name string) {
   359  						fmt.Fprintf(outputFile, "	stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name)
   360  					}
   361  					emitSaveValue := func(name, typName string) {
   362  						fmt.Fprintf(outputFile, "	var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name))
   363  						fmt.Fprintf(outputFile, "	stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
   364  					}
   365  					emitSave := func(name string) {
   366  						fmt.Fprintf(outputFile, "	stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name)
   367  					}
   368  					emitZeroCheck := func(name string) {
   369  						fmt.Fprintf(outputFile, "	if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name)
   370  					}
   371  
   372  					// Generate the type name method.
   373  					fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
   374  					fmt.Fprintf(outputFile, "	return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
   375  					fmt.Fprintf(outputFile, "}\n\n")
   376  
   377  					// Generate the fields method.
   378  					fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
   379  					fmt.Fprintf(outputFile, "	return []string{\n")
   380  					scanFields(x, "", scanFunctions{
   381  						normal: emitField,
   382  						wait:   emitField,
   383  						value:  emitFieldValue,
   384  					})
   385  					fmt.Fprintf(outputFile, "	}\n")
   386  					fmt.Fprintf(outputFile, "}\n\n")
   387  
   388  					// Define beforeSave if a definition was not found. This prevents
   389  					// the code from compiling if a custom beforeSave was defined in a
   390  					// file not provided to this binary and prevents inherited methods
   391  					// from being called multiple times by overriding them.
   392  					if _, ok := simpleMethods[method{
   393  						typeName:   ts.Name.Name,
   394  						methodName: "beforeSave",
   395  					}]; !ok && generateSaverLoader {
   396  						fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name)
   397  					}
   398  
   399  					// Generate the save method.
   400  					//
   401  					// N.B. For historical reasons, we perform the value saves first,
   402  					// and perform the value loads last. There should be no dependency
   403  					// on this specific behavior, but the ability to specify slots
   404  					// allows a manual implementation to be order-dependent.
   405  					if generateSaverLoader {
   406  						fmt.Fprintf(outputFile, "// +checklocksignore\n")
   407  						fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
   408  						fmt.Fprintf(outputFile, "	%s.beforeSave()\n", recv)
   409  						scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
   410  						scanFields(x, "", scanFunctions{value: emitSaveValue})
   411  						scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
   412  						fmt.Fprintf(outputFile, "}\n\n")
   413  					}
   414  
   415  					// Define afterLoad if a definition was not found. We do this for
   416  					// the same reason that we do it for beforeSave.
   417  					_, hasAfterLoad := simpleMethods[method{
   418  						typeName:   ts.Name.Name,
   419  						methodName: "afterLoad",
   420  					}]
   421  					if !hasAfterLoad && generateSaverLoader {
   422  						fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name)
   423  					}
   424  
   425  					// Generate the load method.
   426  					//
   427  					// N.B. See the comment above for the save method.
   428  					if generateSaverLoader {
   429  						fmt.Fprintf(outputFile, "// +checklocksignore\n")
   430  						fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
   431  						scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
   432  						scanFields(x, "", scanFunctions{value: emitLoadValue})
   433  						if hasAfterLoad {
   434  							// The call to afterLoad is made conditionally, because when
   435  							// AfterLoad is called, the object encodes a dependency on
   436  							// referred objects (i.e. fields). This means that afterLoad
   437  							// will not be called until the other afterLoads are called.
   438  							fmt.Fprintf(outputFile, "	stateSourceObject.AfterLoad(%s.afterLoad)\n", recv)
   439  						}
   440  						fmt.Fprintf(outputFile, "}\n\n")
   441  					}
   442  
   443  					// Add to our registration.
   444  					emitRegister(ts.Name.Name)
   445  
   446  				case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
   447  					maybeEmitImports()
   448  
   449  					// Generate the info methods.
   450  					fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
   451  					fmt.Fprintf(outputFile, "	return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
   452  					fmt.Fprintf(outputFile, "}\n\n")
   453  					fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
   454  					fmt.Fprintf(outputFile, "	return nil\n")
   455  					fmt.Fprintf(outputFile, "}\n\n")
   456  
   457  					// See above.
   458  					emitRegister(ts.Name.Name)
   459  				}
   460  			}
   461  		}
   462  	}
   463  
   464  	if len(initCalls) > 0 {
   465  		// Emit the init() function.
   466  		fmt.Fprintf(outputFile, "func init() {\n")
   467  		for _, ic := range initCalls {
   468  			fmt.Fprintf(outputFile, "	%s\n", ic)
   469  		}
   470  		fmt.Fprintf(outputFile, "}\n")
   471  	}
   472  }