github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/net/splice_test.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build linux 6 7 package net 8 9 import ( 10 "io" 11 "log" 12 "os" 13 "os/exec" 14 "strconv" 15 "sync" 16 "testing" 17 "time" 18 ) 19 20 func TestSplice(t *testing.T) { 21 t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") }) 22 if !testableNetwork("unixgram") { 23 t.Skip("skipping unix-to-tcp tests") 24 } 25 t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) 26 t.Run("no-unixpacket", testSpliceNoUnixpacket) 27 t.Run("no-unixgram", testSpliceNoUnixgram) 28 } 29 30 func testSplice(t *testing.T, upNet, downNet string) { 31 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test) 32 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test) 33 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test) 34 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test) 35 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test) 36 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test) 37 t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) }) 38 t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) }) 39 } 40 41 type spliceTestCase struct { 42 upNet, downNet string 43 44 chunkSize, totalSize int 45 limitReadSize int 46 } 47 48 func (tc spliceTestCase) test(t *testing.T) { 49 clientUp, serverUp := spliceTestSocketPair(t, tc.upNet) 50 defer serverUp.Close() 51 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize) 52 if err != nil { 53 t.Fatal(err) 54 } 55 defer cleanup() 56 clientDown, serverDown := spliceTestSocketPair(t, tc.downNet) 57 defer serverDown.Close() 58 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize) 59 if err != nil { 60 t.Fatal(err) 61 } 62 defer cleanup() 63 var ( 64 r io.Reader = serverUp 65 size = tc.totalSize 66 ) 67 if tc.limitReadSize > 0 { 68 if tc.limitReadSize < size { 69 size = tc.limitReadSize 70 } 71 72 r = &io.LimitedReader{ 73 N: int64(tc.limitReadSize), 74 R: serverUp, 75 } 76 defer serverUp.Close() 77 } 78 n, err := io.Copy(serverDown, r) 79 serverDown.Close() 80 if err != nil { 81 t.Fatal(err) 82 } 83 if want := int64(size); want != n { 84 t.Errorf("want %d bytes spliced, got %d", want, n) 85 } 86 87 if tc.limitReadSize > 0 { 88 wantN := 0 89 if tc.limitReadSize > size { 90 wantN = tc.limitReadSize - size 91 } 92 93 if n := r.(*io.LimitedReader).N; n != int64(wantN) { 94 t.Errorf("r.N = %d, want %d", n, wantN) 95 } 96 } 97 } 98 99 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { 100 clientUp, serverUp := spliceTestSocketPair(t, upNet) 101 defer clientUp.Close() 102 clientDown, serverDown := spliceTestSocketPair(t, downNet) 103 defer clientDown.Close() 104 105 serverUp.Close() 106 107 // We'd like to call net.splice here and check the handled return 108 // value, but we disable splice on old Linux kernels. 109 // 110 // In that case, poll.Splice and net.splice return a non-nil error 111 // and handled == false. We'd ideally like to see handled == true 112 // because the source reader is at EOF, but if we're running on an old 113 // kernel, and splice is disabled, we won't see EOF from net.splice, 114 // because we won't touch the reader at all. 115 // 116 // Trying to untangle the errors from net.splice and match them 117 // against the errors created by the poll package would be brittle, 118 // so this is a higher level test. 119 // 120 // The following ReadFrom should return immediately, regardless of 121 // whether splice is disabled or not. The other side should then 122 // get a goodbye signal. Test for the goodbye signal. 123 msg := "bye" 124 go func() { 125 serverDown.(io.ReaderFrom).ReadFrom(serverUp) 126 io.WriteString(serverDown, msg) 127 serverDown.Close() 128 }() 129 130 buf := make([]byte, 3) 131 _, err := io.ReadFull(clientDown, buf) 132 if err != nil { 133 t.Errorf("clientDown: %v", err) 134 } 135 if string(buf) != msg { 136 t.Errorf("clientDown got %q, want %q", buf, msg) 137 } 138 } 139 140 func testSpliceIssue25985(t *testing.T, upNet, downNet string) { 141 front := newLocalListener(t, upNet) 142 defer front.Close() 143 back := newLocalListener(t, downNet) 144 defer back.Close() 145 146 var wg sync.WaitGroup 147 wg.Add(2) 148 149 proxy := func() { 150 src, err := front.Accept() 151 if err != nil { 152 return 153 } 154 dst, err := Dial(downNet, back.Addr().String()) 155 if err != nil { 156 return 157 } 158 defer dst.Close() 159 defer src.Close() 160 go func() { 161 io.Copy(src, dst) 162 wg.Done() 163 }() 164 go func() { 165 io.Copy(dst, src) 166 wg.Done() 167 }() 168 } 169 170 go proxy() 171 172 toFront, err := Dial(upNet, front.Addr().String()) 173 if err != nil { 174 t.Fatal(err) 175 } 176 177 io.WriteString(toFront, "foo") 178 toFront.Close() 179 180 fromProxy, err := back.Accept() 181 if err != nil { 182 t.Fatal(err) 183 } 184 defer fromProxy.Close() 185 186 _, err = io.ReadAll(fromProxy) 187 if err != nil { 188 t.Fatal(err) 189 } 190 191 wg.Wait() 192 } 193 194 func testSpliceNoUnixpacket(t *testing.T) { 195 clientUp, serverUp := spliceTestSocketPair(t, "unixpacket") 196 defer clientUp.Close() 197 defer serverUp.Close() 198 clientDown, serverDown := spliceTestSocketPair(t, "tcp") 199 defer clientDown.Close() 200 defer serverDown.Close() 201 // If splice called poll.Splice here, we'd get err == syscall.EINVAL 202 // and handled == false. If poll.Splice gets an EINVAL on the first 203 // try, it assumes the kernel it's running on doesn't support splice 204 // for unix sockets and returns handled == false. This works for our 205 // purposes by somewhat of an accident, but is not entirely correct. 206 // 207 // What we want is err == nil and handled == false, i.e. we never 208 // called poll.Splice, because we know the unix socket's network. 209 _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp) 210 if err != nil || handled != false { 211 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) 212 } 213 } 214 215 func testSpliceNoUnixgram(t *testing.T) { 216 addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t)) 217 if err != nil { 218 t.Fatal(err) 219 } 220 defer os.Remove(addr.Name) 221 up, err := ListenUnixgram("unixgram", addr) 222 if err != nil { 223 t.Fatal(err) 224 } 225 defer up.Close() 226 clientDown, serverDown := spliceTestSocketPair(t, "tcp") 227 defer clientDown.Close() 228 defer serverDown.Close() 229 // Analogous to testSpliceNoUnixpacket. 230 _, err, handled := splice(serverDown.(*TCPConn).fd, up) 231 if err != nil || handled != false { 232 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) 233 } 234 } 235 236 func BenchmarkSplice(b *testing.B) { 237 testHookUninstaller.Do(uninstallTestHooks) 238 239 b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) 240 b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) 241 } 242 243 func benchSplice(b *testing.B, upNet, downNet string) { 244 for i := 0; i <= 10; i++ { 245 chunkSize := 1 << uint(i+10) 246 tc := spliceTestCase{ 247 upNet: upNet, 248 downNet: downNet, 249 chunkSize: chunkSize, 250 } 251 252 b.Run(strconv.Itoa(chunkSize), tc.bench) 253 } 254 } 255 256 func (tc spliceTestCase) bench(b *testing.B) { 257 // To benchmark the genericReadFrom code path, set this to false. 258 useSplice := true 259 260 clientUp, serverUp := spliceTestSocketPair(b, tc.upNet) 261 defer serverUp.Close() 262 263 cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) 264 if err != nil { 265 b.Fatal(err) 266 } 267 defer cleanup() 268 269 clientDown, serverDown := spliceTestSocketPair(b, tc.downNet) 270 defer serverDown.Close() 271 272 cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) 273 if err != nil { 274 b.Fatal(err) 275 } 276 defer cleanup() 277 278 b.SetBytes(int64(tc.chunkSize)) 279 b.ResetTimer() 280 281 if useSplice { 282 _, err := io.Copy(serverDown, serverUp) 283 if err != nil { 284 b.Fatal(err) 285 } 286 } else { 287 type onlyReader struct { 288 io.Reader 289 } 290 _, err := io.Copy(serverDown, onlyReader{serverUp}) 291 if err != nil { 292 b.Fatal(err) 293 } 294 } 295 } 296 297 func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) { 298 t.Helper() 299 ln := newLocalListener(t, net) 300 defer ln.Close() 301 var cerr, serr error 302 acceptDone := make(chan struct{}) 303 go func() { 304 server, serr = ln.Accept() 305 acceptDone <- struct{}{} 306 }() 307 client, cerr = Dial(ln.Addr().Network(), ln.Addr().String()) 308 <-acceptDone 309 if cerr != nil { 310 if server != nil { 311 server.Close() 312 } 313 t.Fatal(cerr) 314 } 315 if serr != nil { 316 if client != nil { 317 client.Close() 318 } 319 t.Fatal(serr) 320 } 321 return client, server 322 } 323 324 func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) { 325 f, err := conn.(interface{ File() (*os.File, error) }).File() 326 if err != nil { 327 return nil, err 328 } 329 330 cmd := exec.Command(os.Args[0], os.Args[1:]...) 331 cmd.Env = []string{ 332 "GO_NET_TEST_SPLICE=1", 333 "GO_NET_TEST_SPLICE_OP=" + op, 334 "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize), 335 "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize), 336 "TMPDIR=" + os.Getenv("TMPDIR"), 337 } 338 cmd.ExtraFiles = append(cmd.ExtraFiles, f) 339 cmd.Stdout = os.Stdout 340 cmd.Stderr = os.Stderr 341 342 if err := cmd.Start(); err != nil { 343 return nil, err 344 } 345 346 donec := make(chan struct{}) 347 go func() { 348 cmd.Wait() 349 conn.Close() 350 f.Close() 351 close(donec) 352 }() 353 354 return func() { 355 select { 356 case <-donec: 357 case <-time.After(5 * time.Second): 358 log.Printf("killing splice client after 5 second shutdown timeout") 359 cmd.Process.Kill() 360 select { 361 case <-donec: 362 case <-time.After(5 * time.Second): 363 log.Printf("splice client didn't die after 10 seconds") 364 } 365 } 366 }, nil 367 } 368 369 func init() { 370 if os.Getenv("GO_NET_TEST_SPLICE") == "" { 371 return 372 } 373 defer os.Exit(0) 374 375 f := os.NewFile(uintptr(3), "splice-test-conn") 376 defer f.Close() 377 378 conn, err := FileConn(f) 379 if err != nil { 380 log.Fatal(err) 381 } 382 383 var chunkSize int 384 if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil { 385 log.Fatal(err) 386 } 387 buf := make([]byte, chunkSize) 388 389 var totalSize int 390 if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil { 391 log.Fatal(err) 392 } 393 394 var fn func([]byte) (int, error) 395 switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op { 396 case "r": 397 fn = conn.Read 398 case "w": 399 defer conn.Close() 400 401 fn = conn.Write 402 default: 403 log.Fatalf("unknown op %q", op) 404 } 405 406 var n int 407 for count := 0; count < totalSize; count += n { 408 if count+chunkSize > totalSize { 409 buf = buf[:totalSize-count] 410 } 411 412 var err error 413 if n, err = fn(buf); err != nil { 414 return 415 } 416 } 417 }