github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/generator.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package gomarshal implements the go_marshal code generator. See README.md. 16 package gomarshal 17 18 import ( 19 "bytes" 20 "fmt" 21 "go/ast" 22 "go/parser" 23 "go/token" 24 "os" 25 "sort" 26 "strings" 27 28 "github.com/SagerNet/gvisor/tools/tags" 29 ) 30 31 // List of identifiers we use in generated code that may conflict with a 32 // similarly-named source identifier. Abort gracefully when we see these to 33 // avoid potentially confusing compilation failures in generated code. 34 // 35 // This only applies to import aliases at the moment. All other identifiers 36 // are qualified by a receiver argument, since they're struct fields. 37 // 38 // All recievers are single letters, so we don't allow import aliases to be a 39 // single letter. 40 var badIdents = []string{ 41 "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx", 42 "inner", "length", "limit", "ptr", "size", "src", "srcs", "val", 43 // All single-letter identifiers. 44 } 45 46 // Constructed fromt badIdents in init(). 47 var badIdentsMap map[string]struct{} 48 49 func init() { 50 badIdentsMap = make(map[string]struct{}) 51 for _, ident := range badIdents { 52 badIdentsMap[ident] = struct{}{} 53 } 54 } 55 56 // Generator drives code generation for a single invocation of the go_marshal 57 // utility. 58 // 59 // The Generator holds arguments passed to the tool, and drives parsing, 60 // processing and code Generator for all types marked with +marshal declared in 61 // the input files. 62 // 63 // See Generator.run() as the entry point. 64 type Generator struct { 65 // Paths to input go source files. 66 inputs []string 67 // Output file to write generated go source. 68 output *os.File 69 // Output file to write generated tests. 70 outputTest *os.File 71 // Output file to write unconditionally generated tests. 72 outputTestUC *os.File 73 // Package name for the generated file. 74 pkg string 75 // Set of extra packages to import in the generated file. 76 imports *importTable 77 } 78 79 // NewGenerator creates a new code Generator. 80 func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { 81 f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 82 if err != nil { 83 return nil, fmt.Errorf("couldn't open output file %q: %w", out, err) 84 } 85 fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 86 if err != nil { 87 return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err) 88 } 89 fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 90 if err != nil { 91 return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err) 92 } 93 g := Generator{ 94 inputs: srcs, 95 output: f, 96 outputTest: fTest, 97 outputTestUC: fTestUC, 98 pkg: pkg, 99 imports: newImportTable(), 100 } 101 for _, i := range imports { 102 // All imports on the extra imports list are unconditionally marked as 103 // used, so that they're always added to the generated code. 104 g.imports.add(i).markUsed() 105 } 106 107 // The following imports may or may not be used by the generated code, 108 // depending on what's required for the target types. Don't mark these as 109 // used by default. 110 g.imports.add("io") 111 g.imports.add("reflect") 112 g.imports.add("runtime") 113 g.imports.add("unsafe") 114 g.imports.add("github.com/SagerNet/gvisor/pkg/gohacks") 115 g.imports.add("github.com/SagerNet/gvisor/pkg/hostarch") 116 g.imports.add("github.com/SagerNet/gvisor/pkg/marshal") 117 return &g, nil 118 } 119 120 // writeHeader writes the header for the generated source file. The header 121 // includes the package name, package level comments and import statements. 122 func (g *Generator) writeHeader() error { 123 var b sourceBuffer 124 b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") 125 126 // Emit build tags. 127 b.emit("// If there are issues with build tag aggregation, see\n") 128 b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n") 129 b.emit("// come from the input set of files used to generate this file. This input set\n") 130 b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n") 131 b.emit("// tools/defs.bzl:calculate_sets().\n\n") 132 133 if t := tags.Aggregate(g.inputs); len(t) > 0 { 134 b.emit(strings.Join(t.Lines(), "\n")) 135 b.emit("\n\n") 136 } 137 138 // Package header. 139 b.emit("package %s\n\n", g.pkg) 140 if err := b.write(g.output); err != nil { 141 return err 142 } 143 144 return g.imports.write(g.output) 145 } 146 147 // writeTypeChecks writes a statement to force the compiler to perform a type 148 // check for all Marshallable types referenced by the generated code. 149 func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { 150 if len(ms) == 0 { 151 return nil 152 } 153 154 msl := make([]string, 0, len(ms)) 155 for m := range ms { 156 msl = append(msl, m) 157 } 158 sort.Strings(msl) 159 160 var buf bytes.Buffer 161 fmt.Fprint(&buf, "// Marshallable types used by this file.\n") 162 163 for _, m := range msl { 164 fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) 165 } 166 fmt.Fprint(&buf, "\n") 167 168 _, err := fmt.Fprint(g.output, buf.String()) 169 return err 170 } 171 172 // parse processes all input files passed this generator and produces a set of 173 // parsed go ASTs. 174 func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { 175 debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) 176 for _, path := range g.inputs { 177 debugf(" %s\n", path) 178 } 179 180 files := make([]*ast.File, 0, len(g.inputs)) 181 fsets := make([]*token.FileSet, 0, len(g.inputs)) 182 183 for _, path := range g.inputs { 184 fset := token.NewFileSet() 185 f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) 186 if err != nil { 187 // Not a valid input file? 188 return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err) 189 } 190 191 if debugEnabled() { 192 debugf("AST for %q:\n", path) 193 ast.Print(fset, f) 194 } 195 196 files = append(files, f) 197 fsets = append(fsets, fset) 198 } 199 200 return files, fsets, nil 201 } 202 203 // sliceAPI carries information about the '+marshal slice' directive. 204 type sliceAPI struct { 205 // Comment node in the AST containing the +marshal tag. 206 comment *ast.Comment 207 // Identifier fragment to use when naming generated functions for the slice 208 // API. 209 ident string 210 // Whether the generated functions should reference the newtype name, or the 211 // inner type name. Only meaningful on newtype declarations on primitives. 212 inner bool 213 } 214 215 // marshallableType carries information about a type marked with the '+marshal' 216 // directive. 217 type marshallableType struct { 218 spec *ast.TypeSpec 219 slice *sliceAPI 220 recv string 221 dynamic bool 222 } 223 224 func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType { 225 mt := &marshallableType{ 226 spec: spec, 227 slice: nil, 228 } 229 230 var unhandledTags []string 231 232 for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { 233 if strings.HasPrefix(tag, "slice:") { 234 tokens := strings.Split(tag, ":") 235 if len(tokens) < 2 || len(tokens) > 3 { 236 abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) 237 } 238 if len(tokens[1]) == 0 { 239 abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") 240 } 241 242 sa := &sliceAPI{ 243 comment: tagLine, 244 ident: tokens[1], 245 } 246 mt.slice = sa 247 248 if len(tokens) == 3 { 249 if tokens[2] != "inner" { 250 abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") 251 } 252 sa.inner = true 253 } 254 255 continue 256 } else if tag == "dynamic" { 257 mt.dynamic = true 258 continue 259 } 260 261 unhandledTags = append(unhandledTags, tag) 262 } 263 264 if len(unhandledTags) > 0 { 265 abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) 266 } 267 268 return mt 269 } 270 271 // collectMarshallableTypes walks the parsed AST and collects a list of type 272 // declarations for which we need to generate the Marshallable interface. 273 func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType { 274 recv := make(map[string]string) // Type name to recevier name. 275 types := make(map[*ast.TypeSpec]*marshallableType) 276 for _, decl := range a.Decls { 277 gdecl, ok := decl.(*ast.GenDecl) 278 // Type declaration? 279 if !ok || gdecl.Tok != token.TYPE { 280 // Is this a function declaration? We remember receiver names. 281 d, ok := decl.(*ast.FuncDecl) 282 if ok && d.Recv != nil && len(d.Recv.List) == 1 { 283 // Accept concrete methods & pointer methods. 284 ident, ok := d.Recv.List[0].Type.(*ast.Ident) 285 if !ok { 286 var st *ast.StarExpr 287 st, ok = d.Recv.List[0].Type.(*ast.StarExpr) 288 if ok { 289 ident, ok = st.X.(*ast.Ident) 290 } 291 } 292 // The receiver name may be not present. 293 if ok && len(d.Recv.List[0].Names) == 1 { 294 // Recover the type receiver name in this case. 295 recv[ident.Name] = d.Recv.List[0].Names[0].Name 296 } 297 } 298 debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") 299 continue 300 } 301 // Does it have a comment? 302 if gdecl.Doc == nil { 303 debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") 304 continue 305 } 306 // Does the comment contain a "+marshal" line? 307 marked := false 308 var tagLine *ast.Comment 309 for _, c := range gdecl.Doc.List { 310 if strings.HasPrefix(c.Text, "// +marshal") { 311 marked = true 312 tagLine = c 313 break 314 } 315 } 316 if !marked { 317 debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") 318 continue 319 } 320 for _, spec := range gdecl.Specs { 321 // We already confirmed we're in a type declaration earlier, so this 322 // cast will succeed. 323 t := spec.(*ast.TypeSpec) 324 switch t.Type.(type) { 325 case *ast.StructType: 326 debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) 327 case *ast.Ident: // Newtype on primitive. 328 debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) 329 case *ast.ArrayType: // Newtype on array. 330 debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) 331 default: 332 // A user specifically requested marshalling on this type, but we 333 // don't support it. 334 abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) 335 } 336 types[t] = newMarshallableType(f, tagLine, t) 337 } 338 } 339 // Update the types with the last seen receiver. As long as the 340 // receiver name is consistent for the type, then we will generate 341 // code that is still consistent with itself. 342 for t, mt := range types { 343 r, ok := recv[t.Name.Name] 344 if !ok { 345 mt.recv = receiverName(t) // Default. 346 continue 347 } 348 mt.recv = r // Last seen. 349 } 350 return types 351 } 352 353 // collectImports collects all imports from all input source files. Some of 354 // these imports are copied to the generated output, if they're referenced by 355 // the generated code. 356 // 357 // collectImports de-duplicates imports while building the list, and ensures 358 // identifiers in the generated code don't conflict with any imported package 359 // names. 360 func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { 361 is := make(map[string]importStmt) 362 for _, decl := range a.Decls { 363 gdecl, ok := decl.(*ast.GenDecl) 364 // Import statement? 365 if !ok || gdecl.Tok != token.IMPORT { 366 continue 367 } 368 for _, spec := range gdecl.Specs { 369 i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) 370 debugf("Collected import '%s' as '%s'\n", i.path, i.name) 371 372 // Make sure we have an import that doesn't use any local names that 373 // would conflict with identifiers in the generated code. 374 if len(i.name) == 1 && i.name != "_" { 375 abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) 376 } 377 if _, ok := badIdentsMap[i.name]; ok { 378 abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) 379 } 380 } 381 } 382 return is 383 384 } 385 386 func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator { 387 i := newInterfaceGenerator(t.spec, t.recv, fset) 388 if t.dynamic { 389 if t.slice != nil { 390 abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") 391 } 392 // No validation needed, assume the user knows what they are doing. 393 i.emitMarshallableForDynamicType() 394 return i 395 } 396 switch ty := t.spec.Type.(type) { 397 case *ast.StructType: 398 i.validateStruct(t.spec, ty) 399 i.emitMarshallableForStruct(ty) 400 if t.slice != nil { 401 i.emitMarshallableSliceForStruct(ty, t.slice) 402 } 403 case *ast.Ident: 404 i.validatePrimitiveNewtype(ty) 405 i.emitMarshallableForPrimitiveNewtype(ty) 406 if t.slice != nil { 407 i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) 408 } 409 case *ast.ArrayType: 410 i.validateArrayNewtype(t.spec.Name, ty) 411 // After validate, we can safely call arrayLen. 412 i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) 413 if t.slice != nil { 414 abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?") 415 } 416 default: 417 // This should've been filtered out by collectMarshallabeTypes. 418 panic(fmt.Sprintf("Unexpected type %+v", ty)) 419 } 420 return i 421 } 422 423 // generateOneTestSuite generates a test suite for the automatically generated 424 // implementations type t. 425 func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { 426 i := newTestGenerator(t.spec, t.recv) 427 i.emitTests(t.slice) 428 return i 429 } 430 431 // Run is the entry point to code generation using g. 432 // 433 // Run parses all input source files specified in g and emits generated code. 434 func (g *Generator) Run() error { 435 // Parse our input source files into ASTs and token sets. 436 asts, fsets, err := g.parse() 437 if err != nil { 438 return err 439 } 440 441 if len(asts) != len(fsets) { 442 panic("ASTs and FileSets don't match") 443 } 444 445 // Map of imports in source files; key = local package name, value = import 446 // path. 447 is := make(map[string]importStmt) 448 for i, a := range asts { 449 // Collect all imports from the source files. We may need to copy some 450 // of these to the generated code if they're referenced. This has to be 451 // done before the loop below because we need to process all ASTs before 452 // we start requesting imports to be copied one by one as we encounter 453 // them in each generated source. 454 for name, i := range g.collectImports(a, fsets[i]) { 455 is[name] = i 456 } 457 } 458 459 var impls []*interfaceGenerator 460 var ts []*testGenerator 461 // Set of Marshallable types referenced by generated code. 462 ms := make(map[string]struct{}) 463 for i, a := range asts { 464 // Collect type declarations marked for code generation and generate 465 // Marshallable interfaces. 466 var sortedTypes []*marshallableType 467 for _, t := range g.collectMarshallableTypes(a, fsets[i]) { 468 sortedTypes = append(sortedTypes, t) 469 } 470 sort.Slice(sortedTypes, func(x, y int) bool { 471 // Sort by type name, which should be unique within a package. 472 return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String() 473 }) 474 for _, t := range sortedTypes { 475 impl := g.generateOne(t, fsets[i]) 476 // Collect Marshallable types referenced by the generated code. 477 for ref := range impl.ms { 478 ms[ref] = struct{}{} 479 } 480 impls = append(impls, impl) 481 // Collect imports referenced by the generated code and add them to 482 // the list of imports we need to copy to the generated code. 483 for name := range impl.is { 484 if !g.imports.markUsed(name) { 485 panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) 486 } 487 } 488 // Do not generate tests for dynamic types because they inherently 489 // violate some go_marshal requirements. 490 if !t.dynamic { 491 ts = append(ts, g.generateOneTestSuite(t)) 492 } 493 } 494 } 495 496 // Write output file header. These include things like package name and 497 // import statements. 498 if err := g.writeHeader(); err != nil { 499 return err 500 } 501 502 // Write type checks for referenced marshallable types to output file. 503 if err := g.writeTypeChecks(ms); err != nil { 504 return err 505 } 506 507 // Write generated interfaces to output file. 508 for _, i := range impls { 509 if err := i.write(g.output); err != nil { 510 return err 511 } 512 } 513 514 // Write generated tests to test file. 515 return g.writeTests(ts) 516 } 517 518 // writeTests outputs tests for the generated interface implementations to a go 519 // source file. 520 func (g *Generator) writeTests(ts []*testGenerator) error { 521 var b sourceBuffer 522 523 // Write the unconditional test file. This file is always compiled, 524 // regardless of what build tags were specified on the original input 525 // files. We use this file to guarantee we never end up with an empty test 526 // file, as that causes the build to fail with "no tests/benchmarks/examples 527 // found". 528 // 529 // There's no easy way to determine ahead of time if we'll end up with an 530 // empty build file since build constraints can arbitrarily cause some of 531 // the original types to be not defined. We also have no way to tell bazel 532 // to omit the entire test suite since the output files are already defined 533 // before go-marshal is called. 534 b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") 535 b.emit("package %s\n\n", g.pkg) 536 b.emit("func Example() {\n") 537 b.inIndent(func() { 538 b.emit("// This example is intentionally empty, and ensures this package contains at\n") 539 b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n") 540 b.emit("// input package is marked marshallable, but emitting no testable entities \n") 541 b.emit("// results in a build failure.\n") 542 }) 543 b.emit("}\n") 544 if err := b.write(g.outputTestUC); err != nil { 545 return err 546 } 547 548 // Now generate the real test file that contains the real types we 549 // processed. These need to be conditionally compiled according to the build 550 // tags, as the original types may not be defined under all build 551 // configurations. 552 553 b.reset() 554 b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") 555 556 // Emit build tags. 557 if t := tags.Aggregate(g.inputs); len(t) > 0 { 558 b.emit(strings.Join(t.Lines(), "\n")) 559 b.emit("\n\n") 560 } 561 562 b.emit("package %s\n\n", g.pkg) 563 if err := b.write(g.outputTest); err != nil { 564 return err 565 } 566 567 // Collect and write test import statements. 568 imports := newImportTable() 569 for _, t := range ts { 570 imports.merge(t.imports) 571 } 572 573 if err := imports.write(g.outputTest); err != nil { 574 return err 575 } 576 577 // Write test functions. 578 for _, t := range ts { 579 if err := t.write(g.outputTest); err != nil { 580 return err 581 } 582 } 583 return nil 584 }