github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/generator.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package gomarshal implements the go_marshal code generator. See README.md.
    16  package gomarshal
    17  
    18  import (
    19  	"bytes"
    20  	"fmt"
    21  	"go/ast"
    22  	"go/parser"
    23  	"go/token"
    24  	"os"
    25  	"sort"
    26  	"strings"
    27  
    28  	"github.com/SagerNet/gvisor/tools/tags"
    29  )
    30  
    31  // List of identifiers we use in generated code that may conflict with a
    32  // similarly-named source identifier. Abort gracefully when we see these to
    33  // avoid potentially confusing compilation failures in generated code.
    34  //
    35  // This only applies to import aliases at the moment. All other identifiers
    36  // are qualified by a receiver argument, since they're struct fields.
    37  //
    38  // All recievers are single letters, so we don't allow import aliases to be a
    39  // single letter.
    40  var badIdents = []string{
    41  	"addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx",
    42  	"inner", "length", "limit", "ptr", "size", "src", "srcs", "val",
    43  	// All single-letter identifiers.
    44  }
    45  
    46  // Constructed fromt badIdents in init().
    47  var badIdentsMap map[string]struct{}
    48  
    49  func init() {
    50  	badIdentsMap = make(map[string]struct{})
    51  	for _, ident := range badIdents {
    52  		badIdentsMap[ident] = struct{}{}
    53  	}
    54  }
    55  
    56  // Generator drives code generation for a single invocation of the go_marshal
    57  // utility.
    58  //
    59  // The Generator holds arguments passed to the tool, and drives parsing,
    60  // processing and code Generator for all types marked with +marshal declared in
    61  // the input files.
    62  //
    63  // See Generator.run() as the entry point.
    64  type Generator struct {
    65  	// Paths to input go source files.
    66  	inputs []string
    67  	// Output file to write generated go source.
    68  	output *os.File
    69  	// Output file to write generated tests.
    70  	outputTest *os.File
    71  	// Output file to write unconditionally generated tests.
    72  	outputTestUC *os.File
    73  	// Package name for the generated file.
    74  	pkg string
    75  	// Set of extra packages to import in the generated file.
    76  	imports *importTable
    77  }
    78  
    79  // NewGenerator creates a new code Generator.
    80  func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) {
    81  	f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
    82  	if err != nil {
    83  		return nil, fmt.Errorf("couldn't open output file %q: %w", out, err)
    84  	}
    85  	fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
    86  	if err != nil {
    87  		return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err)
    88  	}
    89  	fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
    90  	if err != nil {
    91  		return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err)
    92  	}
    93  	g := Generator{
    94  		inputs:       srcs,
    95  		output:       f,
    96  		outputTest:   fTest,
    97  		outputTestUC: fTestUC,
    98  		pkg:          pkg,
    99  		imports:      newImportTable(),
   100  	}
   101  	for _, i := range imports {
   102  		// All imports on the extra imports list are unconditionally marked as
   103  		// used, so that they're always added to the generated code.
   104  		g.imports.add(i).markUsed()
   105  	}
   106  
   107  	// The following imports may or may not be used by the generated code,
   108  	// depending on what's required for the target types. Don't mark these as
   109  	// used by default.
   110  	g.imports.add("io")
   111  	g.imports.add("reflect")
   112  	g.imports.add("runtime")
   113  	g.imports.add("unsafe")
   114  	g.imports.add("github.com/SagerNet/gvisor/pkg/gohacks")
   115  	g.imports.add("github.com/SagerNet/gvisor/pkg/hostarch")
   116  	g.imports.add("github.com/SagerNet/gvisor/pkg/marshal")
   117  	return &g, nil
   118  }
   119  
   120  // writeHeader writes the header for the generated source file. The header
   121  // includes the package name, package level comments and import statements.
   122  func (g *Generator) writeHeader() error {
   123  	var b sourceBuffer
   124  	b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
   125  
   126  	// Emit build tags.
   127  	b.emit("// If there are issues with build tag aggregation, see\n")
   128  	b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n")
   129  	b.emit("// come from the input set of files used to generate this file. This input set\n")
   130  	b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n")
   131  	b.emit("// tools/defs.bzl:calculate_sets().\n\n")
   132  
   133  	if t := tags.Aggregate(g.inputs); len(t) > 0 {
   134  		b.emit(strings.Join(t.Lines(), "\n"))
   135  		b.emit("\n\n")
   136  	}
   137  
   138  	// Package header.
   139  	b.emit("package %s\n\n", g.pkg)
   140  	if err := b.write(g.output); err != nil {
   141  		return err
   142  	}
   143  
   144  	return g.imports.write(g.output)
   145  }
   146  
   147  // writeTypeChecks writes a statement to force the compiler to perform a type
   148  // check for all Marshallable types referenced by the generated code.
   149  func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
   150  	if len(ms) == 0 {
   151  		return nil
   152  	}
   153  
   154  	msl := make([]string, 0, len(ms))
   155  	for m := range ms {
   156  		msl = append(msl, m)
   157  	}
   158  	sort.Strings(msl)
   159  
   160  	var buf bytes.Buffer
   161  	fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
   162  
   163  	for _, m := range msl {
   164  		fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
   165  	}
   166  	fmt.Fprint(&buf, "\n")
   167  
   168  	_, err := fmt.Fprint(g.output, buf.String())
   169  	return err
   170  }
   171  
   172  // parse processes all input files passed this generator and produces a set of
   173  // parsed go ASTs.
   174  func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
   175  	debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
   176  	for _, path := range g.inputs {
   177  		debugf("  %s\n", path)
   178  	}
   179  
   180  	files := make([]*ast.File, 0, len(g.inputs))
   181  	fsets := make([]*token.FileSet, 0, len(g.inputs))
   182  
   183  	for _, path := range g.inputs {
   184  		fset := token.NewFileSet()
   185  		f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
   186  		if err != nil {
   187  			// Not a valid input file?
   188  			return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err)
   189  		}
   190  
   191  		if debugEnabled() {
   192  			debugf("AST for %q:\n", path)
   193  			ast.Print(fset, f)
   194  		}
   195  
   196  		files = append(files, f)
   197  		fsets = append(fsets, fset)
   198  	}
   199  
   200  	return files, fsets, nil
   201  }
   202  
   203  // sliceAPI carries information about the '+marshal slice' directive.
   204  type sliceAPI struct {
   205  	// Comment node in the AST containing the +marshal tag.
   206  	comment *ast.Comment
   207  	// Identifier fragment to use when naming generated functions for the slice
   208  	// API.
   209  	ident string
   210  	// Whether the generated functions should reference the newtype name, or the
   211  	// inner type name. Only meaningful on newtype declarations on primitives.
   212  	inner bool
   213  }
   214  
   215  // marshallableType carries information about a type marked with the '+marshal'
   216  // directive.
   217  type marshallableType struct {
   218  	spec    *ast.TypeSpec
   219  	slice   *sliceAPI
   220  	recv    string
   221  	dynamic bool
   222  }
   223  
   224  func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType {
   225  	mt := &marshallableType{
   226  		spec:  spec,
   227  		slice: nil,
   228  	}
   229  
   230  	var unhandledTags []string
   231  
   232  	for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
   233  		if strings.HasPrefix(tag, "slice:") {
   234  			tokens := strings.Split(tag, ":")
   235  			if len(tokens) < 2 || len(tokens) > 3 {
   236  				abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
   237  			}
   238  			if len(tokens[1]) == 0 {
   239  				abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
   240  			}
   241  
   242  			sa := &sliceAPI{
   243  				comment: tagLine,
   244  				ident:   tokens[1],
   245  			}
   246  			mt.slice = sa
   247  
   248  			if len(tokens) == 3 {
   249  				if tokens[2] != "inner" {
   250  					abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
   251  				}
   252  				sa.inner = true
   253  			}
   254  
   255  			continue
   256  		} else if tag == "dynamic" {
   257  			mt.dynamic = true
   258  			continue
   259  		}
   260  
   261  		unhandledTags = append(unhandledTags, tag)
   262  	}
   263  
   264  	if len(unhandledTags) > 0 {
   265  		abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
   266  	}
   267  
   268  	return mt
   269  }
   270  
   271  // collectMarshallableTypes walks the parsed AST and collects a list of type
   272  // declarations for which we need to generate the Marshallable interface.
   273  func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType {
   274  	recv := make(map[string]string) // Type name to recevier name.
   275  	types := make(map[*ast.TypeSpec]*marshallableType)
   276  	for _, decl := range a.Decls {
   277  		gdecl, ok := decl.(*ast.GenDecl)
   278  		// Type declaration?
   279  		if !ok || gdecl.Tok != token.TYPE {
   280  			// Is this a function declaration? We remember receiver names.
   281  			d, ok := decl.(*ast.FuncDecl)
   282  			if ok && d.Recv != nil && len(d.Recv.List) == 1 {
   283  				// Accept concrete methods & pointer methods.
   284  				ident, ok := d.Recv.List[0].Type.(*ast.Ident)
   285  				if !ok {
   286  					var st *ast.StarExpr
   287  					st, ok = d.Recv.List[0].Type.(*ast.StarExpr)
   288  					if ok {
   289  						ident, ok = st.X.(*ast.Ident)
   290  					}
   291  				}
   292  				// The receiver name may be not present.
   293  				if ok && len(d.Recv.List[0].Names) == 1 {
   294  					// Recover the type receiver name in this case.
   295  					recv[ident.Name] = d.Recv.List[0].Names[0].Name
   296  				}
   297  			}
   298  			debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
   299  			continue
   300  		}
   301  		// Does it have a comment?
   302  		if gdecl.Doc == nil {
   303  			debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
   304  			continue
   305  		}
   306  		// Does the comment contain a "+marshal" line?
   307  		marked := false
   308  		var tagLine *ast.Comment
   309  		for _, c := range gdecl.Doc.List {
   310  			if strings.HasPrefix(c.Text, "// +marshal") {
   311  				marked = true
   312  				tagLine = c
   313  				break
   314  			}
   315  		}
   316  		if !marked {
   317  			debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
   318  			continue
   319  		}
   320  		for _, spec := range gdecl.Specs {
   321  			// We already confirmed we're in a type declaration earlier, so this
   322  			// cast will succeed.
   323  			t := spec.(*ast.TypeSpec)
   324  			switch t.Type.(type) {
   325  			case *ast.StructType:
   326  				debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
   327  			case *ast.Ident: // Newtype on primitive.
   328  				debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
   329  			case *ast.ArrayType: // Newtype on array.
   330  				debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
   331  			default:
   332  				// A user specifically requested marshalling on this type, but we
   333  				// don't support it.
   334  				abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
   335  			}
   336  			types[t] = newMarshallableType(f, tagLine, t)
   337  		}
   338  	}
   339  	// Update the types with the last seen receiver. As long as the
   340  	// receiver name is consistent for the type, then we will generate
   341  	// code that is still consistent with itself.
   342  	for t, mt := range types {
   343  		r, ok := recv[t.Name.Name]
   344  		if !ok {
   345  			mt.recv = receiverName(t) // Default.
   346  			continue
   347  		}
   348  		mt.recv = r // Last seen.
   349  	}
   350  	return types
   351  }
   352  
   353  // collectImports collects all imports from all input source files. Some of
   354  // these imports are copied to the generated output, if they're referenced by
   355  // the generated code.
   356  //
   357  // collectImports de-duplicates imports while building the list, and ensures
   358  // identifiers in the generated code don't conflict with any imported package
   359  // names.
   360  func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
   361  	is := make(map[string]importStmt)
   362  	for _, decl := range a.Decls {
   363  		gdecl, ok := decl.(*ast.GenDecl)
   364  		// Import statement?
   365  		if !ok || gdecl.Tok != token.IMPORT {
   366  			continue
   367  		}
   368  		for _, spec := range gdecl.Specs {
   369  			i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
   370  			debugf("Collected import '%s' as '%s'\n", i.path, i.name)
   371  
   372  			// Make sure we have an import that doesn't use any local names that
   373  			// would conflict with identifiers in the generated code.
   374  			if len(i.name) == 1 && i.name != "_" {
   375  				abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
   376  			}
   377  			if _, ok := badIdentsMap[i.name]; ok {
   378  				abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
   379  			}
   380  		}
   381  	}
   382  	return is
   383  
   384  }
   385  
   386  func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator {
   387  	i := newInterfaceGenerator(t.spec, t.recv, fset)
   388  	if t.dynamic {
   389  		if t.slice != nil {
   390  			abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.")
   391  		}
   392  		// No validation needed, assume the user knows what they are doing.
   393  		i.emitMarshallableForDynamicType()
   394  		return i
   395  	}
   396  	switch ty := t.spec.Type.(type) {
   397  	case *ast.StructType:
   398  		i.validateStruct(t.spec, ty)
   399  		i.emitMarshallableForStruct(ty)
   400  		if t.slice != nil {
   401  			i.emitMarshallableSliceForStruct(ty, t.slice)
   402  		}
   403  	case *ast.Ident:
   404  		i.validatePrimitiveNewtype(ty)
   405  		i.emitMarshallableForPrimitiveNewtype(ty)
   406  		if t.slice != nil {
   407  			i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
   408  		}
   409  	case *ast.ArrayType:
   410  		i.validateArrayNewtype(t.spec.Name, ty)
   411  		// After validate, we can safely call arrayLen.
   412  		i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
   413  		if t.slice != nil {
   414  			abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")
   415  		}
   416  	default:
   417  		// This should've been filtered out by collectMarshallabeTypes.
   418  		panic(fmt.Sprintf("Unexpected type %+v", ty))
   419  	}
   420  	return i
   421  }
   422  
   423  // generateOneTestSuite generates a test suite for the automatically generated
   424  // implementations type t.
   425  func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
   426  	i := newTestGenerator(t.spec, t.recv)
   427  	i.emitTests(t.slice)
   428  	return i
   429  }
   430  
   431  // Run is the entry point to code generation using g.
   432  //
   433  // Run parses all input source files specified in g and emits generated code.
   434  func (g *Generator) Run() error {
   435  	// Parse our input source files into ASTs and token sets.
   436  	asts, fsets, err := g.parse()
   437  	if err != nil {
   438  		return err
   439  	}
   440  
   441  	if len(asts) != len(fsets) {
   442  		panic("ASTs and FileSets don't match")
   443  	}
   444  
   445  	// Map of imports in source files; key = local package name, value = import
   446  	// path.
   447  	is := make(map[string]importStmt)
   448  	for i, a := range asts {
   449  		// Collect all imports from the source files. We may need to copy some
   450  		// of these to the generated code if they're referenced. This has to be
   451  		// done before the loop below because we need to process all ASTs before
   452  		// we start requesting imports to be copied one by one as we encounter
   453  		// them in each generated source.
   454  		for name, i := range g.collectImports(a, fsets[i]) {
   455  			is[name] = i
   456  		}
   457  	}
   458  
   459  	var impls []*interfaceGenerator
   460  	var ts []*testGenerator
   461  	// Set of Marshallable types referenced by generated code.
   462  	ms := make(map[string]struct{})
   463  	for i, a := range asts {
   464  		// Collect type declarations marked for code generation and generate
   465  		// Marshallable interfaces.
   466  		var sortedTypes []*marshallableType
   467  		for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
   468  			sortedTypes = append(sortedTypes, t)
   469  		}
   470  		sort.Slice(sortedTypes, func(x, y int) bool {
   471  			// Sort by type name, which should be unique within a package.
   472  			return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String()
   473  		})
   474  		for _, t := range sortedTypes {
   475  			impl := g.generateOne(t, fsets[i])
   476  			// Collect Marshallable types referenced by the generated code.
   477  			for ref := range impl.ms {
   478  				ms[ref] = struct{}{}
   479  			}
   480  			impls = append(impls, impl)
   481  			// Collect imports referenced by the generated code and add them to
   482  			// the list of imports we need to copy to the generated code.
   483  			for name := range impl.is {
   484  				if !g.imports.markUsed(name) {
   485  					panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
   486  				}
   487  			}
   488  			// Do not generate tests for dynamic types because they inherently
   489  			// violate some go_marshal requirements.
   490  			if !t.dynamic {
   491  				ts = append(ts, g.generateOneTestSuite(t))
   492  			}
   493  		}
   494  	}
   495  
   496  	// Write output file header. These include things like package name and
   497  	// import statements.
   498  	if err := g.writeHeader(); err != nil {
   499  		return err
   500  	}
   501  
   502  	// Write type checks for referenced marshallable types to output file.
   503  	if err := g.writeTypeChecks(ms); err != nil {
   504  		return err
   505  	}
   506  
   507  	// Write generated interfaces to output file.
   508  	for _, i := range impls {
   509  		if err := i.write(g.output); err != nil {
   510  			return err
   511  		}
   512  	}
   513  
   514  	// Write generated tests to test file.
   515  	return g.writeTests(ts)
   516  }
   517  
   518  // writeTests outputs tests for the generated interface implementations to a go
   519  // source file.
   520  func (g *Generator) writeTests(ts []*testGenerator) error {
   521  	var b sourceBuffer
   522  
   523  	// Write the unconditional test file. This file is always compiled,
   524  	// regardless of what build tags were specified on the original input
   525  	// files. We use this file to guarantee we never end up with an empty test
   526  	// file, as that causes the build to fail with "no tests/benchmarks/examples
   527  	// found".
   528  	//
   529  	// There's no easy way to determine ahead of time if we'll end up with an
   530  	// empty build file since build constraints can arbitrarily cause some of
   531  	// the original types to be not defined. We also have no way to tell bazel
   532  	// to omit the entire test suite since the output files are already defined
   533  	// before go-marshal is called.
   534  	b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
   535  	b.emit("package %s\n\n", g.pkg)
   536  	b.emit("func Example() {\n")
   537  	b.inIndent(func() {
   538  		b.emit("// This example is intentionally empty, and ensures this package contains at\n")
   539  		b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n")
   540  		b.emit("// input package is marked marshallable, but emitting no testable entities \n")
   541  		b.emit("// results in a build failure.\n")
   542  	})
   543  	b.emit("}\n")
   544  	if err := b.write(g.outputTestUC); err != nil {
   545  		return err
   546  	}
   547  
   548  	// Now generate the real test file that contains the real types we
   549  	// processed. These need to be conditionally compiled according to the build
   550  	// tags, as the original types may not be defined under all build
   551  	// configurations.
   552  
   553  	b.reset()
   554  	b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
   555  
   556  	// Emit build tags.
   557  	if t := tags.Aggregate(g.inputs); len(t) > 0 {
   558  		b.emit(strings.Join(t.Lines(), "\n"))
   559  		b.emit("\n\n")
   560  	}
   561  
   562  	b.emit("package %s\n\n", g.pkg)
   563  	if err := b.write(g.outputTest); err != nil {
   564  		return err
   565  	}
   566  
   567  	// Collect and write test import statements.
   568  	imports := newImportTable()
   569  	for _, t := range ts {
   570  		imports.merge(t.imports)
   571  	}
   572  
   573  	if err := imports.write(g.outputTest); err != nil {
   574  		return err
   575  	}
   576  
   577  	// Write test functions.
   578  	for _, t := range ts {
   579  		if err := t.write(g.outputTest); err != nil {
   580  			return err
   581  		}
   582  	}
   583  	return nil
   584  }