github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/generator_interfaces.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomarshal
    16  
    17  import (
    18  	"fmt"
    19  	"go/ast"
    20  	"go/token"
    21  	"strings"
    22  )
    23  
    24  // interfaceGenerator generates marshalling interfaces for a single type.
    25  //
    26  // getState is not thread-safe.
    27  type interfaceGenerator struct {
    28  	sourceBuffer
    29  
    30  	// The type we're serializing.
    31  	t *ast.TypeSpec
    32  
    33  	// Receiver argument for generated methods.
    34  	r string
    35  
    36  	// FileSet containing the tokens for the type we're processing.
    37  	f *token.FileSet
    38  
    39  	// is records external packages referenced by the generated implementation.
    40  	is map[string]struct{}
    41  
    42  	// ms records Marshallable types referenced by the generated implementation
    43  	// of t's interfaces.
    44  	ms map[string]struct{}
    45  
    46  	// as records fields in t that are potentially not packed. The key is the
    47  	// accessor for the field.
    48  	as map[string]struct{}
    49  }
    50  
    51  // typeName returns the name of the type this g represents.
    52  func (g *interfaceGenerator) typeName() string {
    53  	return g.t.Name.Name
    54  }
    55  
    56  // newinterfaceGenerator creates a new interface generator.
    57  func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator {
    58  	g := &interfaceGenerator{
    59  		t:  t,
    60  		r:  r,
    61  		f:  fset,
    62  		is: make(map[string]struct{}),
    63  		ms: make(map[string]struct{}),
    64  		as: make(map[string]struct{}),
    65  	}
    66  	g.recordUsedMarshallable(g.typeName())
    67  	return g
    68  }
    69  
    70  func (g *interfaceGenerator) recordUsedMarshallable(m string) {
    71  	g.ms[m] = struct{}{}
    72  
    73  }
    74  
    75  func (g *interfaceGenerator) recordUsedImport(i string) {
    76  	g.is[i] = struct{}{}
    77  }
    78  
    79  func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
    80  	g.as[fieldName] = struct{}{}
    81  }
    82  
    83  // abortAt aborts the go_marshal tool with the given error message, with a
    84  // reference position to the input source. Same as abortAt, but uses g to
    85  // resolve p to position.
    86  func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
    87  	abortAt(g.f.Position(p), msg)
    88  }
    89  
    90  // scalarSize returns the size of type identified by t. If t isn't a primitive
    91  // type, the size isn't known at code generation time, and must be resolved via
    92  // the marshal.Marshallable interface.
    93  func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
    94  	switch t.Name {
    95  	case "int8", "uint8", "byte":
    96  		return 1, false
    97  	case "int16", "uint16":
    98  		return 2, false
    99  	case "int32", "uint32":
   100  		return 4, false
   101  	case "int64", "uint64":
   102  		return 8, false
   103  	default:
   104  		return 0, true
   105  	}
   106  }
   107  
   108  func (g *interfaceGenerator) shift(bufVar string, n int) {
   109  	g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
   110  }
   111  
   112  func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
   113  	g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
   114  }
   115  
   116  // marshalScalar writes a single scalar to a byte slice.
   117  func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
   118  	switch typ {
   119  	case "int8", "uint8", "byte":
   120  		g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
   121  		g.shift(bufVar, 1)
   122  	case "int16", "uint16":
   123  		g.recordUsedImport("hostarch")
   124  		g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
   125  		g.shift(bufVar, 2)
   126  	case "int32", "uint32":
   127  		g.recordUsedImport("hostarch")
   128  		g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
   129  		g.shift(bufVar, 4)
   130  	case "int64", "uint64":
   131  		g.recordUsedImport("hostarch")
   132  		g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
   133  		g.shift(bufVar, 8)
   134  	default:
   135  		g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
   136  		g.shiftDynamic(bufVar, accessor)
   137  	}
   138  }
   139  
   140  // unmarshalScalar reads a single scalar from a byte slice.
   141  func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
   142  	switch typ {
   143  	case "byte":
   144  		g.emit("%s = %s[0]\n", accessor, bufVar)
   145  		g.shift(bufVar, 1)
   146  	case "int8", "uint8":
   147  		g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
   148  		g.shift(bufVar, 1)
   149  	case "int16", "uint16":
   150  		g.recordUsedImport("hostarch")
   151  		g.emit("%s = %s(hostarch.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
   152  		g.shift(bufVar, 2)
   153  	case "int32", "uint32":
   154  		g.recordUsedImport("hostarch")
   155  		g.emit("%s = %s(hostarch.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
   156  		g.shift(bufVar, 4)
   157  	case "int64", "uint64":
   158  		g.recordUsedImport("hostarch")
   159  		g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
   160  		g.shift(bufVar, 8)
   161  	default:
   162  		g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
   163  		g.shiftDynamic(bufVar, accessor)
   164  		g.recordPotentiallyNonPackedField(accessor)
   165  	}
   166  }
   167  
   168  // emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
   169  // byte slice, bypassing escape analysis. The caller is responsible for ensuring
   170  // srcPtr lives until they're done with dstVar, the runtime does not consider
   171  // dstVar dependent on srcPtr due to the escape analysis bypass.
   172  //
   173  // srcPtr must be a pointer.
   174  //
   175  // This function uses internally uses the identifier "hdr", and cannot be used
   176  // in a context where it is already bound.
   177  func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
   178  	g.recordUsedImport("gohacks")
   179  	g.emit("// Construct a slice backed by dst's underlying memory.\n")
   180  	g.emit("var %s []byte\n", dstVar)
   181  	g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
   182  	g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
   183  	g.emit("hdr.Len = %s\n", lenExpr)
   184  	g.emit("hdr.Cap = %s\n\n", lenExpr)
   185  }
   186  
   187  // emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
   188  // to a byte slice. As part of the cast, the byte slice is made to look
   189  // independent of the src slice by bypassing escape analysis. This means the
   190  // byte slice can be used without causing the source to escape. The caller is
   191  // responsible for ensuring srcPtr lives until they're done with dstVar, as the
   192  // runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
   193  //
   194  // srcPtr must be a pointer.
   195  //
   196  // This function uses internally uses the identifiers "ptr", "val" and "hdr",
   197  // and cannot be used in a context where these identifiers are already bound.
   198  func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
   199  	g.emitNoEscapeSliceDataPointer(srcPtr, "val")
   200  
   201  	g.emit("// Construct a slice backed by dst's underlying memory.\n")
   202  	g.emit("var %s []byte\n", dstVar)
   203  	g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
   204  	g.emit("hdr.Data = uintptr(val)\n")
   205  	g.emit("hdr.Len = %s\n", lenExpr)
   206  	g.emit("hdr.Cap = %s\n\n", lenExpr)
   207  }
   208  
   209  // emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
   210  // unsafe.Pointer, bypassing escape analysis. The caller is responsible for
   211  // ensuring srcPtr lives until they're done with dstVar, as the runtime no
   212  // longer considers dstVar dependent on srcPtr and is free to GC it.
   213  //
   214  // srcPtr must be a pointer.
   215  //
   216  // This function uses internally uses the identifier "ptr" cannot be used in a
   217  // context where this identifier is already bound.
   218  func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
   219  	g.recordUsedImport("gohacks")
   220  	g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
   221  	g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
   222  }
   223  
   224  func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
   225  	g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
   226  	g.emit("// must live until the use above.\n")
   227  	g.emit("runtime.KeepAlive(%s) // escapes: replaced by intrinsic.\n", ptrVar)
   228  }
   229  
   230  func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
   231  	switch x := e.X.(type) {
   232  	case *ast.BinaryExpr:
   233  		// Recursively expand sub-expression.
   234  		g.expandBinaryExpr(b, x)
   235  	case *ast.Ident:
   236  		fmt.Fprintf(b, "%s", x.Name)
   237  	case *ast.BasicLit:
   238  		fmt.Fprintf(b, "%s", x.Value)
   239  	default:
   240  		g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
   241  	}
   242  
   243  	fmt.Fprintf(b, "%s", e.Op)
   244  
   245  	switch y := e.Y.(type) {
   246  	case *ast.BinaryExpr:
   247  		// Recursively expand sub-expression.
   248  		g.expandBinaryExpr(b, y)
   249  	case *ast.Ident:
   250  		fmt.Fprintf(b, "%s", y.Name)
   251  	case *ast.BasicLit:
   252  		fmt.Fprintf(b, "%s", y.Value)
   253  	default:
   254  		g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
   255  	}
   256  }
   257  
   258  // arrayLenExpr returns a string containing a valid golang expression
   259  // representing the length of array a. The returned expression should be treated
   260  // as a single value, and will be already parenthesized as required.
   261  func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
   262  	var b strings.Builder
   263  
   264  	switch l := a.Len.(type) {
   265  	case *ast.Ident:
   266  		fmt.Fprintf(&b, "%s", l.Name)
   267  	case *ast.BasicLit:
   268  		fmt.Fprintf(&b, "%s", l.Value)
   269  	case *ast.BinaryExpr:
   270  		g.expandBinaryExpr(&b, l)
   271  		return fmt.Sprintf("(%s)", b.String())
   272  	default:
   273  		g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
   274  	}
   275  	return b.String()
   276  }