gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_marshal/gomarshal/util.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomarshal
    16  
    17  import (
    18  	"bytes"
    19  	"flag"
    20  	"fmt"
    21  	"go/ast"
    22  	"go/token"
    23  	"io"
    24  	"os"
    25  	"path"
    26  	"reflect"
    27  	"sort"
    28  	"strings"
    29  )
    30  
    31  var debug = flag.Bool("debug", false, "enables debugging output")
    32  
    33  // receiverName returns an appropriate receiver name given a type spec.
    34  func receiverName(t *ast.TypeSpec) string {
    35  	if len(t.Name.Name) < 1 {
    36  		// Zero length type name?
    37  		panic("unreachable")
    38  	}
    39  	return strings.ToLower(t.Name.Name[:1])
    40  }
    41  
    42  // kindString returns a user-friendly representation of an AST expr type.
    43  func kindString(e ast.Expr) string {
    44  	switch e.(type) {
    45  	case *ast.Ident:
    46  		return "scalar"
    47  	case *ast.ArrayType:
    48  		return "array"
    49  	case *ast.StructType:
    50  		return "struct"
    51  	case *ast.StarExpr:
    52  		return "pointer"
    53  	case *ast.FuncType:
    54  		return "function"
    55  	case *ast.InterfaceType:
    56  		return "interface"
    57  	case *ast.MapType:
    58  		return "map"
    59  	case *ast.ChanType:
    60  		return "channel"
    61  	default:
    62  		return reflect.TypeOf(e).String()
    63  	}
    64  }
    65  
    66  func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
    67  	for _, field := range st.Fields.List {
    68  		fn(field)
    69  	}
    70  }
    71  
    72  // fieldDispatcher is a collection of callbacks for handling different types of
    73  // fields in a struct declaration.
    74  type fieldDispatcher struct {
    75  	primitive func(n, t *ast.Ident)
    76  	selector  func(n, tX, tSel *ast.Ident)
    77  	array     func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
    78  	unhandled func(n *ast.Ident)
    79  }
    80  
    81  // Precondition: All dispatch callbacks that will be invoked must be
    82  // provided.
    83  func (fd fieldDispatcher) dispatch(f *ast.Field) {
    84  	// Each field declaration may actually be multiple declarations of the same
    85  	// type. For example, consider:
    86  	//
    87  	// type Point struct {
    88  	//     x, y, z int
    89  	// }
    90  	//
    91  	// We invoke the call-backs once per such instance.
    92  
    93  	// Handle embedded fields. Embedded fields have no names, but can be
    94  	// referenced by the type name.
    95  	if len(f.Names) < 1 {
    96  		switch v := f.Type.(type) {
    97  		case *ast.Ident:
    98  			fd.primitive(v, v)
    99  		case *ast.SelectorExpr:
   100  			fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel)
   101  		default:
   102  			// Note: Arrays can't be embedded, which is handled here.
   103  			panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type))
   104  		}
   105  		return
   106  	}
   107  
   108  	// Non-embedded field.
   109  	for _, name := range f.Names {
   110  		switch v := f.Type.(type) {
   111  		case *ast.Ident:
   112  			fd.primitive(name, v)
   113  		case *ast.SelectorExpr:
   114  			fd.selector(name, v.X.(*ast.Ident), v.Sel)
   115  		case *ast.ArrayType:
   116  			switch t := v.Elt.(type) {
   117  			case *ast.Ident:
   118  				fd.array(name, v, t)
   119  			default:
   120  				// Should be handled with a better error message during validate.
   121  				panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
   122  			}
   123  		default:
   124  			fd.unhandled(name)
   125  		}
   126  	}
   127  }
   128  
   129  // debugEnabled indicates whether debugging is enabled for gomarshal.
   130  func debugEnabled() bool {
   131  	return *debug
   132  }
   133  
   134  // abort aborts the go_marshal tool with the given error message.
   135  func abort(msg string) {
   136  	if !strings.HasSuffix(msg, "\n") {
   137  		msg += "\n"
   138  	}
   139  	fmt.Print(msg)
   140  	os.Exit(1)
   141  }
   142  
   143  // abortAt aborts the go_marshal tool with the given error message, with
   144  // a reference position to the input source.
   145  func abortAt(p token.Position, msg string) {
   146  	abort(fmt.Sprintf("%v:\n  %s\n", p, msg))
   147  }
   148  
   149  // debugf conditionally prints a debug message.
   150  func debugf(f string, a ...any) {
   151  	if debugEnabled() {
   152  		fmt.Printf(f, a...)
   153  	}
   154  }
   155  
   156  // debugfAt conditionally prints a debug message with a reference to a position
   157  // in the input source.
   158  func debugfAt(p token.Position, f string, a ...any) {
   159  	if debugEnabled() {
   160  		fmt.Printf("%s:\n  %s", p, fmt.Sprintf(f, a...))
   161  	}
   162  }
   163  
   164  // emit generates a line of code in the output file.
   165  //
   166  // emit is a wrapper around writing a formatted string to the output
   167  // buffer. emit can be invoked in one of two ways:
   168  //
   169  // (1) emit("some string")
   170  //
   171  //	When emit is called with a single string argument, it is simply copied to
   172  //	the output buffer without any further formatting.
   173  //
   174  // (2) emit(fmtString, args...)
   175  //
   176  //	emit can also be invoked in a similar fashion to *Printf() functions,
   177  //	where the first argument is a format string.
   178  //
   179  // Calling emit with a single argument that is not a string will result in a
   180  // panic, as the caller's intent is ambiguous.
   181  func emit(out io.Writer, indent int, a ...any) {
   182  	const spacesPerIndentLevel = 4
   183  
   184  	if len(a) < 1 {
   185  		panic("emit() called with no arguments")
   186  	}
   187  
   188  	if indent > 0 {
   189  		if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
   190  			// Writing to the emit output should not fail. Typically the output
   191  			// is a byte.Buffer; writes to these never fail.
   192  			panic(err)
   193  		}
   194  	}
   195  
   196  	first, ok := a[0].(string)
   197  	if !ok {
   198  		// First argument must be either the string to emit (case 1 from
   199  		// function-level comment), or a format string (case 2).
   200  		panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
   201  	}
   202  
   203  	if len(a) == 1 {
   204  		// Single string argument. Assume no formatting requested.
   205  		if _, err := fmt.Fprint(out, first); err != nil {
   206  			// Writing to out should not fail.
   207  			panic(err)
   208  		}
   209  		return
   210  
   211  	}
   212  
   213  	// Formatting requested.
   214  	if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
   215  		// Writing to out should not fail.
   216  		panic(err)
   217  	}
   218  }
   219  
   220  // sourceBuffer represents fragments of generated go source code.
   221  //
   222  // sourceBuffer provides a convenient way to build up go souce fragments in
   223  // memory. May be safely zero-value initialized. Not thread-safe.
   224  type sourceBuffer struct {
   225  	// Current indentation level.
   226  	indent int
   227  
   228  	// Memory buffer containing contents while they're being generated.
   229  	b bytes.Buffer
   230  }
   231  
   232  func (b *sourceBuffer) reset() {
   233  	b.indent = 0
   234  	b.b.Reset()
   235  }
   236  
   237  func (b *sourceBuffer) incIndent() {
   238  	b.indent++
   239  }
   240  
   241  func (b *sourceBuffer) decIndent() {
   242  	if b.indent <= 0 {
   243  		panic("decIndent() without matching incIndent()")
   244  	}
   245  	b.indent--
   246  }
   247  
   248  func (b *sourceBuffer) emit(a ...any) {
   249  	emit(&b.b, b.indent, a...)
   250  }
   251  
   252  func (b *sourceBuffer) emitNoIndent(a ...any) {
   253  	emit(&b.b, 0 /*indent*/, a...)
   254  }
   255  
   256  func (b *sourceBuffer) inIndent(body func()) {
   257  	b.incIndent()
   258  	body()
   259  	b.decIndent()
   260  }
   261  
   262  func (b *sourceBuffer) write(out io.Writer) error {
   263  	_, err := fmt.Fprint(out, b.b.String())
   264  	return err
   265  }
   266  
   267  // Write implements io.Writer.Write.
   268  func (b *sourceBuffer) Write(buf []byte) (int, error) {
   269  	return (b.b.Write(buf))
   270  }
   271  
   272  // importStmt represents a single import statement.
   273  type importStmt struct {
   274  	// Local name of the imported package.
   275  	name string
   276  	// Import path.
   277  	path string
   278  	// Indicates whether the local name is an alias, or simply the final
   279  	// component of the path.
   280  	aliased bool
   281  	// Indicates whether this import was referenced by generated code.
   282  	used bool
   283  	// AST node and file set representing the import statement, if any. These
   284  	// are only non-nil if the import statement originates from an input source
   285  	// file.
   286  	spec *ast.ImportSpec
   287  	fset *token.FileSet
   288  }
   289  
   290  func newImport(p string) *importStmt {
   291  	name := path.Base(p)
   292  	return &importStmt{
   293  		name:    name,
   294  		path:    p,
   295  		aliased: false,
   296  	}
   297  }
   298  
   299  func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
   300  	p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
   301  	name := path.Base(p)
   302  	if name == "" || name == "/" || name == "." {
   303  		panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
   304  			f.Position(spec.Path.Pos()), name))
   305  	}
   306  	if spec.Name != nil {
   307  		name = spec.Name.Name
   308  	}
   309  	return &importStmt{
   310  		name:    name,
   311  		path:    p,
   312  		aliased: spec.Name != nil,
   313  		spec:    spec,
   314  		fset:    f,
   315  	}
   316  }
   317  
   318  // String implements fmt.Stringer.String. This generates a string for the import
   319  // statement appropriate for writing directly to generated code.
   320  func (i *importStmt) String() string {
   321  	if i.aliased {
   322  		return fmt.Sprintf("%s %q", i.name, i.path)
   323  	}
   324  	return fmt.Sprintf("%q", i.path)
   325  }
   326  
   327  // debugString returns a debug string representing an import statement. This
   328  // representation is not valid golang code and is used for debugging output.
   329  func (i *importStmt) debugString() string {
   330  	if i.spec != nil && i.fset != nil {
   331  		return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
   332  	}
   333  	return fmt.Sprintf("(go-marshal import): %s", i)
   334  }
   335  
   336  func (i *importStmt) markUsed() {
   337  	i.used = true
   338  }
   339  
   340  func (i *importStmt) equivalent(other *importStmt) bool {
   341  	return i.name == other.name && i.path == other.path && i.aliased == other.aliased
   342  }
   343  
   344  // importTable represents a collection of importStmts.
   345  //
   346  // An importTable may contain multiple import statements referencing the same
   347  // local name. All import statements aliasing to the same local name are
   348  // technically ambiguous, as if such an import name is used in the generated
   349  // code, it's not clear which import statement it refers to. We ignore any
   350  // potential collisions until actually writing the import table to the generated
   351  // source file. See importTable.write.
   352  //
   353  // Given the following import statements across all the files comprising a
   354  // package marshalled:
   355  //
   356  // "sync"
   357  // "pkg/sync"
   358  // "pkg/sentry/kernel"
   359  // ktime "pkg/sentry/kernel/time"
   360  //
   361  // An importTable representing them would look like this:
   362  //
   363  //	 importTable {
   364  //		is: map[string][]*importStmt {
   365  //		    "sync": []*importStmt{
   366  //		        importStmt{name:"sync", path:"sync", aliased:false}
   367  //		        importStmt{name:"sync", path:"pkg/sync", aliased:false}
   368  //		    },
   369  //		    "kernel": []*importStmt{importStmt{
   370  //		       name: "kernel",
   371  //		       path: "pkg/sentry/kernel",
   372  //		       aliased: false
   373  //		    }},
   374  //		    "ktime": []*importStmt{importStmt{
   375  //		        name: "ktime",
   376  //		        path: "pkg/sentry/kernel/time",
   377  //		        aliased: true,
   378  //		    }},
   379  //		}
   380  //	 }
   381  //
   382  // Note that the local name "sync" is assigned to two different import
   383  // statements. This is possible if the import statements are from different
   384  // source files in the same package.
   385  //
   386  // Since go-marshal generates a single output file per package regardless of the
   387  // number of input files, if "sync" is referenced by any generated code, it's
   388  // unclear which import statement "sync" refers to. While it's theoretically
   389  // possible to resolve this by assigning a unique local alias to each instance
   390  // of the sync package, go-marshal currently aborts when it encounters such an
   391  // ambiguity.
   392  //
   393  // TODO(b/151478251): importTable considers the final component of an import
   394  // path to be the package name, but this is only a convention. The actual
   395  // package name is determined by the package statement in the source files for
   396  // the package.
   397  type importTable struct {
   398  	// Map of imports and whether they should be copied to the output.
   399  	is map[string][]*importStmt
   400  }
   401  
   402  func newImportTable() *importTable {
   403  	return &importTable{
   404  		is: make(map[string][]*importStmt),
   405  	}
   406  }
   407  
   408  // Merges import statements from other into i.
   409  func (i *importTable) merge(other *importTable) {
   410  	for name, ims := range other.is {
   411  		i.is[name] = append(i.is[name], ims...)
   412  	}
   413  }
   414  
   415  func (i *importTable) addStmt(s *importStmt) *importStmt {
   416  	i.is[s.name] = append(i.is[s.name], s)
   417  	return s
   418  }
   419  
   420  func (i *importTable) add(s string) *importStmt {
   421  	n := newImport(s)
   422  	return i.addStmt(n)
   423  }
   424  
   425  func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
   426  	return i.addStmt(newImportFromSpec(spec, f))
   427  }
   428  
   429  // Marks the import named n as used. If no such import is in the table, returns
   430  // false.
   431  func (i *importTable) markUsed(n string) bool {
   432  	if ns, ok := i.is[n]; ok {
   433  		for _, n := range ns {
   434  			n.markUsed()
   435  		}
   436  		return true
   437  	}
   438  	return false
   439  }
   440  
   441  func (i *importTable) clear() {
   442  	for _, is := range i.is {
   443  		for _, i := range is {
   444  			i.used = false
   445  		}
   446  	}
   447  }
   448  
   449  func (i *importTable) write(out io.Writer) error {
   450  	if len(i.is) == 0 {
   451  		// Nothing to import, we're done.
   452  		return nil
   453  	}
   454  
   455  	imports := make([]string, 0, len(i.is))
   456  	for name, is := range i.is {
   457  		var lastUsed *importStmt
   458  		var ambiguous bool
   459  
   460  		for _, i := range is {
   461  			if i.used {
   462  				if lastUsed != nil {
   463  					if !i.equivalent(lastUsed) {
   464  						ambiguous = true
   465  					}
   466  				}
   467  				lastUsed = i
   468  			}
   469  		}
   470  
   471  		if ambiguous {
   472  			// We have two or more import statements across the different source
   473  			// files that share a local name, and at least one of these imports
   474  			// are used by the generated code. This ambiguity can't be resolved
   475  			// by go-marshal and requires the user intervention. Dump a list of
   476  			// the colliding import statements and let the user modify the input
   477  			// files as appropriate.
   478  			var b strings.Builder
   479  			fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
   480  			fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
   481  			// Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
   482  			// be true. Therefore the slicing below is safe.
   483  			for _, i := range is[:len(is)-1] {
   484  				fmt.Fprintf(&b, "  %v\n", i.debugString())
   485  			}
   486  			fmt.Fprintf(&b, "  %v", is[len(is)-1].debugString())
   487  			panic(b.String())
   488  		}
   489  
   490  		if lastUsed != nil {
   491  			imports = append(imports, lastUsed.String())
   492  		}
   493  	}
   494  	sort.Strings(imports)
   495  
   496  	var b sourceBuffer
   497  	b.emit("import (\n")
   498  	b.incIndent()
   499  	for _, i := range imports {
   500  		b.emit("%s\n", i)
   501  	}
   502  	b.decIndent()
   503  	b.emit(")\n\n")
   504  
   505  	return b.write(out)
   506  }