github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/generator_tests.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomarshal
    16  
    17  import (
    18  	"fmt"
    19  	"go/ast"
    20  	"io"
    21  	"strings"
    22  )
    23  
    24  var standardImports = []string{
    25  	"bytes",
    26  	"fmt",
    27  	"reflect",
    28  	"testing",
    29  
    30  	"github.com/SagerNet/gvisor/tools/go_marshal/analysis",
    31  }
    32  
    33  var sliceAPIImports = []string{
    34  	"encoding/binary",
    35  	"github.com/SagerNet/gvisor/pkg/hostarch",
    36  }
    37  
    38  type testGenerator struct {
    39  	sourceBuffer
    40  
    41  	// The type we're serializing.
    42  	t *ast.TypeSpec
    43  
    44  	// Receiver argument for generated methods.
    45  	r string
    46  
    47  	// Imports used by generated code.
    48  	imports *importTable
    49  
    50  	// Import statement for the package declaring the type we generated code
    51  	// for. We need this to construct test instances for the type, since the
    52  	// tests aren't written in the same package.
    53  	decl *importStmt
    54  }
    55  
    56  func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator {
    57  	g := &testGenerator{
    58  		t:       t,
    59  		r:       r,
    60  		imports: newImportTable(),
    61  	}
    62  
    63  	for _, i := range standardImports {
    64  		g.imports.add(i).markUsed()
    65  	}
    66  	// These imports are used if a type requests the slice API. Don't
    67  	// mark them as used by default.
    68  	for _, i := range sliceAPIImports {
    69  		g.imports.add(i)
    70  	}
    71  
    72  	return g
    73  }
    74  
    75  func (g *testGenerator) typeName() string {
    76  	return g.t.Name.Name
    77  }
    78  
    79  func (g *testGenerator) testFuncName(base string) string {
    80  	return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
    81  }
    82  
    83  func (g *testGenerator) inTestFunction(name string, body func()) {
    84  	g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
    85  	g.inIndent(body)
    86  	g.emit("}\n\n")
    87  }
    88  
    89  func (g *testGenerator) emitTestNonZeroSize() {
    90  	g.inTestFunction("TestSizeNonZero", func() {
    91  		g.emit("var x %v\n", g.typeName())
    92  		g.emit("if x.SizeBytes() == 0 {\n")
    93  		g.inIndent(func() {
    94  			g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
    95  		})
    96  		g.emit("}\n")
    97  	})
    98  }
    99  
   100  func (g *testGenerator) emitTestSuspectAlignment() {
   101  	g.inTestFunction("TestSuspectAlignment", func() {
   102  		g.emit("var x %v\n", g.typeName())
   103  		g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
   104  	})
   105  }
   106  
   107  func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
   108  	g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
   109  		g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
   110  		g.emit("analysis.RandomizeValue(&x)\n\n")
   111  
   112  		g.emit("buf := make([]byte, x.SizeBytes())\n")
   113  		g.emit("x.MarshalBytes(buf)\n")
   114  		g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
   115  		g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
   116  
   117  		g.emit("y.UnmarshalBytes(buf)\n")
   118  		g.emit("if !reflect.DeepEqual(x, y) {\n")
   119  		g.inIndent(func() {
   120  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
   121  		})
   122  		g.emit("}\n")
   123  		g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
   124  		g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
   125  		g.inIndent(func() {
   126  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
   127  		})
   128  		g.emit("}\n\n")
   129  
   130  		g.emit("z.UnmarshalUnsafe(buf)\n")
   131  		g.emit("if !reflect.DeepEqual(x, z) {\n")
   132  		g.inIndent(func() {
   133  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n")
   134  		})
   135  		g.emit("}\n")
   136  		g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
   137  		g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
   138  		g.inIndent(func() {
   139  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n")
   140  		})
   141  		g.emit("}\n")
   142  	})
   143  }
   144  
   145  func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
   146  	for _, name := range []string{"binary", "hostarch"} {
   147  		if !g.imports.markUsed(name) {
   148  			panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
   149  		}
   150  	}
   151  
   152  	g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
   153  		g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
   154  		g.emit("analysis.RandomizeValue(&x)\n\n")
   155  		g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
   156  		g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
   157  		g.emit("buf.Reset()\n")
   158  		g.emit("if err := binary.Write(buf, hostarch.ByteOrder, x[:]); err != nil {\n")
   159  		g.inIndent(func() {
   160  			g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
   161  		})
   162  		g.emit("}\n")
   163  		g.emit("bufUnsafe := make([]byte, size)\n")
   164  		g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
   165  
   166  		g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
   167  		g.emit("if !reflect.DeepEqual(x, y) {\n")
   168  		g.inIndent(func() {
   169  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
   170  		})
   171  		g.emit("}\n")
   172  		g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
   173  		g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
   174  		g.inIndent(func() {
   175  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
   176  		})
   177  		g.emit("}\n\n")
   178  	})
   179  }
   180  
   181  func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
   182  	g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
   183  		g.emit("var x, y, yUnsafe %s\n", g.typeName())
   184  		g.emit("analysis.RandomizeValue(&x)\n\n")
   185  
   186  		g.emit("var buf bytes.Buffer\n\n")
   187  
   188  		g.emit("x.WriteTo(&buf)\n")
   189  		g.emit("y.UnmarshalBytes(buf.Bytes())\n\n")
   190  		g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n")
   191  
   192  		g.emit("if !reflect.DeepEqual(x, y) {\n")
   193  		g.inIndent(func() {
   194  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
   195  		})
   196  		g.emit("}\n")
   197  		g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
   198  		g.inIndent(func() {
   199  			g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
   200  		})
   201  		g.emit("}\n")
   202  	})
   203  }
   204  
   205  func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
   206  	g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() {
   207  		g.emit("var x %s\n", g.typeName())
   208  		g.emit("sizeFromConcrete := x.SizeBytes()\n")
   209  		g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName())
   210  
   211  		g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
   212  		g.inIndent(func() {
   213  			g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
   214  		})
   215  		g.emit("}\n")
   216  	})
   217  }
   218  
   219  func (g *testGenerator) emitTests(slice *sliceAPI) {
   220  	g.emitTestNonZeroSize()
   221  	g.emitTestSuspectAlignment()
   222  	g.emitTestMarshalUnmarshalPreservesData()
   223  	g.emitTestWriteToUnmarshalPreservesData()
   224  	g.emitTestSizeBytesOnTypedNilPtr()
   225  
   226  	if slice != nil {
   227  		g.emitTestMarshalUnmarshalSlicePreservesData(slice)
   228  	}
   229  }
   230  
   231  func (g *testGenerator) write(out io.Writer) error {
   232  	return g.sourceBuffer.write(out)
   233  }