github.com/database64128/tfo-go/v2@v2.2.0/tfo_test.go (about) 1 package tfo 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "net" 9 "os" 10 "runtime" 11 "sync" 12 "syscall" 13 "testing" 14 "time" 15 ) 16 17 type mptcpStatus uint8 18 19 const ( 20 mptcpUseDefault mptcpStatus = iota 21 mptcpEnabled 22 mptcpDisabled 23 ) 24 25 type runtimeFallbackHelperFunc func(*testing.T) 26 27 func runtimeFallbackAsIs(t *testing.T) {} 28 29 func runtimeFallbackSetListenNoTFO(t *testing.T) { 30 if runtimeListenNoTFO.CompareAndSwap(false, true) { 31 t.Cleanup(func() { 32 runtimeListenNoTFO.Store(false) 33 }) 34 } 35 } 36 37 func runtimeFallbackSetDialNoTFO(t *testing.T) { 38 if v := runtimeDialTFOSupport.v.Swap(uint32(dialTFOSupportNone)); v != uint32(dialTFOSupportNone) { 39 t.Cleanup(func() { 40 runtimeDialTFOSupport.v.Store(v) 41 }) 42 } 43 } 44 45 func runtimeFallbackSetDialLinuxSendto(t *testing.T) { 46 if v := runtimeDialTFOSupport.v.Swap(uint32(dialTFOSupportLinuxSendto)); v != uint32(dialTFOSupportLinuxSendto) { 47 t.Cleanup(func() { 48 runtimeDialTFOSupport.v.Store(v) 49 }) 50 } 51 } 52 53 var listenConfigCases = []struct { 54 name string 55 listenConfig ListenConfig 56 mptcp mptcpStatus 57 setRuntimeFallback runtimeFallbackHelperFunc 58 }{ 59 {"TFO", ListenConfig{}, mptcpUseDefault, runtimeFallbackAsIs}, 60 {"TFO+RuntimeNoTFO", ListenConfig{}, mptcpUseDefault, runtimeFallbackSetListenNoTFO}, 61 {"TFO+MPTCPEnabled", ListenConfig{}, mptcpEnabled, runtimeFallbackAsIs}, 62 {"TFO+MPTCPEnabled+RuntimeNoTFO", ListenConfig{}, mptcpEnabled, runtimeFallbackSetListenNoTFO}, 63 {"TFO+MPTCPDisabled", ListenConfig{}, mptcpDisabled, runtimeFallbackAsIs}, 64 {"TFO+MPTCPDisabled+RuntimeNoTFO", ListenConfig{}, mptcpDisabled, runtimeFallbackSetListenNoTFO}, 65 {"TFO+Backlog1024", ListenConfig{Backlog: 1024}, mptcpUseDefault, runtimeFallbackAsIs}, 66 {"TFO+Backlog1024+MPTCPEnabled", ListenConfig{Backlog: 1024}, mptcpEnabled, runtimeFallbackAsIs}, 67 {"TFO+Backlog1024+MPTCPDisabled", ListenConfig{Backlog: 1024}, mptcpDisabled, runtimeFallbackAsIs}, 68 {"TFO+Backlog-1", ListenConfig{Backlog: -1}, mptcpUseDefault, runtimeFallbackAsIs}, 69 {"TFO+Backlog-1+MPTCPEnabled", ListenConfig{Backlog: -1}, mptcpEnabled, runtimeFallbackAsIs}, 70 {"TFO+Backlog-1+MPTCPDisabled", ListenConfig{Backlog: -1}, mptcpDisabled, runtimeFallbackAsIs}, 71 {"TFO+Fallback", ListenConfig{Fallback: true}, mptcpUseDefault, runtimeFallbackAsIs}, 72 {"TFO+Fallback+RuntimeNoTFO", ListenConfig{Fallback: true}, mptcpUseDefault, runtimeFallbackSetListenNoTFO}, 73 {"TFO+Fallback+MPTCPEnabled", ListenConfig{Fallback: true}, mptcpEnabled, runtimeFallbackAsIs}, 74 {"TFO+Fallback+MPTCPEnabled+RuntimeNoTFO", ListenConfig{Fallback: true}, mptcpEnabled, runtimeFallbackSetListenNoTFO}, 75 {"TFO+Fallback+MPTCPDisabled", ListenConfig{Fallback: true}, mptcpDisabled, runtimeFallbackAsIs}, 76 {"TFO+Fallback+MPTCPDisabled+RuntimeNoTFO", ListenConfig{Fallback: true}, mptcpDisabled, runtimeFallbackSetListenNoTFO}, 77 {"NoTFO", ListenConfig{DisableTFO: true}, mptcpUseDefault, runtimeFallbackAsIs}, 78 {"NoTFO+MPTCPEnabled", ListenConfig{DisableTFO: true}, mptcpEnabled, runtimeFallbackAsIs}, 79 {"NoTFO+MPTCPDisabled", ListenConfig{DisableTFO: true}, mptcpDisabled, runtimeFallbackAsIs}, 80 } 81 82 var dialerCases = []struct { 83 name string 84 dialer Dialer 85 mptcp mptcpStatus 86 setRuntimeFallback runtimeFallbackHelperFunc 87 linuxOnly bool 88 }{ 89 {"TFO", Dialer{}, mptcpUseDefault, runtimeFallbackAsIs, false}, 90 {"TFO+RuntimeNoTFO", Dialer{}, mptcpUseDefault, runtimeFallbackSetDialNoTFO, false}, 91 {"TFO+RuntimeLinuxSendto", Dialer{}, mptcpUseDefault, runtimeFallbackSetDialLinuxSendto, true}, 92 {"TFO+MPTCPEnabled", Dialer{}, mptcpEnabled, runtimeFallbackAsIs, false}, 93 {"TFO+MPTCPEnabled+RuntimeNoTFO", Dialer{}, mptcpEnabled, runtimeFallbackSetDialNoTFO, false}, 94 {"TFO+MPTCPEnabled+RuntimeLinuxSendto", Dialer{}, mptcpEnabled, runtimeFallbackSetDialLinuxSendto, true}, 95 {"TFO+MPTCPDisabled", Dialer{}, mptcpDisabled, runtimeFallbackAsIs, false}, 96 {"TFO+MPTCPDisabled+RuntimeNoTFO", Dialer{}, mptcpDisabled, runtimeFallbackSetDialNoTFO, false}, 97 {"TFO+MPTCPDisabled+RuntimeLinuxSendto", Dialer{}, mptcpDisabled, runtimeFallbackSetDialLinuxSendto, true}, 98 {"TFO+Fallback", Dialer{Fallback: true}, mptcpUseDefault, runtimeFallbackAsIs, false}, 99 {"TFO+Fallback+RuntimeNoTFO", Dialer{Fallback: true}, mptcpUseDefault, runtimeFallbackSetDialNoTFO, false}, 100 {"TFO+Fallback+RuntimeLinuxSendto", Dialer{Fallback: true}, mptcpUseDefault, runtimeFallbackSetDialLinuxSendto, true}, 101 {"TFO+Fallback+MPTCPEnabled", Dialer{Fallback: true}, mptcpEnabled, runtimeFallbackAsIs, false}, 102 {"TFO+Fallback+MPTCPEnabled+RuntimeNoTFO", Dialer{Fallback: true}, mptcpEnabled, runtimeFallbackSetDialNoTFO, false}, 103 {"TFO+Fallback+MPTCPEnabled+RuntimeLinuxSendto", Dialer{Fallback: true}, mptcpEnabled, runtimeFallbackSetDialLinuxSendto, true}, 104 {"TFO+Fallback+MPTCPDisabled", Dialer{Fallback: true}, mptcpDisabled, runtimeFallbackAsIs, false}, 105 {"TFO+Fallback+MPTCPDisabled+RuntimeNoTFO", Dialer{Fallback: true}, mptcpDisabled, runtimeFallbackSetDialNoTFO, false}, 106 {"TFO+Fallback+MPTCPDisabled+RuntimeLinuxSendto", Dialer{Fallback: true}, mptcpDisabled, runtimeFallbackSetDialLinuxSendto, true}, 107 {"NoTFO", Dialer{DisableTFO: true}, mptcpUseDefault, runtimeFallbackAsIs, false}, 108 {"NoTFO+MPTCPEnabled", Dialer{DisableTFO: true}, mptcpEnabled, runtimeFallbackAsIs, false}, 109 {"NoTFO+MPTCPDisabled", Dialer{DisableTFO: true}, mptcpDisabled, runtimeFallbackAsIs, false}, 110 } 111 112 type testCase struct { 113 name string 114 listenConfig ListenConfig 115 dialer Dialer 116 setRuntimeFallbackListen runtimeFallbackHelperFunc 117 setRuntimeFallbackDial runtimeFallbackHelperFunc 118 } 119 120 func (c testCase) Run(t *testing.T, f func(*testing.T, ListenConfig, Dialer)) { 121 t.Run(c.name, func(t *testing.T) { 122 c.setRuntimeFallbackListen(t) 123 c.setRuntimeFallbackDial(t) 124 f(t, c.listenConfig, c.dialer) 125 }) 126 } 127 128 // cases is a list of [ListenConfig] and [Dialer] combinations to test. 129 var cases []testCase 130 131 func init() { 132 // Initialize [listenConfigCases]. 133 for i := range listenConfigCases { 134 c := &listenConfigCases[i] 135 switch c.mptcp { 136 case mptcpUseDefault: 137 case mptcpEnabled: 138 c.listenConfig.SetMultipathTCP(true) 139 case mptcpDisabled: 140 c.listenConfig.SetMultipathTCP(false) 141 default: 142 panic("unreachable") 143 } 144 } 145 146 // Initialize [dialerCases]. 147 for i := range dialerCases { 148 c := &dialerCases[i] 149 switch c.mptcp { 150 case mptcpUseDefault: 151 case mptcpEnabled: 152 c.dialer.SetMultipathTCP(true) 153 case mptcpDisabled: 154 c.dialer.SetMultipathTCP(false) 155 default: 156 panic("unreachable") 157 } 158 } 159 160 // Generate [cases]. 161 cases = make([]testCase, 0, len(listenConfigCases)*len(dialerCases)) 162 for _, lc := range listenConfigCases { 163 if comptimeNoTFO && !lc.listenConfig.tfoDisabled() { 164 continue 165 } 166 for _, d := range dialerCases { 167 if comptimeNoTFO && !d.dialer.DisableTFO { 168 continue 169 } 170 switch runtime.GOOS { 171 case "linux", "android": 172 default: 173 if d.linuxOnly { 174 continue 175 } 176 } 177 cases = append(cases, testCase{ 178 name: lc.name + "/" + d.name, 179 listenConfig: lc.listenConfig, 180 dialer: d.dialer, 181 setRuntimeFallbackListen: lc.setRuntimeFallback, 182 setRuntimeFallbackDial: d.setRuntimeFallback, 183 }) 184 } 185 } 186 } 187 188 // discardTCPServer is a TCP server that accepts and drains incoming connections. 189 type discardTCPServer struct { 190 ln *net.TCPListener 191 wg sync.WaitGroup 192 } 193 194 // newDiscardTCPServer creates a new [discardTCPServer] that listens on a random port. 195 func newDiscardTCPServer(ctx context.Context) (*discardTCPServer, error) { 196 lc := ListenConfig{DisableTFO: comptimeNoTFO} 197 ln, err := lc.Listen(ctx, "tcp", "[::1]:") 198 if err != nil { 199 return nil, err 200 } 201 return &discardTCPServer{ln: ln.(*net.TCPListener)}, nil 202 } 203 204 // Addr returns the server's address. 205 func (s *discardTCPServer) Addr() *net.TCPAddr { 206 return s.ln.Addr().(*net.TCPAddr) 207 } 208 209 // Start spins up a new goroutine that accepts and drains incoming connections 210 // until [discardTCPServer.Close] is called. 211 func (s *discardTCPServer) Start(t *testing.T) { 212 s.wg.Add(1) 213 214 go func() { 215 defer s.wg.Done() 216 217 for { 218 c, err := s.ln.AcceptTCP() 219 if err != nil { 220 if errors.Is(err, os.ErrDeadlineExceeded) { 221 return 222 } 223 t.Error("AcceptTCP:", err) 224 return 225 } 226 227 go func() { 228 defer c.Close() 229 230 n, err := io.Copy(io.Discard, c) 231 if err != nil { 232 t.Error("Copy:", err) 233 } 234 t.Logf("Discarded %d bytes from %s", n, c.RemoteAddr()) 235 }() 236 } 237 }() 238 } 239 240 // Close interrupts all running accept goroutines, waits for them to finish, 241 // and closes the listener. 242 func (s *discardTCPServer) Close() { 243 s.ln.SetDeadline(aLongTimeAgo) 244 s.wg.Wait() 245 s.ln.Close() 246 } 247 248 var ( 249 hello = []byte{'h', 'e', 'l', 'l', 'o'} 250 world = []byte{'w', 'o', 'r', 'l', 'd'} 251 helloworld = []byte{'h', 'e', 'l', 'l', 'o', 'w', 'o', 'r', 'l', 'd'} 252 worldhello = []byte{'w', 'o', 'r', 'l', 'd', 'h', 'e', 'l', 'l', 'o'} 253 helloWorldSentence = []byte{'h', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'} 254 ) 255 256 func testListenDialUDP(t *testing.T, lc ListenConfig, d Dialer) { 257 pc, err := lc.ListenPacket(context.Background(), "udp", "[::1]:") 258 if err != nil { 259 t.Fatal(err) 260 } 261 uc := pc.(*net.UDPConn) 262 defer uc.Close() 263 264 c, err := d.Dial("udp", uc.LocalAddr().String(), hello) 265 if err != nil { 266 t.Fatal(err) 267 } 268 defer c.Close() 269 270 b := make([]byte, 5) 271 n, _, err := uc.ReadFromUDPAddrPort(b) 272 if err != nil { 273 t.Fatal(err) 274 } 275 if n != 5 { 276 t.Fatalf("Expected 5 bytes, got %d", n) 277 } 278 if !bytes.Equal(b, hello) { 279 t.Fatalf("Expected %v, got %v", hello, b) 280 } 281 } 282 283 // TestListenDialUDP ensures that the UDP capabilities of [ListenConfig] and 284 // [Dialer] are not affected by this package. 285 func TestListenDialUDP(t *testing.T) { 286 for _, c := range cases { 287 c.Run(t, testListenDialUDP) 288 } 289 } 290 291 // TestListenCtrlFn ensures that the user-provided [ListenConfig.Control] function 292 // is called when [ListenConfig.Listen] is called. 293 func TestListenCtrlFn(t *testing.T) { 294 for _, c := range listenConfigCases { 295 t.Run(c.name, func(t *testing.T) { 296 c.setRuntimeFallback(t) 297 testListenCtrlFn(t, c.listenConfig) 298 }) 299 } 300 } 301 302 // TestDialCtrlFn ensures that [Dialer]'s user-provided control functions 303 // are used in the same way as [net.Dialer]. 304 func TestDialCtrlFn(t *testing.T) { 305 s, err := newDiscardTCPServer(context.Background()) 306 if err != nil { 307 t.Fatal(err) 308 } 309 defer s.Close() 310 311 address := s.Addr().String() 312 313 for _, c := range dialerCases { 314 t.Run(c.name, func(t *testing.T) { 315 c.setRuntimeFallback(t) 316 testDialCtrlFn(t, c.dialer, address) 317 testDialCtrlCtxFn(t, c.dialer, address) 318 testDialCtrlCtxFnSupersedesCtrlFn(t, c.dialer, address) 319 }) 320 } 321 } 322 323 // TestAddrFunctions ensures that the address methods on [*net.TCPListener] and 324 // [*net.TCPConn] return the correct values. 325 func TestAddrFunctions(t *testing.T) { 326 for _, c := range cases { 327 c.Run(t, testAddrFunctions) 328 } 329 } 330 331 // TestClientWriteReadServerReadWrite ensures that a client can write to a server, 332 // the server can read from the client, and the server can write to the client. 333 func TestClientWriteReadServerReadWrite(t *testing.T) { 334 for _, c := range cases { 335 c.Run(t, testClientWriteReadServerReadWrite) 336 } 337 } 338 339 // TestServerWriteReadClientReadWrite ensures that a server can write to a client, 340 // the client can read from the server, and the client can write to the server. 341 func TestServerWriteReadClientReadWrite(t *testing.T) { 342 for _, c := range cases { 343 c.Run(t, testServerWriteReadClientReadWrite) 344 } 345 } 346 347 // TestClientServerReadFrom ensures that the ReadFrom method 348 // on accepted and dialed connections works as expected. 349 func TestClientServerReadFrom(t *testing.T) { 350 for _, c := range cases { 351 c.Run(t, testClientServerReadFrom) 352 } 353 } 354 355 // TestSetDeadline ensures that the SetDeadline, SetReadDeadline, and 356 // SetWriteDeadline methods on accepted and dialed connections work as expected. 357 func TestSetDeadline(t *testing.T) { 358 for _, c := range cases { 359 c.Run(t, testSetDeadline) 360 } 361 } 362 363 func testRawConnControl(t *testing.T, sc syscall.Conn) { 364 rawConn, err := sc.SyscallConn() 365 if err != nil { 366 t.Fatal(err) 367 } 368 369 var success bool 370 371 if err = rawConn.Control(func(fd uintptr) { 372 success = fd != 0 373 }); err != nil { 374 t.Fatal(err) 375 } 376 377 if !success { 378 t.Error("RawConn Control failed") 379 } 380 } 381 382 func testListenCtrlFn(t *testing.T, lc ListenConfig) { 383 var success bool 384 385 lc.Control = func(network, address string, c syscall.RawConn) error { 386 return c.Control(func(fd uintptr) { 387 success = fd != 0 388 }) 389 } 390 391 ln, err := lc.Listen(context.Background(), "tcp", "") 392 if err != nil { 393 t.Fatal(err) 394 } 395 defer ln.Close() 396 397 if !success { 398 t.Error("ListenConfig ctrlFn failed") 399 } 400 401 testRawConnControl(t, ln.(syscall.Conn)) 402 } 403 404 func testDialCtrlFn(t *testing.T, d Dialer, address string) { 405 var success bool 406 407 d.Control = func(network, address string, c syscall.RawConn) error { 408 return c.Control(func(fd uintptr) { 409 success = fd != 0 410 }) 411 } 412 413 c, err := d.Dial("tcp", address, hello) 414 if err != nil { 415 t.Fatal(err) 416 } 417 defer c.Close() 418 419 if !success { 420 t.Error("Dialer ctrlFn failed") 421 } 422 423 testRawConnControl(t, c.(syscall.Conn)) 424 } 425 426 func testDialCtrlCtxFn(t *testing.T, d Dialer, address string) { 427 type contextKey int 428 429 const ( 430 ctxKey = contextKey(64) 431 ctxVal = 128 432 ) 433 434 var success bool 435 436 d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) error { 437 return c.Control(func(fd uintptr) { 438 success = fd != 0 && ctx.Value(ctxKey) == ctxVal 439 }) 440 } 441 442 ctx := context.WithValue(context.Background(), ctxKey, ctxVal) 443 c, err := d.DialContext(ctx, "tcp", address, hello) 444 if err != nil { 445 t.Fatal(err) 446 } 447 defer c.Close() 448 449 if !success { 450 t.Error("Dialer ctrlCtxFn failed") 451 } 452 453 testRawConnControl(t, c.(syscall.Conn)) 454 } 455 456 func testDialCtrlCtxFnSupersedesCtrlFn(t *testing.T, d Dialer, address string) { 457 var ctrlCtxFnCalled bool 458 459 d.Control = func(network, address string, c syscall.RawConn) error { 460 t.Error("Dialer.Control called") 461 return nil 462 } 463 464 d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) error { 465 ctrlCtxFnCalled = true 466 return nil 467 } 468 469 c, err := d.Dial("tcp", address, hello) 470 if err != nil { 471 t.Fatal(err) 472 } 473 defer c.Close() 474 475 if !ctrlCtxFnCalled { 476 t.Error("Dialer.ControlContext not called") 477 } 478 } 479 480 func testAddrFunctions(t *testing.T, lc ListenConfig, d Dialer) { 481 ln, err := lc.Listen(context.Background(), "tcp", "[::1]:") 482 if err != nil { 483 t.Fatal(err) 484 } 485 lntcp := ln.(*net.TCPListener) 486 defer lntcp.Close() 487 488 addr := lntcp.Addr().(*net.TCPAddr) 489 if !addr.IP.Equal(net.IPv6loopback) { 490 t.Fatalf("expected unspecified IP, got %v", addr.IP) 491 } 492 if addr.Port == 0 { 493 t.Fatalf("expected non-zero port, got %d", addr.Port) 494 } 495 496 c, err := d.Dial("tcp", addr.String(), hello) 497 if err != nil { 498 t.Fatal(err) 499 } 500 defer c.Close() 501 502 if laddr := c.LocalAddr().(*net.TCPAddr); !laddr.IP.Equal(net.IPv6loopback) || laddr.Port == 0 { 503 t.Errorf("Bad local addr: %v", laddr) 504 } 505 if raddr := c.RemoteAddr().(*net.TCPAddr); !raddr.IP.Equal(net.IPv6loopback) || raddr.Port != addr.Port { 506 t.Errorf("Bad remote addr: %v", raddr) 507 } 508 } 509 510 func write(w io.Writer, data []byte, t *testing.T) { 511 dataLen := len(data) 512 n, err := w.Write(data) 513 if err != nil { 514 t.Error(err) 515 return 516 } 517 if n != dataLen { 518 t.Errorf("Wrote %d bytes, should have written %d bytes", n, dataLen) 519 } 520 } 521 522 func writeWithReadFrom(w io.ReaderFrom, data []byte, t *testing.T) { 523 r := bytes.NewReader(data) 524 n, err := w.ReadFrom(r) 525 if err != nil { 526 t.Error(err) 527 } 528 bytesWritten := int(n) 529 dataLen := len(data) 530 if bytesWritten != dataLen { 531 t.Errorf("Wrote %d bytes, should have written %d bytes", bytesWritten, dataLen) 532 } 533 } 534 535 func readExactlyOneByte(r io.Reader, expectedByte byte, t *testing.T) { 536 b := make([]byte, 1) 537 n, err := r.Read(b) 538 if err != nil { 539 t.Fatal(err) 540 } 541 if n != 1 { 542 t.Fatalf("Read %d bytes, expected 1 byte", n) 543 } 544 if b[0] != expectedByte { 545 t.Fatalf("Read unexpected byte: '%c', expected '%c'", b[0], expectedByte) 546 } 547 } 548 549 func readUntilEOF(r io.Reader, expectedData []byte, t *testing.T) { 550 b, err := io.ReadAll(r) 551 if err != nil { 552 t.Error(err) 553 return 554 } 555 if !bytes.Equal(b, expectedData) { 556 t.Errorf("Read data %v is different from original data %v", b, expectedData) 557 } 558 } 559 560 func testClientWriteReadServerReadWrite(t *testing.T, lc ListenConfig, d Dialer) { 561 t.Logf("c->s payload: %v", helloworld) 562 t.Logf("s->c payload: %v", worldhello) 563 564 ln, err := lc.Listen(context.Background(), "tcp", "[::1]:") 565 if err != nil { 566 t.Fatal(err) 567 } 568 lntcp := ln.(*net.TCPListener) 569 defer lntcp.Close() 570 t.Log("Started listener on", lntcp.Addr()) 571 572 ctrlCh := make(chan struct{}) 573 go func() { 574 conn, err := lntcp.AcceptTCP() 575 if err != nil { 576 t.Error(err) 577 return 578 } 579 defer conn.Close() 580 t.Log("Accepted", conn.RemoteAddr()) 581 582 readUntilEOF(conn, helloworld, t) 583 write(conn, world, t) 584 write(conn, hello, t) 585 conn.CloseWrite() 586 close(ctrlCh) 587 }() 588 589 c, err := d.Dial("tcp", ln.Addr().String(), hello) 590 if err != nil { 591 t.Fatal(err) 592 } 593 tc := c.(*net.TCPConn) 594 defer tc.Close() 595 596 write(tc, world, t) 597 tc.CloseWrite() 598 readUntilEOF(tc, worldhello, t) 599 <-ctrlCh 600 } 601 602 func testServerWriteReadClientReadWrite(t *testing.T, lc ListenConfig, d Dialer) { 603 t.Logf("c->s payload: %v", helloworld) 604 t.Logf("s->c payload: %v", worldhello) 605 606 ln, err := lc.Listen(context.Background(), "tcp", "[::1]:") 607 if err != nil { 608 t.Fatal(err) 609 } 610 lntcp := ln.(*net.TCPListener) 611 defer lntcp.Close() 612 t.Log("Started listener on", lntcp.Addr()) 613 614 ctrlCh := make(chan struct{}) 615 go func() { 616 conn, err := lntcp.AcceptTCP() 617 if err != nil { 618 t.Error(err) 619 return 620 } 621 t.Log("Accepted", conn.RemoteAddr()) 622 defer conn.Close() 623 624 write(conn, world, t) 625 write(conn, hello, t) 626 conn.CloseWrite() 627 readUntilEOF(conn, helloworld, t) 628 close(ctrlCh) 629 }() 630 631 c, err := d.Dial("tcp", ln.Addr().String(), nil) 632 if err != nil { 633 t.Fatal(err) 634 } 635 tc := c.(*net.TCPConn) 636 defer tc.Close() 637 638 readUntilEOF(tc, worldhello, t) 639 write(tc, hello, t) 640 write(tc, world, t) 641 tc.CloseWrite() 642 <-ctrlCh 643 } 644 645 func testClientServerReadFrom(t *testing.T, lc ListenConfig, d Dialer) { 646 t.Logf("c->s payload: %v", helloworld) 647 t.Logf("s->c payload: %v", worldhello) 648 649 ln, err := lc.Listen(context.Background(), "tcp", "[::1]:") 650 if err != nil { 651 t.Fatal(err) 652 } 653 lntcp := ln.(*net.TCPListener) 654 defer lntcp.Close() 655 t.Log("Started listener on", lntcp.Addr()) 656 657 ctrlCh := make(chan struct{}) 658 go func() { 659 conn, err := lntcp.AcceptTCP() 660 if err != nil { 661 t.Error(err) 662 return 663 } 664 defer conn.Close() 665 t.Log("Accepted", conn.RemoteAddr()) 666 667 readUntilEOF(conn, helloworld, t) 668 writeWithReadFrom(conn, world, t) 669 writeWithReadFrom(conn, hello, t) 670 conn.CloseWrite() 671 close(ctrlCh) 672 }() 673 674 c, err := d.Dial("tcp", ln.Addr().String(), hello) 675 if err != nil { 676 t.Fatal(err) 677 } 678 tc := c.(*net.TCPConn) 679 defer tc.Close() 680 681 writeWithReadFrom(tc, world, t) 682 tc.CloseWrite() 683 readUntilEOF(tc, worldhello, t) 684 <-ctrlCh 685 } 686 687 func testSetDeadline(t *testing.T, lc ListenConfig, d Dialer) { 688 t.Logf("payload: %v", helloWorldSentence) 689 690 ln, err := lc.Listen(context.Background(), "tcp", "[::1]:") 691 if err != nil { 692 t.Fatal(err) 693 } 694 lntcp := ln.(*net.TCPListener) 695 defer lntcp.Close() 696 t.Log("Started listener on", lntcp.Addr()) 697 698 ctrlCh := make(chan struct{}) 699 go func() { 700 conn, err := lntcp.AcceptTCP() 701 if err != nil { 702 t.Error(err) 703 return 704 } 705 t.Log("Accepted", conn.RemoteAddr()) 706 defer conn.Close() 707 708 write(conn, helloWorldSentence, t) 709 readUntilEOF(conn, []byte{'h', 'l', 'l', ','}, t) 710 close(ctrlCh) 711 }() 712 713 c, err := d.Dial("tcp", ln.Addr().String(), helloWorldSentence[:1]) 714 if err != nil { 715 t.Fatal(err) 716 } 717 tc := c.(*net.TCPConn) 718 defer tc.Close() 719 720 b := make([]byte, 1) 721 722 // SetReadDeadline 723 readExactlyOneByte(tc, 'h', t) 724 if err := tc.SetReadDeadline(time.Now().Add(-time.Second)); err != nil { 725 t.Fatal(err) 726 } 727 if n, err := tc.Read(b); n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) { 728 t.Fatal(n, err) 729 } 730 if err := tc.SetReadDeadline(time.Time{}); err != nil { 731 t.Fatal(err) 732 } 733 readExactlyOneByte(tc, 'e', t) 734 735 // SetWriteDeadline 736 if err := tc.SetWriteDeadline(time.Now().Add(-time.Second)); err != nil { 737 t.Fatal(err) 738 } 739 if n, err := tc.Write(helloWorldSentence[1:2]); n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) { 740 t.Fatal(n, err) 741 } 742 if err := tc.SetWriteDeadline(time.Time{}); err != nil { 743 t.Fatal(err) 744 } 745 write(tc, helloWorldSentence[2:3], t) 746 747 // SetDeadline 748 readExactlyOneByte(tc, 'l', t) 749 write(tc, helloWorldSentence[3:4], t) 750 if err := tc.SetDeadline(time.Now().Add(-time.Second)); err != nil { 751 t.Fatal(err) 752 } 753 if _, err := tc.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { 754 t.Fatal(err) 755 } 756 if n, err := tc.Write(helloWorldSentence[4:5]); n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) { 757 t.Fatal(n, err) 758 } 759 if err := tc.SetDeadline(time.Time{}); err != nil { 760 t.Fatal(err) 761 } 762 readExactlyOneByte(tc, 'l', t) 763 write(tc, helloWorldSentence[5:6], t) 764 765 tc.CloseWrite() 766 <-ctrlCh 767 }