github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/util.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomarshal
    16  
    17  import (
    18  	"bytes"
    19  	"flag"
    20  	"fmt"
    21  	"go/ast"
    22  	"go/token"
    23  	"io"
    24  	"os"
    25  	"path"
    26  	"reflect"
    27  	"sort"
    28  	"strings"
    29  )
    30  
    31  var debug = flag.Bool("debug", false, "enables debugging output")
    32  
    33  // receiverName returns an appropriate receiver name given a type spec.
    34  func receiverName(t *ast.TypeSpec) string {
    35  	if len(t.Name.Name) < 1 {
    36  		// Zero length type name?
    37  		panic("unreachable")
    38  	}
    39  	return strings.ToLower(t.Name.Name[:1])
    40  }
    41  
    42  // kindString returns a user-friendly representation of an AST expr type.
    43  func kindString(e ast.Expr) string {
    44  	switch e.(type) {
    45  	case *ast.Ident:
    46  		return "scalar"
    47  	case *ast.ArrayType:
    48  		return "array"
    49  	case *ast.StructType:
    50  		return "struct"
    51  	case *ast.StarExpr:
    52  		return "pointer"
    53  	case *ast.FuncType:
    54  		return "function"
    55  	case *ast.InterfaceType:
    56  		return "interface"
    57  	case *ast.MapType:
    58  		return "map"
    59  	case *ast.ChanType:
    60  		return "channel"
    61  	default:
    62  		return reflect.TypeOf(e).String()
    63  	}
    64  }
    65  
    66  func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
    67  	for _, field := range st.Fields.List {
    68  		fn(field)
    69  	}
    70  }
    71  
    72  // fieldDispatcher is a collection of callbacks for handling different types of
    73  // fields in a struct declaration.
    74  type fieldDispatcher struct {
    75  	primitive func(n, t *ast.Ident)
    76  	selector  func(n, tX, tSel *ast.Ident)
    77  	array     func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
    78  	unhandled func(n *ast.Ident)
    79  }
    80  
    81  // Precondition: All dispatch callbacks that will be invoked must be
    82  // provided.
    83  func (fd fieldDispatcher) dispatch(f *ast.Field) {
    84  	// Each field declaration may actually be multiple declarations of the same
    85  	// type. For example, consider:
    86  	//
    87  	// type Point struct {
    88  	//     x, y, z int
    89  	// }
    90  	//
    91  	// We invoke the call-backs once per such instance.
    92  
    93  	// Handle embedded fields. Embedded fields have no names, but can be
    94  	// referenced by the type name.
    95  	if len(f.Names) < 1 {
    96  		switch v := f.Type.(type) {
    97  		case *ast.Ident:
    98  			fd.primitive(v, v)
    99  		case *ast.SelectorExpr:
   100  			fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel)
   101  		default:
   102  			// Note: Arrays can't be embedded, which is handled here.
   103  			panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type))
   104  		}
   105  		return
   106  	}
   107  
   108  	// Non-embedded field.
   109  	for _, name := range f.Names {
   110  		switch v := f.Type.(type) {
   111  		case *ast.Ident:
   112  			fd.primitive(name, v)
   113  		case *ast.SelectorExpr:
   114  			fd.selector(name, v.X.(*ast.Ident), v.Sel)
   115  		case *ast.ArrayType:
   116  			switch t := v.Elt.(type) {
   117  			case *ast.Ident:
   118  				fd.array(name, v, t)
   119  			default:
   120  				// Should be handled with a better error message during validate.
   121  				panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
   122  			}
   123  		default:
   124  			fd.unhandled(name)
   125  		}
   126  	}
   127  }
   128  
   129  // debugEnabled indicates whether debugging is enabled for gomarshal.
   130  func debugEnabled() bool {
   131  	return *debug
   132  }
   133  
   134  // abort aborts the go_marshal tool with the given error message.
   135  func abort(msg string) {
   136  	if !strings.HasSuffix(msg, "\n") {
   137  		msg += "\n"
   138  	}
   139  	fmt.Print(msg)
   140  	os.Exit(1)
   141  }
   142  
   143  // abortAt aborts the go_marshal tool with the given error message, with
   144  // a reference position to the input source.
   145  func abortAt(p token.Position, msg string) {
   146  	abort(fmt.Sprintf("%v:\n  %s\n", p, msg))
   147  }
   148  
   149  // debugf conditionally prints a debug message.
   150  func debugf(f string, a ...interface{}) {
   151  	if debugEnabled() {
   152  		fmt.Printf(f, a...)
   153  	}
   154  }
   155  
   156  // debugfAt conditionally prints a debug message with a reference to a position
   157  // in the input source.
   158  func debugfAt(p token.Position, f string, a ...interface{}) {
   159  	if debugEnabled() {
   160  		fmt.Printf("%s:\n  %s", p, fmt.Sprintf(f, a...))
   161  	}
   162  }
   163  
   164  // emit generates a line of code in the output file.
   165  //
   166  // emit is a wrapper around writing a formatted string to the output
   167  // buffer. emit can be invoked in one of two ways:
   168  //
   169  // (1) emit("some string")
   170  //     When emit is called with a single string argument, it is simply copied to
   171  //     the output buffer without any further formatting.
   172  // (2) emit(fmtString, args...)
   173  //     emit can also be invoked in a similar fashion to *Printf() functions,
   174  //     where the first argument is a format string.
   175  //
   176  // Calling emit with a single argument that is not a string will result in a
   177  // panic, as the caller's intent is ambiguous.
   178  func emit(out io.Writer, indent int, a ...interface{}) {
   179  	const spacesPerIndentLevel = 4
   180  
   181  	if len(a) < 1 {
   182  		panic("emit() called with no arguments")
   183  	}
   184  
   185  	if indent > 0 {
   186  		if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
   187  			// Writing to the emit output should not fail. Typically the output
   188  			// is a byte.Buffer; writes to these never fail.
   189  			panic(err)
   190  		}
   191  	}
   192  
   193  	first, ok := a[0].(string)
   194  	if !ok {
   195  		// First argument must be either the string to emit (case 1 from
   196  		// function-level comment), or a format string (case 2).
   197  		panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
   198  	}
   199  
   200  	if len(a) == 1 {
   201  		// Single string argument. Assume no formatting requested.
   202  		if _, err := fmt.Fprint(out, first); err != nil {
   203  			// Writing to out should not fail.
   204  			panic(err)
   205  		}
   206  		return
   207  
   208  	}
   209  
   210  	// Formatting requested.
   211  	if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
   212  		// Writing to out should not fail.
   213  		panic(err)
   214  	}
   215  }
   216  
   217  // sourceBuffer represents fragments of generated go source code.
   218  //
   219  // sourceBuffer provides a convenient way to build up go souce fragments in
   220  // memory. May be safely zero-value initialized. Not thread-safe.
   221  type sourceBuffer struct {
   222  	// Current indentation level.
   223  	indent int
   224  
   225  	// Memory buffer containing contents while they're being generated.
   226  	b bytes.Buffer
   227  }
   228  
   229  func (b *sourceBuffer) reset() {
   230  	b.indent = 0
   231  	b.b.Reset()
   232  }
   233  
   234  func (b *sourceBuffer) incIndent() {
   235  	b.indent++
   236  }
   237  
   238  func (b *sourceBuffer) decIndent() {
   239  	if b.indent <= 0 {
   240  		panic("decIndent() without matching incIndent()")
   241  	}
   242  	b.indent--
   243  }
   244  
   245  func (b *sourceBuffer) emit(a ...interface{}) {
   246  	emit(&b.b, b.indent, a...)
   247  }
   248  
   249  func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
   250  	emit(&b.b, 0 /*indent*/, a...)
   251  }
   252  
   253  func (b *sourceBuffer) inIndent(body func()) {
   254  	b.incIndent()
   255  	body()
   256  	b.decIndent()
   257  }
   258  
   259  func (b *sourceBuffer) write(out io.Writer) error {
   260  	_, err := fmt.Fprint(out, b.b.String())
   261  	return err
   262  }
   263  
   264  // Write implements io.Writer.Write.
   265  func (b *sourceBuffer) Write(buf []byte) (int, error) {
   266  	return (b.b.Write(buf))
   267  }
   268  
   269  // importStmt represents a single import statement.
   270  type importStmt struct {
   271  	// Local name of the imported package.
   272  	name string
   273  	// Import path.
   274  	path string
   275  	// Indicates whether the local name is an alias, or simply the final
   276  	// component of the path.
   277  	aliased bool
   278  	// Indicates whether this import was referenced by generated code.
   279  	used bool
   280  	// AST node and file set representing the import statement, if any. These
   281  	// are only non-nil if the import statement originates from an input source
   282  	// file.
   283  	spec *ast.ImportSpec
   284  	fset *token.FileSet
   285  }
   286  
   287  func newImport(p string) *importStmt {
   288  	name := path.Base(p)
   289  	return &importStmt{
   290  		name:    name,
   291  		path:    p,
   292  		aliased: false,
   293  	}
   294  }
   295  
   296  func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
   297  	p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
   298  	name := path.Base(p)
   299  	if name == "" || name == "/" || name == "." {
   300  		panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
   301  			f.Position(spec.Path.Pos()), name))
   302  	}
   303  	if spec.Name != nil {
   304  		name = spec.Name.Name
   305  	}
   306  	return &importStmt{
   307  		name:    name,
   308  		path:    p,
   309  		aliased: spec.Name != nil,
   310  		spec:    spec,
   311  		fset:    f,
   312  	}
   313  }
   314  
   315  // String implements fmt.Stringer.String. This generates a string for the import
   316  // statement appropriate for writing directly to generated code.
   317  func (i *importStmt) String() string {
   318  	if i.aliased {
   319  		return fmt.Sprintf("%s %q", i.name, i.path)
   320  	}
   321  	return fmt.Sprintf("%q", i.path)
   322  }
   323  
   324  // debugString returns a debug string representing an import statement. This
   325  // representation is not valid golang code and is used for debugging output.
   326  func (i *importStmt) debugString() string {
   327  	if i.spec != nil && i.fset != nil {
   328  		return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
   329  	}
   330  	return fmt.Sprintf("(go-marshal import): %s", i)
   331  }
   332  
   333  func (i *importStmt) markUsed() {
   334  	i.used = true
   335  }
   336  
   337  func (i *importStmt) equivalent(other *importStmt) bool {
   338  	return i.name == other.name && i.path == other.path && i.aliased == other.aliased
   339  }
   340  
   341  // importTable represents a collection of importStmts.
   342  //
   343  // An importTable may contain multiple import statements referencing the same
   344  // local name. All import statements aliasing to the same local name are
   345  // technically ambiguous, as if such an import name is used in the generated
   346  // code, it's not clear which import statement it refers to. We ignore any
   347  // potential collisions until actually writing the import table to the generated
   348  // source file. See importTable.write.
   349  //
   350  // Given the following import statements across all the files comprising a
   351  // package marshalled:
   352  //
   353  // "sync"
   354  // "pkg/sync"
   355  // "pkg/sentry/kernel"
   356  // ktime "pkg/sentry/kernel/time"
   357  //
   358  // An importTable representing them would look like this:
   359  //
   360  // importTable {
   361  //     is: map[string][]*importStmt {
   362  //         "sync": []*importStmt{
   363  //             importStmt{name:"sync", path:"sync", aliased:false}
   364  //             importStmt{name:"sync", path:"pkg/sync", aliased:false}
   365  //         },
   366  //         "kernel": []*importStmt{importStmt{
   367  //            name: "kernel",
   368  //            path: "pkg/sentry/kernel",
   369  //            aliased: false
   370  //         }},
   371  //         "ktime": []*importStmt{importStmt{
   372  //             name: "ktime",
   373  //             path: "pkg/sentry/kernel/time",
   374  //             aliased: true,
   375  //         }},
   376  //     }
   377  // }
   378  //
   379  // Note that the local name "sync" is assigned to two different import
   380  // statements. This is possible if the import statements are from different
   381  // source files in the same package.
   382  //
   383  // Since go-marshal generates a single output file per package regardless of the
   384  // number of input files, if "sync" is referenced by any generated code, it's
   385  // unclear which import statement "sync" refers to. While it's theoretically
   386  // possible to resolve this by assigning a unique local alias to each instance
   387  // of the sync package, go-marshal currently aborts when it encounters such an
   388  // ambiguity.
   389  //
   390  // TODO(b/151478251): importTable considers the final component of an import
   391  // path to be the package name, but this is only a convention. The actual
   392  // package name is determined by the package statement in the source files for
   393  // the package.
   394  type importTable struct {
   395  	// Map of imports and whether they should be copied to the output.
   396  	is map[string][]*importStmt
   397  }
   398  
   399  func newImportTable() *importTable {
   400  	return &importTable{
   401  		is: make(map[string][]*importStmt),
   402  	}
   403  }
   404  
   405  // Merges import statements from other into i.
   406  func (i *importTable) merge(other *importTable) {
   407  	for name, ims := range other.is {
   408  		i.is[name] = append(i.is[name], ims...)
   409  	}
   410  }
   411  
   412  func (i *importTable) addStmt(s *importStmt) *importStmt {
   413  	i.is[s.name] = append(i.is[s.name], s)
   414  	return s
   415  }
   416  
   417  func (i *importTable) add(s string) *importStmt {
   418  	n := newImport(s)
   419  	return i.addStmt(n)
   420  }
   421  
   422  func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
   423  	return i.addStmt(newImportFromSpec(spec, f))
   424  }
   425  
   426  // Marks the import named n as used. If no such import is in the table, returns
   427  // false.
   428  func (i *importTable) markUsed(n string) bool {
   429  	if ns, ok := i.is[n]; ok {
   430  		for _, n := range ns {
   431  			n.markUsed()
   432  		}
   433  		return true
   434  	}
   435  	return false
   436  }
   437  
   438  func (i *importTable) clear() {
   439  	for _, is := range i.is {
   440  		for _, i := range is {
   441  			i.used = false
   442  		}
   443  	}
   444  }
   445  
   446  func (i *importTable) write(out io.Writer) error {
   447  	if len(i.is) == 0 {
   448  		// Nothing to import, we're done.
   449  		return nil
   450  	}
   451  
   452  	imports := make([]string, 0, len(i.is))
   453  	for name, is := range i.is {
   454  		var lastUsed *importStmt
   455  		var ambiguous bool
   456  
   457  		for _, i := range is {
   458  			if i.used {
   459  				if lastUsed != nil {
   460  					if !i.equivalent(lastUsed) {
   461  						ambiguous = true
   462  					}
   463  				}
   464  				lastUsed = i
   465  			}
   466  		}
   467  
   468  		if ambiguous {
   469  			// We have two or more import statements across the different source
   470  			// files that share a local name, and at least one of these imports
   471  			// are used by the generated code. This ambiguity can't be resolved
   472  			// by go-marshal and requires the user intervention. Dump a list of
   473  			// the colliding import statements and let the user modify the input
   474  			// files as appropriate.
   475  			var b strings.Builder
   476  			fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
   477  			fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
   478  			// Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
   479  			// be true. Therefore the slicing below is safe.
   480  			for _, i := range is[:len(is)-1] {
   481  				fmt.Fprintf(&b, "  %v\n", i.debugString())
   482  			}
   483  			fmt.Fprintf(&b, "  %v", is[len(is)-1].debugString())
   484  			panic(b.String())
   485  		}
   486  
   487  		if lastUsed != nil {
   488  			imports = append(imports, lastUsed.String())
   489  		}
   490  	}
   491  	sort.Strings(imports)
   492  
   493  	var b sourceBuffer
   494  	b.emit("import (\n")
   495  	b.incIndent()
   496  	for _, i := range imports {
   497  		b.emit("%s\n", i)
   498  	}
   499  	b.decIndent()
   500  	b.emit(")\n\n")
   501  
   502  	return b.write(out)
   503  }