gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_marshal/gomarshal/generator.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package gomarshal implements the go_marshal code generator. See README.md. 16 package gomarshal 17 18 import ( 19 "bytes" 20 "fmt" 21 "go/ast" 22 "go/parser" 23 "go/token" 24 "os" 25 "sort" 26 "strings" 27 28 "gvisor.dev/gvisor/tools/constraintutil" 29 ) 30 31 // List of identifiers we use in generated code that may conflict with a 32 // similarly-named source identifier. Abort gracefully when we see these to 33 // avoid potentially confusing compilation failures in generated code. 34 // 35 // This only applies to import aliases at the moment. All other identifiers 36 // are qualified by a receiver argument, since they're struct fields. 37 // 38 // All recievers are single letters, so we don't allow import aliases to be a 39 // single letter. 40 var badIdents = []string{ 41 "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx", 42 "inner", "length", "limit", "ptr", "size", "src", "srcs", "val", 43 // All single-letter identifiers. 44 } 45 46 // Constructed fromt badIdents in init(). 47 var badIdentsMap map[string]struct{} 48 49 func init() { 50 badIdentsMap = make(map[string]struct{}) 51 for _, ident := range badIdents { 52 badIdentsMap[ident] = struct{}{} 53 } 54 } 55 56 // Generator drives code generation for a single invocation of the go_marshal 57 // utility. 58 // 59 // The Generator holds arguments passed to the tool, and drives parsing, 60 // processing and code Generator for all types marked with +marshal declared in 61 // the input files. 62 // 63 // See Generator.run() as the entry point. 64 type Generator struct { 65 // Paths to input go source files. 66 inputs []string 67 // Output file to write generated go source. 68 output *os.File 69 // Output file to write generated tests. 70 outputTest *os.File 71 // Output file to write unconditionally generated tests. 72 outputTestUC *os.File 73 // Package name for the generated file. 74 pkg string 75 // Set of extra packages to import in the generated file. 76 imports *importTable 77 } 78 79 // NewGenerator creates a new code Generator. 80 func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { 81 f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 82 if err != nil { 83 return nil, fmt.Errorf("couldn't open output file %q: %w", out, err) 84 } 85 fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 86 if err != nil { 87 return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err) 88 } 89 fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 90 if err != nil { 91 return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err) 92 } 93 g := Generator{ 94 inputs: srcs, 95 output: f, 96 outputTest: fTest, 97 outputTestUC: fTestUC, 98 pkg: pkg, 99 imports: newImportTable(), 100 } 101 for _, i := range imports { 102 // All imports on the extra imports list are unconditionally marked as 103 // used, so that they're always added to the generated code. 104 g.imports.add(i).markUsed() 105 } 106 107 // The following imports may or may not be used by the generated code, 108 // depending on what's required for the target types. Don't mark these as 109 // used by default. 110 g.imports.add("io") 111 g.imports.add("reflect") 112 g.imports.add("runtime") 113 g.imports.add("unsafe") 114 g.imports.add("gvisor.dev/gvisor/pkg/gohacks") 115 g.imports.add("gvisor.dev/gvisor/pkg/hostarch") 116 g.imports.add("gvisor.dev/gvisor/pkg/marshal") 117 return &g, nil 118 } 119 120 // writeHeader writes the header for the generated source file. The header 121 // includes the package name, package level comments and import statements. 122 func (g *Generator) writeHeader() error { 123 var b sourceBuffer 124 b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") 125 126 bcexpr, err := constraintutil.CombineFromFiles(g.inputs) 127 if err != nil { 128 return err 129 } 130 if bcexpr != nil { 131 // Emit build constraints. 132 b.emit("// If there are issues with build constraint aggregation, see\n") 133 b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here\n") 134 b.emit("// come from the input set of files used to generate this file. This input set\n") 135 b.emit("// is filtered based on pre-defined file suffixes related to build constraints,\n") 136 b.emit("// see tools/defs.bzl:calculate_sets().\n\n") 137 b.emit(constraintutil.Lines(bcexpr)) 138 } 139 140 // Package header. 141 b.emit("package %s\n\n", g.pkg) 142 if err := b.write(g.output); err != nil { 143 return err 144 } 145 146 return g.imports.write(g.output) 147 } 148 149 // writeTypeChecks writes a statement to force the compiler to perform a type 150 // check for all Marshallable types referenced by the generated code. 151 func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { 152 if len(ms) == 0 { 153 return nil 154 } 155 156 msl := make([]string, 0, len(ms)) 157 for m := range ms { 158 msl = append(msl, m) 159 } 160 sort.Strings(msl) 161 162 var buf bytes.Buffer 163 fmt.Fprint(&buf, "// Marshallable types used by this file.\n") 164 165 for _, m := range msl { 166 fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) 167 } 168 fmt.Fprint(&buf, "\n") 169 170 _, err := fmt.Fprint(g.output, buf.String()) 171 return err 172 } 173 174 // parse processes all input files passed this generator and produces a set of 175 // parsed go ASTs. 176 func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { 177 debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) 178 for _, path := range g.inputs { 179 debugf(" %s\n", path) 180 } 181 182 files := make([]*ast.File, 0, len(g.inputs)) 183 fsets := make([]*token.FileSet, 0, len(g.inputs)) 184 185 for _, path := range g.inputs { 186 fset := token.NewFileSet() 187 f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) 188 if err != nil { 189 // Not a valid input file? 190 return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err) 191 } 192 193 if debugEnabled() { 194 debugf("AST for %q:\n", path) 195 ast.Print(fset, f) 196 } 197 198 files = append(files, f) 199 fsets = append(fsets, fset) 200 } 201 202 return files, fsets, nil 203 } 204 205 // sliceAPI carries information about the '+marshal slice' directive. 206 type sliceAPI struct { 207 // Comment node in the AST containing the +marshal tag. 208 comment *ast.Comment 209 // Identifier fragment to use when naming generated functions for the slice 210 // API. 211 ident string 212 // Whether the generated functions should reference the newtype name, or the 213 // inner type name. Only meaningful on newtype declarations on primitives. 214 inner bool 215 } 216 217 // marshallableType carries information about a type marked with the '+marshal' 218 // directive. 219 type marshallableType struct { 220 spec *ast.TypeSpec 221 slice *sliceAPI 222 recv string 223 dynamic bool 224 boundCheck bool 225 } 226 227 func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType { 228 mt := &marshallableType{ 229 spec: spec, 230 slice: nil, 231 } 232 233 var unhandledTags []string 234 235 for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { 236 if strings.HasPrefix(tag, "slice:") { 237 tokens := strings.Split(tag, ":") 238 if len(tokens) < 2 || len(tokens) > 3 { 239 abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) 240 } 241 if len(tokens[1]) == 0 { 242 abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") 243 } 244 245 sa := &sliceAPI{ 246 comment: tagLine, 247 ident: tokens[1], 248 } 249 mt.slice = sa 250 251 if len(tokens) == 3 { 252 if tokens[2] != "inner" { 253 abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") 254 } 255 sa.inner = true 256 } 257 258 continue 259 } else if tag == "dynamic" { 260 mt.dynamic = true 261 continue 262 } else if tag == "boundCheck" { 263 mt.boundCheck = true 264 continue 265 } 266 267 unhandledTags = append(unhandledTags, tag) 268 } 269 270 if len(unhandledTags) > 0 { 271 abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) 272 } 273 274 return mt 275 } 276 277 // collectMarshallableTypes walks the parsed AST and collects a list of type 278 // declarations for which we need to generate the Marshallable interface. 279 func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType { 280 recv := make(map[string]string) // Type name to recevier name. 281 types := make(map[*ast.TypeSpec]*marshallableType) 282 for _, decl := range a.Decls { 283 gdecl, ok := decl.(*ast.GenDecl) 284 // Type declaration? 285 if !ok || gdecl.Tok != token.TYPE { 286 // Is this a function declaration? We remember receiver names. 287 d, ok := decl.(*ast.FuncDecl) 288 if ok && d.Recv != nil && len(d.Recv.List) == 1 { 289 // Accept concrete methods & pointer methods. 290 ident, ok := d.Recv.List[0].Type.(*ast.Ident) 291 if !ok { 292 var st *ast.StarExpr 293 st, ok = d.Recv.List[0].Type.(*ast.StarExpr) 294 if ok { 295 ident, ok = st.X.(*ast.Ident) 296 } 297 } 298 // The receiver name may be not present. 299 if ok && len(d.Recv.List[0].Names) == 1 { 300 // Recover the type receiver name in this case. 301 recv[ident.Name] = d.Recv.List[0].Names[0].Name 302 } 303 } 304 debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") 305 continue 306 } 307 // Does it have a comment? 308 if gdecl.Doc == nil { 309 debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") 310 continue 311 } 312 // Does the comment contain a "+marshal" line? 313 marked := false 314 var tagLine *ast.Comment 315 for _, c := range gdecl.Doc.List { 316 if strings.HasPrefix(c.Text, "// +marshal") { 317 marked = true 318 tagLine = c 319 break 320 } 321 } 322 if !marked { 323 debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") 324 continue 325 } 326 for _, spec := range gdecl.Specs { 327 // We already confirmed we're in a type declaration earlier, so this 328 // cast will succeed. 329 t := spec.(*ast.TypeSpec) 330 switch t.Type.(type) { 331 case *ast.StructType: 332 debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) 333 case *ast.Ident: // Newtype on primitive. 334 debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) 335 case *ast.ArrayType: // Newtype on array. 336 debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) 337 default: 338 // A user specifically requested marshalling on this type, but we 339 // don't support it. 340 abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) 341 } 342 types[t] = newMarshallableType(f, tagLine, t) 343 } 344 } 345 // Update the types with the last seen receiver. As long as the 346 // receiver name is consistent for the type, then we will generate 347 // code that is still consistent with itself. 348 for t, mt := range types { 349 r, ok := recv[t.Name.Name] 350 if !ok { 351 mt.recv = receiverName(t) // Default. 352 continue 353 } 354 mt.recv = r // Last seen. 355 } 356 return types 357 } 358 359 // collectImports collects all imports from all input source files. Some of 360 // these imports are copied to the generated output, if they're referenced by 361 // the generated code. 362 // 363 // collectImports de-duplicates imports while building the list, and ensures 364 // identifiers in the generated code don't conflict with any imported package 365 // names. 366 func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { 367 is := make(map[string]importStmt) 368 for _, decl := range a.Decls { 369 gdecl, ok := decl.(*ast.GenDecl) 370 // Import statement? 371 if !ok || gdecl.Tok != token.IMPORT { 372 continue 373 } 374 for _, spec := range gdecl.Specs { 375 i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) 376 debugf("Collected import '%s' as '%s'\n", i.path, i.name) 377 378 // Make sure we have an import that doesn't use any local names that 379 // would conflict with identifiers in the generated code. 380 if len(i.name) == 1 && i.name != "_" { 381 abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) 382 } 383 if _, ok := badIdentsMap[i.name]; ok { 384 abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) 385 } 386 } 387 } 388 return is 389 390 } 391 392 func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator { 393 i := newInterfaceGenerator(t.spec, t.recv, fset) 394 if t.dynamic { 395 if t.slice != nil { 396 abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") 397 } 398 if t.boundCheck { 399 abortAt(fset.Position(t.slice.comment.Slash), "Can not generate Checked methods for dynamic types. Has to be implemented manually.") 400 } 401 // No validation needed, assume the user knows what they are doing. 402 i.emitMarshallableForDynamicType() 403 return i 404 } 405 switch ty := t.spec.Type.(type) { 406 case *ast.StructType: 407 i.validateStruct(t.spec, ty) 408 i.emitMarshallableForStruct(ty) 409 if t.boundCheck { 410 i.emitCheckedMarshallableForStruct() 411 } 412 if t.slice != nil { 413 i.emitMarshallableSliceForStruct(ty, t.slice) 414 } 415 case *ast.Ident: 416 i.validatePrimitiveNewtype(ty) 417 i.emitMarshallableForPrimitiveNewtype(ty) 418 if t.boundCheck { 419 i.emitCheckedMarshallableForPrimitiveNewtype() 420 } 421 if t.slice != nil { 422 i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) 423 } 424 case *ast.ArrayType: 425 i.validateArrayNewtype(t.spec.Name, ty) 426 // After validate, we can safely call arrayLen. 427 i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) 428 if t.boundCheck { 429 i.emitCheckedMarshallableForArrayNewtype() 430 } 431 if t.slice != nil { 432 abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?") 433 } 434 default: 435 // This should've been filtered out by collectMarshallabeTypes. 436 panic(fmt.Sprintf("Unexpected type %+v", ty)) 437 } 438 return i 439 } 440 441 // generateOneTestSuite generates a test suite for the automatically generated 442 // implementations type t. 443 func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { 444 i := newTestGenerator(t.spec, t.recv) 445 i.emitTests(t.slice, t.boundCheck) 446 return i 447 } 448 449 // Run is the entry point to code generation using g. 450 // 451 // Run parses all input source files specified in g and emits generated code. 452 func (g *Generator) Run() error { 453 // Parse our input source files into ASTs and token sets. 454 asts, fsets, err := g.parse() 455 if err != nil { 456 return err 457 } 458 459 if len(asts) != len(fsets) { 460 panic("ASTs and FileSets don't match") 461 } 462 463 // Map of imports in source files; key = local package name, value = import 464 // path. 465 is := make(map[string]importStmt) 466 for i, a := range asts { 467 // Collect all imports from the source files. We may need to copy some 468 // of these to the generated code if they're referenced. This has to be 469 // done before the loop below because we need to process all ASTs before 470 // we start requesting imports to be copied one by one as we encounter 471 // them in each generated source. 472 for name, i := range g.collectImports(a, fsets[i]) { 473 is[name] = i 474 } 475 } 476 477 var impls []*interfaceGenerator 478 var ts []*testGenerator 479 // Set of Marshallable types referenced by generated code. 480 ms := make(map[string]struct{}) 481 for i, a := range asts { 482 // Collect type declarations marked for code generation and generate 483 // Marshallable interfaces. 484 var sortedTypes []*marshallableType 485 for _, t := range g.collectMarshallableTypes(a, fsets[i]) { 486 sortedTypes = append(sortedTypes, t) 487 } 488 sort.Slice(sortedTypes, func(x, y int) bool { 489 // Sort by type name, which should be unique within a package. 490 return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String() 491 }) 492 for _, t := range sortedTypes { 493 impl := g.generateOne(t, fsets[i]) 494 // Collect Marshallable types referenced by the generated code. 495 for ref := range impl.ms { 496 ms[ref] = struct{}{} 497 } 498 impls = append(impls, impl) 499 // Collect imports referenced by the generated code and add them to 500 // the list of imports we need to copy to the generated code. 501 for name := range impl.is { 502 if !g.imports.markUsed(name) { 503 panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) 504 } 505 } 506 // Do not generate tests for dynamic types because they inherently 507 // violate some go_marshal requirements. 508 if !t.dynamic { 509 ts = append(ts, g.generateOneTestSuite(t)) 510 } 511 } 512 } 513 514 // Write output file header. These include things like package name and 515 // import statements. 516 if err := g.writeHeader(); err != nil { 517 return err 518 } 519 520 // Write type checks for referenced marshallable types to output file. 521 if err := g.writeTypeChecks(ms); err != nil { 522 return err 523 } 524 525 // Write generated interfaces to output file. 526 for _, i := range impls { 527 if err := i.write(g.output); err != nil { 528 return err 529 } 530 } 531 532 // Write generated tests to test file. 533 return g.writeTests(ts) 534 } 535 536 // writeTests outputs tests for the generated interface implementations to a go 537 // source file. 538 func (g *Generator) writeTests(ts []*testGenerator) error { 539 var b sourceBuffer 540 541 // Write the unconditional test file. This file is always compiled, 542 // regardless of what build tags were specified on the original input 543 // files. We use this file to guarantee we never end up with an empty test 544 // file, as that causes the build to fail with "no tests/benchmarks/examples 545 // found". 546 // 547 // There's no easy way to determine ahead of time if we'll end up with an 548 // empty build file since build constraints can arbitrarily cause some of 549 // the original types to be not defined. We also have no way to tell bazel 550 // to omit the entire test suite since the output files are already defined 551 // before go-marshal is called. 552 b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") 553 b.emit("package %s\n\n", g.pkg) 554 b.emit("func Example() {\n") 555 b.inIndent(func() { 556 b.emit("// This example is intentionally empty, and ensures this package contains at\n") 557 b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n") 558 b.emit("// input package is marked marshallable, but emitting no testable entities \n") 559 b.emit("// results in a build failure.\n") 560 }) 561 b.emit("}\n") 562 if err := b.write(g.outputTestUC); err != nil { 563 return err 564 } 565 566 // Now generate the real test file that contains the real types we 567 // processed. These need to be conditionally compiled according to the build 568 // tags, as the original types may not be defined under all build 569 // configurations. 570 571 b.reset() 572 b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") 573 574 // Emit build constraints. 575 bcexpr, err := constraintutil.CombineFromFiles(g.inputs) 576 if err != nil { 577 return err 578 } 579 b.emit(constraintutil.Lines(bcexpr)) 580 581 b.emit("package %s\n\n", g.pkg) 582 if err := b.write(g.outputTest); err != nil { 583 return err 584 } 585 586 // Collect and write test import statements. 587 imports := newImportTable() 588 for _, t := range ts { 589 imports.merge(t.imports) 590 } 591 592 if err := imports.write(g.outputTest); err != nil { 593 return err 594 } 595 596 // Write test functions. 597 for _, t := range ts { 598 if err := t.write(g.outputTest); err != nil { 599 return err 600 } 601 } 602 return nil 603 }