gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_stateify/main.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Stateify provides a simple way to generate Load/Save methods based on 16 // existing types and struct tags. 17 package main 18 19 import ( 20 "flag" 21 "fmt" 22 "go/ast" 23 "go/parser" 24 "go/token" 25 "os" 26 "path/filepath" 27 "reflect" 28 "strings" 29 "sync" 30 31 "gvisor.dev/gvisor/tools/constraintutil" 32 ) 33 34 var ( 35 fullPkg = flag.String("fullpkg", "", "fully qualified output package") 36 imports = flag.String("imports", "", "extra imports for the output file") 37 output = flag.String("output", "", "output file") 38 statePkg = flag.String("statepkg", "", "state import package; defaults to empty") 39 ) 40 41 // resolveTypeName returns a qualified type name. 42 func resolveTypeName(typ ast.Expr) (field string, qualified string) { 43 for done := false; !done; { 44 // Resolve star expressions. 45 switch rs := typ.(type) { 46 case *ast.StarExpr: 47 qualified += "*" 48 typ = rs.X 49 case *ast.ArrayType: 50 if rs.Len == nil { 51 // Slice type declaration. 52 qualified += "[]" 53 } else { 54 // Array type declaration. 55 qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" 56 } 57 typ = rs.Elt 58 default: 59 // No more descent. 60 done = true 61 } 62 } 63 64 // Resolve a package selector. 65 sel, ok := typ.(*ast.SelectorExpr) 66 if ok { 67 qualified = qualified + sel.X.(*ast.Ident).Name + "." 68 typ = sel.Sel 69 } 70 71 // Figure out actual type name. 72 field = typ.(*ast.Ident).Name 73 qualified = qualified + field 74 return 75 } 76 77 // extractStateTag pulls the relevant state tag. 78 func extractStateTag(tag *ast.BasicLit) string { 79 if tag == nil { 80 return "" 81 } 82 if len(tag.Value) < 2 { 83 return "" 84 } 85 return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") 86 } 87 88 // scanFunctions is a set of functions passed to scanFields. 89 type scanFunctions struct { 90 zerovalue func(name string) 91 normal func(name string) 92 wait func(name string) 93 value func(name, typName string) 94 } 95 96 // scanFields scans the fields of a struct. 97 // 98 // Each provided function will be applied to appropriately tagged fields, or 99 // skipped if nil. 100 // 101 // Fields tagged nosave are skipped. 102 func scanFields(ss *ast.StructType, fn scanFunctions) { 103 if ss.Fields.List == nil { 104 // No fields. 105 return 106 } 107 108 // Scan all fields. 109 for _, field := range ss.Fields.List { 110 if field.Names == nil { 111 // Anonymous types can't be embedded, so we don't need 112 // to worry about providing a useful name here. 113 name, _ := resolveTypeName(field.Type) 114 scanField(name, field, fn) 115 continue 116 } 117 118 // Iterate over potentially multiple fields defined on the same line. 119 for _, nameI := range field.Names { 120 name := nameI.Name 121 // Skip _ fields. 122 if name == "_" { 123 continue 124 } 125 scanField(name, field, fn) 126 } 127 } 128 } 129 130 // scanField scans a single struct field with a resolved name. 131 func scanField(name string, field *ast.Field, fn scanFunctions) { 132 // Is this a anonymous struct? If yes, then continue the 133 // recursion with the given prefix. We don't pay attention to 134 // any tags on the top-level struct field. 135 tag := extractStateTag(field.Tag) 136 if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { 137 scanFields(anon, fn) 138 return 139 } 140 141 switch tag { 142 case "zerovalue": 143 if fn.zerovalue != nil { 144 fn.zerovalue(name) 145 } 146 147 case "": 148 if fn.normal != nil { 149 fn.normal(name) 150 } 151 152 case "wait": 153 if fn.wait != nil { 154 fn.wait(name) 155 } 156 157 case "manual", "nosave", "ignore": 158 // Do nothing. 159 160 default: 161 if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { 162 if fn.value != nil { 163 fn.value(name, tag[2:len(tag)-1]) 164 } 165 } 166 } 167 } 168 169 func camelCased(name string) string { 170 return strings.ToUpper(name[:1]) + name[1:] 171 } 172 173 func main() { 174 // Parse flags. 175 flag.Usage = func() { 176 fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) 177 flag.PrintDefaults() 178 } 179 flag.Parse() 180 if len(flag.Args()) == 0 { 181 flag.Usage() 182 os.Exit(1) 183 } 184 if *fullPkg == "" { 185 fmt.Fprintf(os.Stderr, "Error: package required.") 186 os.Exit(1) 187 } 188 189 // Open the output file. 190 var ( 191 outputFile *os.File 192 err error 193 ) 194 if *output == "" || *output == "-" { 195 outputFile = os.Stdout 196 } else { 197 outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 198 if err != nil { 199 fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) 200 } 201 defer outputFile.Close() 202 } 203 204 // Set the statePrefix for below, depending on the import. 205 statePrefix := "" 206 if *statePkg != "" { 207 parts := strings.Split(*statePkg, "/") 208 statePrefix = parts[len(parts)-1] + "." 209 } 210 211 // initCalls is dumped at the end. 212 var initCalls []string 213 214 // Common closures. 215 emitRegister := func(name string) { 216 initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) 217 } 218 219 // Automated warning. 220 fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") 221 222 // Emit build constraints. 223 bcexpr, err := constraintutil.CombineFromFiles(flag.Args()) 224 if err != nil { 225 fmt.Fprintf(os.Stderr, "Failed to infer build constraints: %v", err) 226 os.Exit(1) 227 } 228 outputFile.WriteString(constraintutil.Lines(bcexpr)) 229 230 // Emit the package name. 231 _, pkg := filepath.Split(*fullPkg) 232 fmt.Fprintf(outputFile, "package %s\n\n", pkg) 233 234 // Emit the imports lazily. 235 var once sync.Once 236 maybeEmitImports := func() { 237 once.Do(func() { 238 // Emit the imports. 239 fmt.Fprint(outputFile, "import (\n") 240 fmt.Fprint(outputFile, " \"context\"\n") 241 if *statePkg != "" { 242 fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) 243 } 244 if *imports != "" { 245 for _, i := range strings.Split(*imports, ",") { 246 fmt.Fprintf(outputFile, " \"%s\"\n", i) 247 } 248 } 249 fmt.Fprint(outputFile, ")\n\n") 250 }) 251 } 252 253 files := make([]*ast.File, 0, len(flag.Args())) 254 255 // Parse the input files. 256 for _, filename := range flag.Args() { 257 // Parse the file. 258 fset := token.NewFileSet() 259 f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) 260 if err != nil { 261 // Not a valid input file? 262 fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) 263 os.Exit(1) 264 } 265 266 files = append(files, f) 267 } 268 269 type method struct { 270 typeName string 271 methodName string 272 } 273 274 // Search for and add all method to a set. We auto-detecting several 275 // different methods (and insert them if we don't find them, in order 276 // to ensure that expectations match reality). 277 // 278 // While we do this, figure out the right receiver name. If there are 279 // multiple distinct receivers, then we will just pick the last one. 280 simpleMethods := make(map[method]struct{}) 281 receiverNames := make(map[string]string) 282 for _, f := range files { 283 // Go over all functions. 284 for _, decl := range f.Decls { 285 d, ok := decl.(*ast.FuncDecl) 286 if !ok { 287 continue 288 } 289 if d.Recv == nil || len(d.Recv.List) != 1 { 290 // Not a named method. 291 continue 292 } 293 294 // Save the method and the receiver. 295 name, _ := resolveTypeName(d.Recv.List[0].Type) 296 simpleMethods[method{ 297 typeName: name, 298 methodName: d.Name.Name, 299 }] = struct{}{} 300 if len(d.Recv.List[0].Names) > 0 { 301 receiverNames[name] = d.Recv.List[0].Names[0].Name 302 } 303 } 304 } 305 306 for _, f := range files { 307 // Go over all named types. 308 for _, decl := range f.Decls { 309 d, ok := decl.(*ast.GenDecl) 310 if !ok || d.Tok != token.TYPE { 311 continue 312 } 313 314 // Only generate code for types marked "// +stateify 315 // savable" in one of the proceeding comment lines. If 316 // the line is marked "// +stateify type" then only 317 // generate type information and register the type. 318 // If the type also has a "// +stateify identtype" 319 // comment, the functions are instead generated to refer to 320 // the type that this newly-defined type is identical to, rather 321 // than about the newly-defined type itself. 322 if d.Doc == nil { 323 continue 324 } 325 var ( 326 generateTypeInfo = false 327 generateSaverLoader = false 328 isIdentType = false 329 ) 330 for _, l := range d.Doc.List { 331 if l.Text == "// +stateify savable" { 332 generateTypeInfo = true 333 generateSaverLoader = true 334 } 335 if l.Text == "// +stateify type" { 336 generateTypeInfo = true 337 } 338 if l.Text == "// +stateify identtype" { 339 isIdentType = true 340 } 341 } 342 if !generateTypeInfo && !generateSaverLoader { 343 continue 344 } 345 346 for _, gs := range d.Specs { 347 ts := gs.(*ast.TypeSpec) 348 recv, ok := receiverNames[ts.Name.Name] 349 if !ok { 350 // Maybe no methods were defined? 351 recv = strings.ToLower(ts.Name.Name[:1]) 352 } 353 switch x := ts.Type.(type) { 354 case *ast.StructType: 355 maybeEmitImports() 356 if isIdentType { 357 fmt.Fprintf(os.Stderr, "Cannot use `+stateify identtype` on a struct type (%v); must be a type definition of an identical type.", ts.Name.Name) 358 os.Exit(1) 359 } 360 361 // Record the slot for each field. 362 fieldCount := 0 363 fields := make(map[string]int) 364 emitField := func(name string) { 365 fmt.Fprintf(outputFile, " \"%s\",\n", name) 366 fields[name] = fieldCount 367 fieldCount++ 368 } 369 emitFieldValue := func(name string, _ string) { 370 emitField(name) 371 } 372 emitLoadValue := func(name, typName string) { 373 fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y any) { %s.load%s(ctx, y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName) 374 } 375 emitLoad := func(name string) { 376 fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name) 377 } 378 emitLoadWait := func(name string) { 379 fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name) 380 } 381 emitSaveValue := func(name, typName string) { 382 // Emit typName to be more robust against code generation bugs, 383 // but instead of one line make two lines to silence ST1023 384 // finding (i.e. avoid nogo finding: "should omit type $typName 385 // from declaration; it will be inferred from the right-hand side") 386 fmt.Fprintf(outputFile, " var %sValue %s\n", name, typName) 387 fmt.Fprintf(outputFile, " %sValue = %s.save%s()\n", name, recv, camelCased(name)) 388 fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name) 389 } 390 emitSave := func(name string) { 391 fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name) 392 } 393 emitZeroCheck := func(name string) { 394 fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name) 395 } 396 397 // Generate the type name method. 398 fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) 399 fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) 400 fmt.Fprintf(outputFile, "}\n\n") 401 402 // Generate the fields method. 403 fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) 404 fmt.Fprintf(outputFile, " return []string{\n") 405 scanFields(x, scanFunctions{ 406 normal: emitField, 407 wait: emitField, 408 value: emitFieldValue, 409 }) 410 fmt.Fprintf(outputFile, " }\n") 411 fmt.Fprintf(outputFile, "}\n\n") 412 413 // Define beforeSave if a definition was not found. This prevents 414 // the code from compiling if a custom beforeSave was defined in a 415 // file not provided to this binary and prevents inherited methods 416 // from being called multiple times by overriding them. 417 if _, ok := simpleMethods[method{ 418 typeName: ts.Name.Name, 419 methodName: "beforeSave", 420 }]; !ok && generateSaverLoader { 421 fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name) 422 } 423 424 // Generate the save method. 425 // 426 // N.B. For historical reasons, we perform the value saves first, 427 // and perform the value loads last. There should be no dependency 428 // on this specific behavior, but the ability to specify slots 429 // allows a manual implementation to be order-dependent. 430 if generateSaverLoader { 431 fmt.Fprintf(outputFile, "// +checklocksignore\n") 432 fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix) 433 fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv) 434 scanFields(x, scanFunctions{zerovalue: emitZeroCheck}) 435 scanFields(x, scanFunctions{value: emitSaveValue}) 436 scanFields(x, scanFunctions{normal: emitSave, wait: emitSave}) 437 fmt.Fprintf(outputFile, "}\n\n") 438 } 439 440 // Define afterLoad if a definition was not found. We do this for 441 // the same reason that we do it for beforeSave. 442 _, hasAfterLoad := simpleMethods[method{ 443 typeName: ts.Name.Name, 444 methodName: "afterLoad", 445 }] 446 if !hasAfterLoad && generateSaverLoader { 447 fmt.Fprintf(outputFile, "func (%s *%s) afterLoad(context.Context) {}\n\n", recv, ts.Name.Name) 448 } 449 450 // Generate the load method. 451 // 452 // N.B. See the comment above for the save method. 453 if generateSaverLoader { 454 fmt.Fprintf(outputFile, "// +checklocksignore\n") 455 fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(ctx context.Context, stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix) 456 scanFields(x, scanFunctions{normal: emitLoad, wait: emitLoadWait}) 457 scanFields(x, scanFunctions{value: emitLoadValue}) 458 if hasAfterLoad { 459 // The call to afterLoad is made conditionally, because when 460 // AfterLoad is called, the object encodes a dependency on 461 // referred objects (i.e. fields). This means that afterLoad 462 // will not be called until the other afterLoads are called. 463 fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(func () { %s.afterLoad(ctx) })\n", recv) 464 } 465 fmt.Fprintf(outputFile, "}\n\n") 466 } 467 468 // Add to our registration. 469 emitRegister(ts.Name.Name) 470 471 case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: 472 maybeEmitImports() 473 474 // Generate the info methods. 475 fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) 476 fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) 477 fmt.Fprintf(outputFile, "}\n\n") 478 479 if !isIdentType { 480 fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) 481 fmt.Fprintf(outputFile, " return nil\n") 482 fmt.Fprintf(outputFile, "}\n\n") 483 } else { 484 var typeName string 485 switch y := x.(type) { 486 case *ast.Ident: 487 typeName = y.Name 488 case *ast.SelectorExpr: 489 expIdent, ok := y.X.(*ast.Ident) 490 if !ok { 491 fmt.Fprintf(os.Stderr, "Cannot use non-ident %v (type %T) in type selector expression %v", y.X, y.X, y) 492 os.Exit(1) 493 } 494 typeName = fmt.Sprintf("%s.%s", expIdent.Name, y.Sel.Name) 495 default: 496 fmt.Fprintf(os.Stderr, "Cannot use `+stateify identtype` on a non-identifier/non-selector type definition (%v => %v of type %T); must be a type definition of an identical type.", ts.Name.Name, x, x) 497 os.Exit(1) 498 } 499 fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) 500 fmt.Fprintf(outputFile, " return (*%s)(%s).StateFields()\n", typeName, recv) 501 fmt.Fprintf(outputFile, "}\n\n") 502 if generateSaverLoader { 503 fmt.Fprintf(outputFile, "// +checklocksignore\n") 504 fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix) 505 fmt.Fprintf(outputFile, " (*%s)(%s).StateSave(stateSinkObject)\n", typeName, recv) 506 fmt.Fprintf(outputFile, "}\n\n") 507 fmt.Fprintf(outputFile, "// +checklocksignore\n") 508 fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(ctx context.Context, stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix) 509 fmt.Fprintf(outputFile, " (*%s)(%s).StateLoad(ctx, stateSourceObject)\n", typeName, recv) 510 fmt.Fprintf(outputFile, "}\n\n") 511 } 512 } 513 514 // See above. 515 emitRegister(ts.Name.Name) 516 } 517 } 518 } 519 } 520 521 if len(initCalls) > 0 { 522 // Emit the init() function. 523 fmt.Fprintf(outputFile, "func init() {\n") 524 for _, ic := range initCalls { 525 fmt.Fprintf(outputFile, " %s\n", ic) 526 } 527 fmt.Fprintf(outputFile, "}\n") 528 } 529 }