github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_stateify/main.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Stateify provides a simple way to generate Load/Save methods based on 16 // existing types and struct tags. 17 package main 18 19 import ( 20 "flag" 21 "fmt" 22 "go/ast" 23 "go/parser" 24 "go/token" 25 "os" 26 "path/filepath" 27 "reflect" 28 "strings" 29 "sync" 30 31 "github.com/SagerNet/gvisor/tools/tags" 32 ) 33 34 var ( 35 fullPkg = flag.String("fullpkg", "", "fully qualified output package") 36 imports = flag.String("imports", "", "extra imports for the output file") 37 output = flag.String("output", "", "output file") 38 statePkg = flag.String("statepkg", "", "state import package; defaults to empty") 39 ) 40 41 // resolveTypeName returns a qualified type name. 42 func resolveTypeName(typ ast.Expr) (field string, qualified string) { 43 for done := false; !done; { 44 // Resolve star expressions. 45 switch rs := typ.(type) { 46 case *ast.StarExpr: 47 qualified += "*" 48 typ = rs.X 49 case *ast.ArrayType: 50 if rs.Len == nil { 51 // Slice type declaration. 52 qualified += "[]" 53 } else { 54 // Array type declaration. 55 qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" 56 } 57 typ = rs.Elt 58 default: 59 // No more descent. 60 done = true 61 } 62 } 63 64 // Resolve a package selector. 65 sel, ok := typ.(*ast.SelectorExpr) 66 if ok { 67 qualified = qualified + sel.X.(*ast.Ident).Name + "." 68 typ = sel.Sel 69 } 70 71 // Figure out actual type name. 72 field = typ.(*ast.Ident).Name 73 qualified = qualified + field 74 return 75 } 76 77 // extractStateTag pulls the relevant state tag. 78 func extractStateTag(tag *ast.BasicLit) string { 79 if tag == nil { 80 return "" 81 } 82 if len(tag.Value) < 2 { 83 return "" 84 } 85 return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") 86 } 87 88 // scanFunctions is a set of functions passed to scanFields. 89 type scanFunctions struct { 90 zerovalue func(name string) 91 normal func(name string) 92 wait func(name string) 93 value func(name, typName string) 94 } 95 96 // scanFields scans the fields of a struct. 97 // 98 // Each provided function will be applied to appropriately tagged fields, or 99 // skipped if nil. 100 // 101 // Fields tagged nosave are skipped. 102 func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { 103 if ss.Fields.List == nil { 104 // No fields. 105 return 106 } 107 108 // Scan all fields. 109 for _, field := range ss.Fields.List { 110 // Calculate the name. 111 name := "" 112 if field.Names != nil { 113 // It's a named field; override. 114 name = field.Names[0].Name 115 } else { 116 // Anonymous types can't be embedded, so we don't need 117 // to worry about providing a useful name here. 118 name, _ = resolveTypeName(field.Type) 119 } 120 121 // Skip _ fields. 122 if name == "_" { 123 continue 124 } 125 126 // Is this a anonymous struct? If yes, then continue the 127 // recursion with the given prefix. We don't pay attention to 128 // any tags on the top-level struct field. 129 tag := extractStateTag(field.Tag) 130 if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { 131 scanFields(anon, name+".", fn) 132 continue 133 } 134 135 switch tag { 136 case "zerovalue": 137 if fn.zerovalue != nil { 138 fn.zerovalue(name) 139 } 140 141 case "": 142 if fn.normal != nil { 143 fn.normal(name) 144 } 145 146 case "wait": 147 if fn.wait != nil { 148 fn.wait(name) 149 } 150 151 case "manual", "nosave", "ignore": 152 // Do nothing. 153 154 default: 155 if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { 156 if fn.value != nil { 157 fn.value(name, tag[2:len(tag)-1]) 158 } 159 } 160 } 161 } 162 } 163 164 func camelCased(name string) string { 165 return strings.ToUpper(name[:1]) + name[1:] 166 } 167 168 func main() { 169 // Parse flags. 170 flag.Usage = func() { 171 fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) 172 flag.PrintDefaults() 173 } 174 flag.Parse() 175 if len(flag.Args()) == 0 { 176 flag.Usage() 177 os.Exit(1) 178 } 179 if *fullPkg == "" { 180 fmt.Fprintf(os.Stderr, "Error: package required.") 181 os.Exit(1) 182 } 183 184 // Open the output file. 185 var ( 186 outputFile *os.File 187 err error 188 ) 189 if *output == "" || *output == "-" { 190 outputFile = os.Stdout 191 } else { 192 outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 193 if err != nil { 194 fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) 195 } 196 defer outputFile.Close() 197 } 198 199 // Set the statePrefix for below, depending on the import. 200 statePrefix := "" 201 if *statePkg != "" { 202 parts := strings.Split(*statePkg, "/") 203 statePrefix = parts[len(parts)-1] + "." 204 } 205 206 // initCalls is dumped at the end. 207 var initCalls []string 208 209 // Common closures. 210 emitRegister := func(name string) { 211 initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) 212 } 213 214 // Automated warning. 215 fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") 216 217 // Emit build tags. 218 if t := tags.Aggregate(flag.Args()); len(t) > 0 { 219 fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) 220 } 221 222 // Emit the package name. 223 _, pkg := filepath.Split(*fullPkg) 224 fmt.Fprintf(outputFile, "package %s\n\n", pkg) 225 226 // Emit the imports lazily. 227 var once sync.Once 228 maybeEmitImports := func() { 229 once.Do(func() { 230 // Emit the imports. 231 fmt.Fprint(outputFile, "import (\n") 232 if *statePkg != "" { 233 fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) 234 } 235 if *imports != "" { 236 for _, i := range strings.Split(*imports, ",") { 237 fmt.Fprintf(outputFile, " \"%s\"\n", i) 238 } 239 } 240 fmt.Fprint(outputFile, ")\n\n") 241 }) 242 } 243 244 files := make([]*ast.File, 0, len(flag.Args())) 245 246 // Parse the input files. 247 for _, filename := range flag.Args() { 248 // Parse the file. 249 fset := token.NewFileSet() 250 f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) 251 if err != nil { 252 // Not a valid input file? 253 fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) 254 os.Exit(1) 255 } 256 257 files = append(files, f) 258 } 259 260 type method struct { 261 typeName string 262 methodName string 263 } 264 265 // Search for and add all method to a set. We auto-detecting several 266 // different methods (and insert them if we don't find them, in order 267 // to ensure that expectations match reality). 268 // 269 // While we do this, figure out the right receiver name. If there are 270 // multiple distinct receivers, then we will just pick the last one. 271 simpleMethods := make(map[method]struct{}) 272 receiverNames := make(map[string]string) 273 for _, f := range files { 274 // Go over all functions. 275 for _, decl := range f.Decls { 276 d, ok := decl.(*ast.FuncDecl) 277 if !ok { 278 continue 279 } 280 if d.Recv == nil || len(d.Recv.List) != 1 { 281 // Not a named method. 282 continue 283 } 284 285 // Save the method and the receiver. 286 name, _ := resolveTypeName(d.Recv.List[0].Type) 287 simpleMethods[method{ 288 typeName: name, 289 methodName: d.Name.Name, 290 }] = struct{}{} 291 if len(d.Recv.List[0].Names) > 0 { 292 receiverNames[name] = d.Recv.List[0].Names[0].Name 293 } 294 } 295 } 296 297 for _, f := range files { 298 // Go over all named types. 299 for _, decl := range f.Decls { 300 d, ok := decl.(*ast.GenDecl) 301 if !ok || d.Tok != token.TYPE { 302 continue 303 } 304 305 // Only generate code for types marked "// +stateify 306 // savable" in one of the proceeding comment lines. If 307 // the line is marked "// +stateify type" then only 308 // generate type information and register the type. 309 if d.Doc == nil { 310 continue 311 } 312 var ( 313 generateTypeInfo = false 314 generateSaverLoader = false 315 ) 316 for _, l := range d.Doc.List { 317 if l.Text == "// +stateify savable" { 318 generateTypeInfo = true 319 generateSaverLoader = true 320 break 321 } 322 if l.Text == "// +stateify type" { 323 generateTypeInfo = true 324 } 325 } 326 if !generateTypeInfo && !generateSaverLoader { 327 continue 328 } 329 330 for _, gs := range d.Specs { 331 ts := gs.(*ast.TypeSpec) 332 recv, ok := receiverNames[ts.Name.Name] 333 if !ok { 334 // Maybe no methods were defined? 335 recv = strings.ToLower(ts.Name.Name[:1]) 336 } 337 switch x := ts.Type.(type) { 338 case *ast.StructType: 339 maybeEmitImports() 340 341 // Record the slot for each field. 342 fieldCount := 0 343 fields := make(map[string]int) 344 emitField := func(name string) { 345 fmt.Fprintf(outputFile, " \"%s\",\n", name) 346 fields[name] = fieldCount 347 fieldCount++ 348 } 349 emitFieldValue := func(name string, _ string) { 350 emitField(name) 351 } 352 emitLoadValue := func(name, typName string) { 353 fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName) 354 } 355 emitLoad := func(name string) { 356 fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name) 357 } 358 emitLoadWait := func(name string) { 359 fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name) 360 } 361 emitSaveValue := func(name, typName string) { 362 fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name)) 363 fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name) 364 } 365 emitSave := func(name string) { 366 fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name) 367 } 368 emitZeroCheck := func(name string) { 369 fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name) 370 } 371 372 // Generate the type name method. 373 fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) 374 fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) 375 fmt.Fprintf(outputFile, "}\n\n") 376 377 // Generate the fields method. 378 fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) 379 fmt.Fprintf(outputFile, " return []string{\n") 380 scanFields(x, "", scanFunctions{ 381 normal: emitField, 382 wait: emitField, 383 value: emitFieldValue, 384 }) 385 fmt.Fprintf(outputFile, " }\n") 386 fmt.Fprintf(outputFile, "}\n\n") 387 388 // Define beforeSave if a definition was not found. This prevents 389 // the code from compiling if a custom beforeSave was defined in a 390 // file not provided to this binary and prevents inherited methods 391 // from being called multiple times by overriding them. 392 if _, ok := simpleMethods[method{ 393 typeName: ts.Name.Name, 394 methodName: "beforeSave", 395 }]; !ok && generateSaverLoader { 396 fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name) 397 } 398 399 // Generate the save method. 400 // 401 // N.B. For historical reasons, we perform the value saves first, 402 // and perform the value loads last. There should be no dependency 403 // on this specific behavior, but the ability to specify slots 404 // allows a manual implementation to be order-dependent. 405 if generateSaverLoader { 406 fmt.Fprintf(outputFile, "// +checklocksignore\n") 407 fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix) 408 fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv) 409 scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) 410 scanFields(x, "", scanFunctions{value: emitSaveValue}) 411 scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) 412 fmt.Fprintf(outputFile, "}\n\n") 413 } 414 415 // Define afterLoad if a definition was not found. We do this for 416 // the same reason that we do it for beforeSave. 417 _, hasAfterLoad := simpleMethods[method{ 418 typeName: ts.Name.Name, 419 methodName: "afterLoad", 420 }] 421 if !hasAfterLoad && generateSaverLoader { 422 fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name) 423 } 424 425 // Generate the load method. 426 // 427 // N.B. See the comment above for the save method. 428 if generateSaverLoader { 429 fmt.Fprintf(outputFile, "// +checklocksignore\n") 430 fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix) 431 scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) 432 scanFields(x, "", scanFunctions{value: emitLoadValue}) 433 if hasAfterLoad { 434 // The call to afterLoad is made conditionally, because when 435 // AfterLoad is called, the object encodes a dependency on 436 // referred objects (i.e. fields). This means that afterLoad 437 // will not be called until the other afterLoads are called. 438 fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", recv) 439 } 440 fmt.Fprintf(outputFile, "}\n\n") 441 } 442 443 // Add to our registration. 444 emitRegister(ts.Name.Name) 445 446 case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: 447 maybeEmitImports() 448 449 // Generate the info methods. 450 fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) 451 fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) 452 fmt.Fprintf(outputFile, "}\n\n") 453 fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) 454 fmt.Fprintf(outputFile, " return nil\n") 455 fmt.Fprintf(outputFile, "}\n\n") 456 457 // See above. 458 emitRegister(ts.Name.Name) 459 } 460 } 461 } 462 } 463 464 if len(initCalls) > 0 { 465 // Emit the init() function. 466 fmt.Fprintf(outputFile, "func init() {\n") 467 for _, ic := range initCalls { 468 fmt.Fprintf(outputFile, " %s\n", ic) 469 } 470 fmt.Fprintf(outputFile, "}\n") 471 } 472 }