github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/tools/go_marshal/gomarshal/generator_tests.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gomarshal 16 17 import ( 18 "fmt" 19 "go/ast" 20 "io" 21 "strings" 22 ) 23 24 var standardImports = []string{ 25 "bytes", 26 "fmt", 27 "reflect", 28 "testing", 29 30 "github.com/SagerNet/gvisor/tools/go_marshal/analysis", 31 } 32 33 var sliceAPIImports = []string{ 34 "encoding/binary", 35 "github.com/SagerNet/gvisor/pkg/hostarch", 36 } 37 38 type testGenerator struct { 39 sourceBuffer 40 41 // The type we're serializing. 42 t *ast.TypeSpec 43 44 // Receiver argument for generated methods. 45 r string 46 47 // Imports used by generated code. 48 imports *importTable 49 50 // Import statement for the package declaring the type we generated code 51 // for. We need this to construct test instances for the type, since the 52 // tests aren't written in the same package. 53 decl *importStmt 54 } 55 56 func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator { 57 g := &testGenerator{ 58 t: t, 59 r: r, 60 imports: newImportTable(), 61 } 62 63 for _, i := range standardImports { 64 g.imports.add(i).markUsed() 65 } 66 // These imports are used if a type requests the slice API. Don't 67 // mark them as used by default. 68 for _, i := range sliceAPIImports { 69 g.imports.add(i) 70 } 71 72 return g 73 } 74 75 func (g *testGenerator) typeName() string { 76 return g.t.Name.Name 77 } 78 79 func (g *testGenerator) testFuncName(base string) string { 80 return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) 81 } 82 83 func (g *testGenerator) inTestFunction(name string, body func()) { 84 g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) 85 g.inIndent(body) 86 g.emit("}\n\n") 87 } 88 89 func (g *testGenerator) emitTestNonZeroSize() { 90 g.inTestFunction("TestSizeNonZero", func() { 91 g.emit("var x %v\n", g.typeName()) 92 g.emit("if x.SizeBytes() == 0 {\n") 93 g.inIndent(func() { 94 g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") 95 }) 96 g.emit("}\n") 97 }) 98 } 99 100 func (g *testGenerator) emitTestSuspectAlignment() { 101 g.inTestFunction("TestSuspectAlignment", func() { 102 g.emit("var x %v\n", g.typeName()) 103 g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") 104 }) 105 } 106 107 func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { 108 g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { 109 g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) 110 g.emit("analysis.RandomizeValue(&x)\n\n") 111 112 g.emit("buf := make([]byte, x.SizeBytes())\n") 113 g.emit("x.MarshalBytes(buf)\n") 114 g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") 115 g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") 116 117 g.emit("y.UnmarshalBytes(buf)\n") 118 g.emit("if !reflect.DeepEqual(x, y) {\n") 119 g.inIndent(func() { 120 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") 121 }) 122 g.emit("}\n") 123 g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") 124 g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") 125 g.inIndent(func() { 126 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") 127 }) 128 g.emit("}\n\n") 129 130 g.emit("z.UnmarshalUnsafe(buf)\n") 131 g.emit("if !reflect.DeepEqual(x, z) {\n") 132 g.inIndent(func() { 133 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") 134 }) 135 g.emit("}\n") 136 g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") 137 g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") 138 g.inIndent(func() { 139 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") 140 }) 141 g.emit("}\n") 142 }) 143 } 144 145 func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) { 146 for _, name := range []string{"binary", "hostarch"} { 147 if !g.imports.markUsed(name) { 148 panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name)) 149 } 150 } 151 152 g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() { 153 g.emit("var x, y, yUnsafe [8]%s\n", g.typeName()) 154 g.emit("analysis.RandomizeValue(&x)\n\n") 155 g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName()) 156 g.emit("buf := bytes.NewBuffer(make([]byte, size))\n") 157 g.emit("buf.Reset()\n") 158 g.emit("if err := binary.Write(buf, hostarch.ByteOrder, x[:]); err != nil {\n") 159 g.inIndent(func() { 160 g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n") 161 }) 162 g.emit("}\n") 163 g.emit("bufUnsafe := make([]byte, size)\n") 164 g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident) 165 166 g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident) 167 g.emit("if !reflect.DeepEqual(x, y) {\n") 168 g.inIndent(func() { 169 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") 170 }) 171 g.emit("}\n") 172 g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident) 173 g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") 174 g.inIndent(func() { 175 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") 176 }) 177 g.emit("}\n\n") 178 }) 179 } 180 181 func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { 182 g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { 183 g.emit("var x, y, yUnsafe %s\n", g.typeName()) 184 g.emit("analysis.RandomizeValue(&x)\n\n") 185 186 g.emit("var buf bytes.Buffer\n\n") 187 188 g.emit("x.WriteTo(&buf)\n") 189 g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") 190 g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") 191 192 g.emit("if !reflect.DeepEqual(x, y) {\n") 193 g.inIndent(func() { 194 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") 195 }) 196 g.emit("}\n") 197 g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") 198 g.inIndent(func() { 199 g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") 200 }) 201 g.emit("}\n") 202 }) 203 } 204 205 func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { 206 g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { 207 g.emit("var x %s\n", g.typeName()) 208 g.emit("sizeFromConcrete := x.SizeBytes()\n") 209 g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) 210 211 g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") 212 g.inIndent(func() { 213 g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") 214 }) 215 g.emit("}\n") 216 }) 217 } 218 219 func (g *testGenerator) emitTests(slice *sliceAPI) { 220 g.emitTestNonZeroSize() 221 g.emitTestSuspectAlignment() 222 g.emitTestMarshalUnmarshalPreservesData() 223 g.emitTestWriteToUnmarshalPreservesData() 224 g.emitTestSizeBytesOnTypedNilPtr() 225 226 if slice != nil { 227 g.emitTestMarshalUnmarshalSlicePreservesData(slice) 228 } 229 } 230 231 func (g *testGenerator) write(out io.Writer) error { 232 return g.sourceBuffer.write(out) 233 }