gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_marshal/gomarshal/generator_interfaces.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomarshal
    16  
    17  import (
    18  	"fmt"
    19  	"go/ast"
    20  	"go/token"
    21  	"strings"
    22  )
    23  
    24  // interfaceGenerator generates marshalling interfaces for a single type.
    25  //
    26  // getState is not thread-safe.
    27  type interfaceGenerator struct {
    28  	sourceBuffer
    29  
    30  	// The type we're serializing.
    31  	t *ast.TypeSpec
    32  
    33  	// Receiver argument for generated methods.
    34  	r string
    35  
    36  	// FileSet containing the tokens for the type we're processing.
    37  	f *token.FileSet
    38  
    39  	// is records external packages referenced by the generated implementation.
    40  	is map[string]struct{}
    41  
    42  	// ms records Marshallable types referenced by the generated implementation
    43  	// of t's interfaces.
    44  	ms map[string]struct{}
    45  
    46  	// as records fields in t that are potentially not packed. The key is the
    47  	// accessor for the field.
    48  	as map[string]struct{}
    49  }
    50  
    51  // typeName returns the name of the type this g represents.
    52  func (g *interfaceGenerator) typeName() string {
    53  	return g.t.Name.Name
    54  }
    55  
    56  // newinterfaceGenerator creates a new interface generator.
    57  func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator {
    58  	g := &interfaceGenerator{
    59  		t:  t,
    60  		r:  r,
    61  		f:  fset,
    62  		is: make(map[string]struct{}),
    63  		ms: make(map[string]struct{}),
    64  		as: make(map[string]struct{}),
    65  	}
    66  	g.recordUsedMarshallable(g.typeName())
    67  	return g
    68  }
    69  
    70  func (g *interfaceGenerator) recordUsedMarshallable(m string) {
    71  	g.ms[m] = struct{}{}
    72  
    73  }
    74  
    75  func (g *interfaceGenerator) recordUsedImport(i string) {
    76  	g.is[i] = struct{}{}
    77  }
    78  
    79  func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
    80  	// Some calls to g.unmarshalScalar() occur in emitted loops that use "idx"
    81  	// as a loop variable, passing "field[idx]" as the accessor. When
    82  	// g.unmarshalScalar() calls this function, we need to convert such cases
    83  	// to "field[0]" for g.areFieldsPackedExpression(), which is used in
    84  	// contexts where "idx" is not defined.
    85  	fieldName = strings.ReplaceAll(fieldName, "[idx]", "[0]")
    86  	g.as[fieldName] = struct{}{}
    87  }
    88  
    89  // abortAt aborts the go_marshal tool with the given error message, with a
    90  // reference position to the input source. Same as abortAt, but uses g to
    91  // resolve p to position.
    92  func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
    93  	abortAt(g.f.Position(p), msg)
    94  }
    95  
    96  // scalarSize returns the size of type identified by t. If t isn't a primitive
    97  // type, the size isn't known at code generation time, and must be resolved via
    98  // the marshal.Marshallable interface.
    99  func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
   100  	switch t.Name {
   101  	case "int8", "uint8", "byte":
   102  		return 1, false
   103  	case "int16", "uint16":
   104  		return 2, false
   105  	case "int32", "uint32":
   106  		return 4, false
   107  	case "int64", "uint64":
   108  		return 8, false
   109  	default:
   110  		return 0, true
   111  	}
   112  }
   113  
   114  func (g *interfaceGenerator) shift(bufVar string, n int) {
   115  	g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
   116  }
   117  
   118  func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
   119  	g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
   120  }
   121  
   122  // marshalScalar writes a single scalar to a byte slice.
   123  func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
   124  	switch typ {
   125  	case "int8", "uint8", "byte":
   126  		g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
   127  		g.shift(bufVar, 1)
   128  	case "int16", "uint16":
   129  		g.recordUsedImport("hostarch")
   130  		g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
   131  		g.shift(bufVar, 2)
   132  	case "int32", "uint32":
   133  		g.recordUsedImport("hostarch")
   134  		g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
   135  		g.shift(bufVar, 4)
   136  	case "int64", "uint64":
   137  		g.recordUsedImport("hostarch")
   138  		g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
   139  		g.shift(bufVar, 8)
   140  	default:
   141  		g.emit("%s = %s.MarshalUnsafe(%s)\n", bufVar, accessor, bufVar)
   142  	}
   143  }
   144  
   145  // unmarshalScalar reads a single scalar from a byte slice.
   146  func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
   147  	switch typ {
   148  	case "byte":
   149  		g.emit("%s = %s[0]\n", accessor, bufVar)
   150  		g.shift(bufVar, 1)
   151  	case "int8", "uint8":
   152  		g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
   153  		g.shift(bufVar, 1)
   154  	case "int16", "uint16":
   155  		g.recordUsedImport("hostarch")
   156  		g.emit("%s = %s(hostarch.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
   157  		g.shift(bufVar, 2)
   158  	case "int32", "uint32":
   159  		g.recordUsedImport("hostarch")
   160  		g.emit("%s = %s(hostarch.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
   161  		g.shift(bufVar, 4)
   162  	case "int64", "uint64":
   163  		g.recordUsedImport("hostarch")
   164  		g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
   165  		g.shift(bufVar, 8)
   166  	default:
   167  		g.emit("%s = %s.UnmarshalUnsafe(%s)\n", bufVar, accessor, bufVar)
   168  		g.recordPotentiallyNonPackedField(accessor)
   169  	}
   170  }
   171  
   172  // emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
   173  // byte slice, bypassing escape analysis. The caller is responsible for ensuring
   174  // srcPtr lives until they're done with dstVar, the runtime does not consider
   175  // dstVar dependent on srcPtr due to the escape analysis bypass.
   176  //
   177  // srcPtr must be a pointer.
   178  //
   179  // This function uses internally uses the identifier "hdr", and cannot be used
   180  // in a context where it is already bound.
   181  func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
   182  	g.recordUsedImport("gohacks")
   183  	g.emit("// Construct a slice backed by dst's underlying memory.\n")
   184  	g.emit("var %s []byte\n", dstVar)
   185  	g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
   186  	g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
   187  	g.emit("hdr.Len = %s\n", lenExpr)
   188  	g.emit("hdr.Cap = %s\n\n", lenExpr)
   189  }
   190  
   191  // emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
   192  // to a byte slice. As part of the cast, the byte slice is made to look
   193  // independent of the src slice by bypassing escape analysis. This means the
   194  // byte slice can be used without causing the source to escape. The caller is
   195  // responsible for ensuring srcPtr lives until they're done with dstVar, as the
   196  // runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
   197  //
   198  // srcPtr must be a pointer.
   199  //
   200  // This function uses internally uses the identifiers "ptr", "val" and "hdr",
   201  // and cannot be used in a context where these identifiers are already bound.
   202  func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
   203  	g.emitNoEscapeSliceDataPointer(srcPtr, "val")
   204  
   205  	g.emit("// Construct a slice backed by dst's underlying memory.\n")
   206  	g.emit("var %s []byte\n", dstVar)
   207  	g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
   208  	g.emit("hdr.Data = uintptr(val)\n")
   209  	g.emit("hdr.Len = %s\n", lenExpr)
   210  	g.emit("hdr.Cap = %s\n\n", lenExpr)
   211  }
   212  
   213  // emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
   214  // unsafe.Pointer, bypassing escape analysis. The caller is responsible for
   215  // ensuring srcPtr lives until they're done with dstVar, as the runtime no
   216  // longer considers dstVar dependent on srcPtr and is free to GC it.
   217  //
   218  // srcPtr must be a pointer.
   219  //
   220  // This function uses internally uses the identifier "ptr" cannot be used in a
   221  // context where this identifier is already bound.
   222  func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
   223  	g.recordUsedImport("gohacks")
   224  	g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
   225  	g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
   226  }
   227  
   228  func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
   229  	g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
   230  	g.emit("// must live until the use above.\n")
   231  	g.emit("runtime.KeepAlive(%s) // escapes: replaced by intrinsic.\n", ptrVar)
   232  }
   233  
   234  func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
   235  	switch x := e.X.(type) {
   236  	case *ast.BinaryExpr:
   237  		// Recursively expand sub-expression.
   238  		g.expandBinaryExpr(b, x)
   239  	case *ast.Ident:
   240  		fmt.Fprintf(b, "%s", x.Name)
   241  	case *ast.BasicLit:
   242  		fmt.Fprintf(b, "%s", x.Value)
   243  	default:
   244  		g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
   245  	}
   246  
   247  	fmt.Fprintf(b, "%s", e.Op)
   248  
   249  	switch y := e.Y.(type) {
   250  	case *ast.BinaryExpr:
   251  		// Recursively expand sub-expression.
   252  		g.expandBinaryExpr(b, y)
   253  	case *ast.Ident:
   254  		fmt.Fprintf(b, "%s", y.Name)
   255  	case *ast.BasicLit:
   256  		fmt.Fprintf(b, "%s", y.Value)
   257  	default:
   258  		g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
   259  	}
   260  }
   261  
   262  // arrayLenExpr returns a string containing a valid golang expression
   263  // representing the length of array a. The returned expression should be treated
   264  // as a single value, and will be already parenthesized as required.
   265  func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
   266  	var b strings.Builder
   267  
   268  	switch l := a.Len.(type) {
   269  	case *ast.Ident:
   270  		fmt.Fprintf(&b, "%s", l.Name)
   271  	case *ast.BasicLit:
   272  		fmt.Fprintf(&b, "%s", l.Value)
   273  	case *ast.BinaryExpr:
   274  		g.expandBinaryExpr(&b, l)
   275  		return fmt.Sprintf("(%s)", b.String())
   276  	default:
   277  		g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
   278  	}
   279  	return b.String()
   280  }