github.com/database64128/tfo-go/v2@v2.2.0/tfo_test.go (about)

     1  package tfo
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"runtime"
    11  	"sync"
    12  	"syscall"
    13  	"testing"
    14  	"time"
    15  )
    16  
    17  type mptcpStatus uint8
    18  
    19  const (
    20  	mptcpUseDefault mptcpStatus = iota
    21  	mptcpEnabled
    22  	mptcpDisabled
    23  )
    24  
    25  type runtimeFallbackHelperFunc func(*testing.T)
    26  
    27  func runtimeFallbackAsIs(t *testing.T) {}
    28  
    29  func runtimeFallbackSetListenNoTFO(t *testing.T) {
    30  	if runtimeListenNoTFO.CompareAndSwap(false, true) {
    31  		t.Cleanup(func() {
    32  			runtimeListenNoTFO.Store(false)
    33  		})
    34  	}
    35  }
    36  
    37  func runtimeFallbackSetDialNoTFO(t *testing.T) {
    38  	if v := runtimeDialTFOSupport.v.Swap(uint32(dialTFOSupportNone)); v != uint32(dialTFOSupportNone) {
    39  		t.Cleanup(func() {
    40  			runtimeDialTFOSupport.v.Store(v)
    41  		})
    42  	}
    43  }
    44  
    45  func runtimeFallbackSetDialLinuxSendto(t *testing.T) {
    46  	if v := runtimeDialTFOSupport.v.Swap(uint32(dialTFOSupportLinuxSendto)); v != uint32(dialTFOSupportLinuxSendto) {
    47  		t.Cleanup(func() {
    48  			runtimeDialTFOSupport.v.Store(v)
    49  		})
    50  	}
    51  }
    52  
    53  var listenConfigCases = []struct {
    54  	name               string
    55  	listenConfig       ListenConfig
    56  	mptcp              mptcpStatus
    57  	setRuntimeFallback runtimeFallbackHelperFunc
    58  }{
    59  	{"TFO", ListenConfig{}, mptcpUseDefault, runtimeFallbackAsIs},
    60  	{"TFO+RuntimeNoTFO", ListenConfig{}, mptcpUseDefault, runtimeFallbackSetListenNoTFO},
    61  	{"TFO+MPTCPEnabled", ListenConfig{}, mptcpEnabled, runtimeFallbackAsIs},
    62  	{"TFO+MPTCPEnabled+RuntimeNoTFO", ListenConfig{}, mptcpEnabled, runtimeFallbackSetListenNoTFO},
    63  	{"TFO+MPTCPDisabled", ListenConfig{}, mptcpDisabled, runtimeFallbackAsIs},
    64  	{"TFO+MPTCPDisabled+RuntimeNoTFO", ListenConfig{}, mptcpDisabled, runtimeFallbackSetListenNoTFO},
    65  	{"TFO+Backlog1024", ListenConfig{Backlog: 1024}, mptcpUseDefault, runtimeFallbackAsIs},
    66  	{"TFO+Backlog1024+MPTCPEnabled", ListenConfig{Backlog: 1024}, mptcpEnabled, runtimeFallbackAsIs},
    67  	{"TFO+Backlog1024+MPTCPDisabled", ListenConfig{Backlog: 1024}, mptcpDisabled, runtimeFallbackAsIs},
    68  	{"TFO+Backlog-1", ListenConfig{Backlog: -1}, mptcpUseDefault, runtimeFallbackAsIs},
    69  	{"TFO+Backlog-1+MPTCPEnabled", ListenConfig{Backlog: -1}, mptcpEnabled, runtimeFallbackAsIs},
    70  	{"TFO+Backlog-1+MPTCPDisabled", ListenConfig{Backlog: -1}, mptcpDisabled, runtimeFallbackAsIs},
    71  	{"TFO+Fallback", ListenConfig{Fallback: true}, mptcpUseDefault, runtimeFallbackAsIs},
    72  	{"TFO+Fallback+RuntimeNoTFO", ListenConfig{Fallback: true}, mptcpUseDefault, runtimeFallbackSetListenNoTFO},
    73  	{"TFO+Fallback+MPTCPEnabled", ListenConfig{Fallback: true}, mptcpEnabled, runtimeFallbackAsIs},
    74  	{"TFO+Fallback+MPTCPEnabled+RuntimeNoTFO", ListenConfig{Fallback: true}, mptcpEnabled, runtimeFallbackSetListenNoTFO},
    75  	{"TFO+Fallback+MPTCPDisabled", ListenConfig{Fallback: true}, mptcpDisabled, runtimeFallbackAsIs},
    76  	{"TFO+Fallback+MPTCPDisabled+RuntimeNoTFO", ListenConfig{Fallback: true}, mptcpDisabled, runtimeFallbackSetListenNoTFO},
    77  	{"NoTFO", ListenConfig{DisableTFO: true}, mptcpUseDefault, runtimeFallbackAsIs},
    78  	{"NoTFO+MPTCPEnabled", ListenConfig{DisableTFO: true}, mptcpEnabled, runtimeFallbackAsIs},
    79  	{"NoTFO+MPTCPDisabled", ListenConfig{DisableTFO: true}, mptcpDisabled, runtimeFallbackAsIs},
    80  }
    81  
    82  var dialerCases = []struct {
    83  	name               string
    84  	dialer             Dialer
    85  	mptcp              mptcpStatus
    86  	setRuntimeFallback runtimeFallbackHelperFunc
    87  	linuxOnly          bool
    88  }{
    89  	{"TFO", Dialer{}, mptcpUseDefault, runtimeFallbackAsIs, false},
    90  	{"TFO+RuntimeNoTFO", Dialer{}, mptcpUseDefault, runtimeFallbackSetDialNoTFO, false},
    91  	{"TFO+RuntimeLinuxSendto", Dialer{}, mptcpUseDefault, runtimeFallbackSetDialLinuxSendto, true},
    92  	{"TFO+MPTCPEnabled", Dialer{}, mptcpEnabled, runtimeFallbackAsIs, false},
    93  	{"TFO+MPTCPEnabled+RuntimeNoTFO", Dialer{}, mptcpEnabled, runtimeFallbackSetDialNoTFO, false},
    94  	{"TFO+MPTCPEnabled+RuntimeLinuxSendto", Dialer{}, mptcpEnabled, runtimeFallbackSetDialLinuxSendto, true},
    95  	{"TFO+MPTCPDisabled", Dialer{}, mptcpDisabled, runtimeFallbackAsIs, false},
    96  	{"TFO+MPTCPDisabled+RuntimeNoTFO", Dialer{}, mptcpDisabled, runtimeFallbackSetDialNoTFO, false},
    97  	{"TFO+MPTCPDisabled+RuntimeLinuxSendto", Dialer{}, mptcpDisabled, runtimeFallbackSetDialLinuxSendto, true},
    98  	{"TFO+Fallback", Dialer{Fallback: true}, mptcpUseDefault, runtimeFallbackAsIs, false},
    99  	{"TFO+Fallback+RuntimeNoTFO", Dialer{Fallback: true}, mptcpUseDefault, runtimeFallbackSetDialNoTFO, false},
   100  	{"TFO+Fallback+RuntimeLinuxSendto", Dialer{Fallback: true}, mptcpUseDefault, runtimeFallbackSetDialLinuxSendto, true},
   101  	{"TFO+Fallback+MPTCPEnabled", Dialer{Fallback: true}, mptcpEnabled, runtimeFallbackAsIs, false},
   102  	{"TFO+Fallback+MPTCPEnabled+RuntimeNoTFO", Dialer{Fallback: true}, mptcpEnabled, runtimeFallbackSetDialNoTFO, false},
   103  	{"TFO+Fallback+MPTCPEnabled+RuntimeLinuxSendto", Dialer{Fallback: true}, mptcpEnabled, runtimeFallbackSetDialLinuxSendto, true},
   104  	{"TFO+Fallback+MPTCPDisabled", Dialer{Fallback: true}, mptcpDisabled, runtimeFallbackAsIs, false},
   105  	{"TFO+Fallback+MPTCPDisabled+RuntimeNoTFO", Dialer{Fallback: true}, mptcpDisabled, runtimeFallbackSetDialNoTFO, false},
   106  	{"TFO+Fallback+MPTCPDisabled+RuntimeLinuxSendto", Dialer{Fallback: true}, mptcpDisabled, runtimeFallbackSetDialLinuxSendto, true},
   107  	{"NoTFO", Dialer{DisableTFO: true}, mptcpUseDefault, runtimeFallbackAsIs, false},
   108  	{"NoTFO+MPTCPEnabled", Dialer{DisableTFO: true}, mptcpEnabled, runtimeFallbackAsIs, false},
   109  	{"NoTFO+MPTCPDisabled", Dialer{DisableTFO: true}, mptcpDisabled, runtimeFallbackAsIs, false},
   110  }
   111  
   112  type testCase struct {
   113  	name                     string
   114  	listenConfig             ListenConfig
   115  	dialer                   Dialer
   116  	setRuntimeFallbackListen runtimeFallbackHelperFunc
   117  	setRuntimeFallbackDial   runtimeFallbackHelperFunc
   118  }
   119  
   120  func (c testCase) Run(t *testing.T, f func(*testing.T, ListenConfig, Dialer)) {
   121  	t.Run(c.name, func(t *testing.T) {
   122  		c.setRuntimeFallbackListen(t)
   123  		c.setRuntimeFallbackDial(t)
   124  		f(t, c.listenConfig, c.dialer)
   125  	})
   126  }
   127  
   128  // cases is a list of [ListenConfig] and [Dialer] combinations to test.
   129  var cases []testCase
   130  
   131  func init() {
   132  	// Initialize [listenConfigCases].
   133  	for i := range listenConfigCases {
   134  		c := &listenConfigCases[i]
   135  		switch c.mptcp {
   136  		case mptcpUseDefault:
   137  		case mptcpEnabled:
   138  			c.listenConfig.SetMultipathTCP(true)
   139  		case mptcpDisabled:
   140  			c.listenConfig.SetMultipathTCP(false)
   141  		default:
   142  			panic("unreachable")
   143  		}
   144  	}
   145  
   146  	// Initialize [dialerCases].
   147  	for i := range dialerCases {
   148  		c := &dialerCases[i]
   149  		switch c.mptcp {
   150  		case mptcpUseDefault:
   151  		case mptcpEnabled:
   152  			c.dialer.SetMultipathTCP(true)
   153  		case mptcpDisabled:
   154  			c.dialer.SetMultipathTCP(false)
   155  		default:
   156  			panic("unreachable")
   157  		}
   158  	}
   159  
   160  	// Generate [cases].
   161  	cases = make([]testCase, 0, len(listenConfigCases)*len(dialerCases))
   162  	for _, lc := range listenConfigCases {
   163  		if comptimeNoTFO && !lc.listenConfig.tfoDisabled() {
   164  			continue
   165  		}
   166  		for _, d := range dialerCases {
   167  			if comptimeNoTFO && !d.dialer.DisableTFO {
   168  				continue
   169  			}
   170  			switch runtime.GOOS {
   171  			case "linux", "android":
   172  			default:
   173  				if d.linuxOnly {
   174  					continue
   175  				}
   176  			}
   177  			cases = append(cases, testCase{
   178  				name:                     lc.name + "/" + d.name,
   179  				listenConfig:             lc.listenConfig,
   180  				dialer:                   d.dialer,
   181  				setRuntimeFallbackListen: lc.setRuntimeFallback,
   182  				setRuntimeFallbackDial:   d.setRuntimeFallback,
   183  			})
   184  		}
   185  	}
   186  }
   187  
   188  // discardTCPServer is a TCP server that accepts and drains incoming connections.
   189  type discardTCPServer struct {
   190  	ln *net.TCPListener
   191  	wg sync.WaitGroup
   192  }
   193  
   194  // newDiscardTCPServer creates a new [discardTCPServer] that listens on a random port.
   195  func newDiscardTCPServer(ctx context.Context) (*discardTCPServer, error) {
   196  	lc := ListenConfig{DisableTFO: comptimeNoTFO}
   197  	ln, err := lc.Listen(ctx, "tcp", "[::1]:")
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	return &discardTCPServer{ln: ln.(*net.TCPListener)}, nil
   202  }
   203  
   204  // Addr returns the server's address.
   205  func (s *discardTCPServer) Addr() *net.TCPAddr {
   206  	return s.ln.Addr().(*net.TCPAddr)
   207  }
   208  
   209  // Start spins up a new goroutine that accepts and drains incoming connections
   210  // until [discardTCPServer.Close] is called.
   211  func (s *discardTCPServer) Start(t *testing.T) {
   212  	s.wg.Add(1)
   213  
   214  	go func() {
   215  		defer s.wg.Done()
   216  
   217  		for {
   218  			c, err := s.ln.AcceptTCP()
   219  			if err != nil {
   220  				if errors.Is(err, os.ErrDeadlineExceeded) {
   221  					return
   222  				}
   223  				t.Error("AcceptTCP:", err)
   224  				return
   225  			}
   226  
   227  			go func() {
   228  				defer c.Close()
   229  
   230  				n, err := io.Copy(io.Discard, c)
   231  				if err != nil {
   232  					t.Error("Copy:", err)
   233  				}
   234  				t.Logf("Discarded %d bytes from %s", n, c.RemoteAddr())
   235  			}()
   236  		}
   237  	}()
   238  }
   239  
   240  // Close interrupts all running accept goroutines, waits for them to finish,
   241  // and closes the listener.
   242  func (s *discardTCPServer) Close() {
   243  	s.ln.SetDeadline(aLongTimeAgo)
   244  	s.wg.Wait()
   245  	s.ln.Close()
   246  }
   247  
   248  var (
   249  	hello              = []byte{'h', 'e', 'l', 'l', 'o'}
   250  	world              = []byte{'w', 'o', 'r', 'l', 'd'}
   251  	helloworld         = []byte{'h', 'e', 'l', 'l', 'o', 'w', 'o', 'r', 'l', 'd'}
   252  	worldhello         = []byte{'w', 'o', 'r', 'l', 'd', 'h', 'e', 'l', 'l', 'o'}
   253  	helloWorldSentence = []byte{'h', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'}
   254  )
   255  
   256  func testListenDialUDP(t *testing.T, lc ListenConfig, d Dialer) {
   257  	pc, err := lc.ListenPacket(context.Background(), "udp", "[::1]:")
   258  	if err != nil {
   259  		t.Fatal(err)
   260  	}
   261  	uc := pc.(*net.UDPConn)
   262  	defer uc.Close()
   263  
   264  	c, err := d.Dial("udp", uc.LocalAddr().String(), hello)
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  	defer c.Close()
   269  
   270  	b := make([]byte, 5)
   271  	n, _, err := uc.ReadFromUDPAddrPort(b)
   272  	if err != nil {
   273  		t.Fatal(err)
   274  	}
   275  	if n != 5 {
   276  		t.Fatalf("Expected 5 bytes, got %d", n)
   277  	}
   278  	if !bytes.Equal(b, hello) {
   279  		t.Fatalf("Expected %v, got %v", hello, b)
   280  	}
   281  }
   282  
   283  // TestListenDialUDP ensures that the UDP capabilities of [ListenConfig] and
   284  // [Dialer] are not affected by this package.
   285  func TestListenDialUDP(t *testing.T) {
   286  	for _, c := range cases {
   287  		c.Run(t, testListenDialUDP)
   288  	}
   289  }
   290  
   291  // TestListenCtrlFn ensures that the user-provided [ListenConfig.Control] function
   292  // is called when [ListenConfig.Listen] is called.
   293  func TestListenCtrlFn(t *testing.T) {
   294  	for _, c := range listenConfigCases {
   295  		t.Run(c.name, func(t *testing.T) {
   296  			c.setRuntimeFallback(t)
   297  			testListenCtrlFn(t, c.listenConfig)
   298  		})
   299  	}
   300  }
   301  
   302  // TestDialCtrlFn ensures that [Dialer]'s user-provided control functions
   303  // are used in the same way as [net.Dialer].
   304  func TestDialCtrlFn(t *testing.T) {
   305  	s, err := newDiscardTCPServer(context.Background())
   306  	if err != nil {
   307  		t.Fatal(err)
   308  	}
   309  	defer s.Close()
   310  
   311  	address := s.Addr().String()
   312  
   313  	for _, c := range dialerCases {
   314  		t.Run(c.name, func(t *testing.T) {
   315  			c.setRuntimeFallback(t)
   316  			testDialCtrlFn(t, c.dialer, address)
   317  			testDialCtrlCtxFn(t, c.dialer, address)
   318  			testDialCtrlCtxFnSupersedesCtrlFn(t, c.dialer, address)
   319  		})
   320  	}
   321  }
   322  
   323  // TestAddrFunctions ensures that the address methods on [*net.TCPListener] and
   324  // [*net.TCPConn] return the correct values.
   325  func TestAddrFunctions(t *testing.T) {
   326  	for _, c := range cases {
   327  		c.Run(t, testAddrFunctions)
   328  	}
   329  }
   330  
   331  // TestClientWriteReadServerReadWrite ensures that a client can write to a server,
   332  // the server can read from the client, and the server can write to the client.
   333  func TestClientWriteReadServerReadWrite(t *testing.T) {
   334  	for _, c := range cases {
   335  		c.Run(t, testClientWriteReadServerReadWrite)
   336  	}
   337  }
   338  
   339  // TestServerWriteReadClientReadWrite ensures that a server can write to a client,
   340  // the client can read from the server, and the client can write to the server.
   341  func TestServerWriteReadClientReadWrite(t *testing.T) {
   342  	for _, c := range cases {
   343  		c.Run(t, testServerWriteReadClientReadWrite)
   344  	}
   345  }
   346  
   347  // TestClientServerReadFrom ensures that the ReadFrom method
   348  // on accepted and dialed connections works as expected.
   349  func TestClientServerReadFrom(t *testing.T) {
   350  	for _, c := range cases {
   351  		c.Run(t, testClientServerReadFrom)
   352  	}
   353  }
   354  
   355  // TestSetDeadline ensures that the SetDeadline, SetReadDeadline, and
   356  // SetWriteDeadline methods on accepted and dialed connections work as expected.
   357  func TestSetDeadline(t *testing.T) {
   358  	for _, c := range cases {
   359  		c.Run(t, testSetDeadline)
   360  	}
   361  }
   362  
   363  func testRawConnControl(t *testing.T, sc syscall.Conn) {
   364  	rawConn, err := sc.SyscallConn()
   365  	if err != nil {
   366  		t.Fatal(err)
   367  	}
   368  
   369  	var success bool
   370  
   371  	if err = rawConn.Control(func(fd uintptr) {
   372  		success = fd != 0
   373  	}); err != nil {
   374  		t.Fatal(err)
   375  	}
   376  
   377  	if !success {
   378  		t.Error("RawConn Control failed")
   379  	}
   380  }
   381  
   382  func testListenCtrlFn(t *testing.T, lc ListenConfig) {
   383  	var success bool
   384  
   385  	lc.Control = func(network, address string, c syscall.RawConn) error {
   386  		return c.Control(func(fd uintptr) {
   387  			success = fd != 0
   388  		})
   389  	}
   390  
   391  	ln, err := lc.Listen(context.Background(), "tcp", "")
   392  	if err != nil {
   393  		t.Fatal(err)
   394  	}
   395  	defer ln.Close()
   396  
   397  	if !success {
   398  		t.Error("ListenConfig ctrlFn failed")
   399  	}
   400  
   401  	testRawConnControl(t, ln.(syscall.Conn))
   402  }
   403  
   404  func testDialCtrlFn(t *testing.T, d Dialer, address string) {
   405  	var success bool
   406  
   407  	d.Control = func(network, address string, c syscall.RawConn) error {
   408  		return c.Control(func(fd uintptr) {
   409  			success = fd != 0
   410  		})
   411  	}
   412  
   413  	c, err := d.Dial("tcp", address, hello)
   414  	if err != nil {
   415  		t.Fatal(err)
   416  	}
   417  	defer c.Close()
   418  
   419  	if !success {
   420  		t.Error("Dialer ctrlFn failed")
   421  	}
   422  
   423  	testRawConnControl(t, c.(syscall.Conn))
   424  }
   425  
   426  func testDialCtrlCtxFn(t *testing.T, d Dialer, address string) {
   427  	type contextKey int
   428  
   429  	const (
   430  		ctxKey = contextKey(64)
   431  		ctxVal = 128
   432  	)
   433  
   434  	var success bool
   435  
   436  	d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) error {
   437  		return c.Control(func(fd uintptr) {
   438  			success = fd != 0 && ctx.Value(ctxKey) == ctxVal
   439  		})
   440  	}
   441  
   442  	ctx := context.WithValue(context.Background(), ctxKey, ctxVal)
   443  	c, err := d.DialContext(ctx, "tcp", address, hello)
   444  	if err != nil {
   445  		t.Fatal(err)
   446  	}
   447  	defer c.Close()
   448  
   449  	if !success {
   450  		t.Error("Dialer ctrlCtxFn failed")
   451  	}
   452  
   453  	testRawConnControl(t, c.(syscall.Conn))
   454  }
   455  
   456  func testDialCtrlCtxFnSupersedesCtrlFn(t *testing.T, d Dialer, address string) {
   457  	var ctrlCtxFnCalled bool
   458  
   459  	d.Control = func(network, address string, c syscall.RawConn) error {
   460  		t.Error("Dialer.Control called")
   461  		return nil
   462  	}
   463  
   464  	d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) error {
   465  		ctrlCtxFnCalled = true
   466  		return nil
   467  	}
   468  
   469  	c, err := d.Dial("tcp", address, hello)
   470  	if err != nil {
   471  		t.Fatal(err)
   472  	}
   473  	defer c.Close()
   474  
   475  	if !ctrlCtxFnCalled {
   476  		t.Error("Dialer.ControlContext not called")
   477  	}
   478  }
   479  
   480  func testAddrFunctions(t *testing.T, lc ListenConfig, d Dialer) {
   481  	ln, err := lc.Listen(context.Background(), "tcp", "[::1]:")
   482  	if err != nil {
   483  		t.Fatal(err)
   484  	}
   485  	lntcp := ln.(*net.TCPListener)
   486  	defer lntcp.Close()
   487  
   488  	addr := lntcp.Addr().(*net.TCPAddr)
   489  	if !addr.IP.Equal(net.IPv6loopback) {
   490  		t.Fatalf("expected unspecified IP, got %v", addr.IP)
   491  	}
   492  	if addr.Port == 0 {
   493  		t.Fatalf("expected non-zero port, got %d", addr.Port)
   494  	}
   495  
   496  	c, err := d.Dial("tcp", addr.String(), hello)
   497  	if err != nil {
   498  		t.Fatal(err)
   499  	}
   500  	defer c.Close()
   501  
   502  	if laddr := c.LocalAddr().(*net.TCPAddr); !laddr.IP.Equal(net.IPv6loopback) || laddr.Port == 0 {
   503  		t.Errorf("Bad local addr: %v", laddr)
   504  	}
   505  	if raddr := c.RemoteAddr().(*net.TCPAddr); !raddr.IP.Equal(net.IPv6loopback) || raddr.Port != addr.Port {
   506  		t.Errorf("Bad remote addr: %v", raddr)
   507  	}
   508  }
   509  
   510  func write(w io.Writer, data []byte, t *testing.T) {
   511  	dataLen := len(data)
   512  	n, err := w.Write(data)
   513  	if err != nil {
   514  		t.Error(err)
   515  		return
   516  	}
   517  	if n != dataLen {
   518  		t.Errorf("Wrote %d bytes, should have written %d bytes", n, dataLen)
   519  	}
   520  }
   521  
   522  func writeWithReadFrom(w io.ReaderFrom, data []byte, t *testing.T) {
   523  	r := bytes.NewReader(data)
   524  	n, err := w.ReadFrom(r)
   525  	if err != nil {
   526  		t.Error(err)
   527  	}
   528  	bytesWritten := int(n)
   529  	dataLen := len(data)
   530  	if bytesWritten != dataLen {
   531  		t.Errorf("Wrote %d bytes, should have written %d bytes", bytesWritten, dataLen)
   532  	}
   533  }
   534  
   535  func readExactlyOneByte(r io.Reader, expectedByte byte, t *testing.T) {
   536  	b := make([]byte, 1)
   537  	n, err := r.Read(b)
   538  	if err != nil {
   539  		t.Fatal(err)
   540  	}
   541  	if n != 1 {
   542  		t.Fatalf("Read %d bytes, expected 1 byte", n)
   543  	}
   544  	if b[0] != expectedByte {
   545  		t.Fatalf("Read unexpected byte: '%c', expected '%c'", b[0], expectedByte)
   546  	}
   547  }
   548  
   549  func readUntilEOF(r io.Reader, expectedData []byte, t *testing.T) {
   550  	b, err := io.ReadAll(r)
   551  	if err != nil {
   552  		t.Error(err)
   553  		return
   554  	}
   555  	if !bytes.Equal(b, expectedData) {
   556  		t.Errorf("Read data %v is different from original data %v", b, expectedData)
   557  	}
   558  }
   559  
   560  func testClientWriteReadServerReadWrite(t *testing.T, lc ListenConfig, d Dialer) {
   561  	t.Logf("c->s payload: %v", helloworld)
   562  	t.Logf("s->c payload: %v", worldhello)
   563  
   564  	ln, err := lc.Listen(context.Background(), "tcp", "[::1]:")
   565  	if err != nil {
   566  		t.Fatal(err)
   567  	}
   568  	lntcp := ln.(*net.TCPListener)
   569  	defer lntcp.Close()
   570  	t.Log("Started listener on", lntcp.Addr())
   571  
   572  	ctrlCh := make(chan struct{})
   573  	go func() {
   574  		conn, err := lntcp.AcceptTCP()
   575  		if err != nil {
   576  			t.Error(err)
   577  			return
   578  		}
   579  		defer conn.Close()
   580  		t.Log("Accepted", conn.RemoteAddr())
   581  
   582  		readUntilEOF(conn, helloworld, t)
   583  		write(conn, world, t)
   584  		write(conn, hello, t)
   585  		conn.CloseWrite()
   586  		close(ctrlCh)
   587  	}()
   588  
   589  	c, err := d.Dial("tcp", ln.Addr().String(), hello)
   590  	if err != nil {
   591  		t.Fatal(err)
   592  	}
   593  	tc := c.(*net.TCPConn)
   594  	defer tc.Close()
   595  
   596  	write(tc, world, t)
   597  	tc.CloseWrite()
   598  	readUntilEOF(tc, worldhello, t)
   599  	<-ctrlCh
   600  }
   601  
   602  func testServerWriteReadClientReadWrite(t *testing.T, lc ListenConfig, d Dialer) {
   603  	t.Logf("c->s payload: %v", helloworld)
   604  	t.Logf("s->c payload: %v", worldhello)
   605  
   606  	ln, err := lc.Listen(context.Background(), "tcp", "[::1]:")
   607  	if err != nil {
   608  		t.Fatal(err)
   609  	}
   610  	lntcp := ln.(*net.TCPListener)
   611  	defer lntcp.Close()
   612  	t.Log("Started listener on", lntcp.Addr())
   613  
   614  	ctrlCh := make(chan struct{})
   615  	go func() {
   616  		conn, err := lntcp.AcceptTCP()
   617  		if err != nil {
   618  			t.Error(err)
   619  			return
   620  		}
   621  		t.Log("Accepted", conn.RemoteAddr())
   622  		defer conn.Close()
   623  
   624  		write(conn, world, t)
   625  		write(conn, hello, t)
   626  		conn.CloseWrite()
   627  		readUntilEOF(conn, helloworld, t)
   628  		close(ctrlCh)
   629  	}()
   630  
   631  	c, err := d.Dial("tcp", ln.Addr().String(), nil)
   632  	if err != nil {
   633  		t.Fatal(err)
   634  	}
   635  	tc := c.(*net.TCPConn)
   636  	defer tc.Close()
   637  
   638  	readUntilEOF(tc, worldhello, t)
   639  	write(tc, hello, t)
   640  	write(tc, world, t)
   641  	tc.CloseWrite()
   642  	<-ctrlCh
   643  }
   644  
   645  func testClientServerReadFrom(t *testing.T, lc ListenConfig, d Dialer) {
   646  	t.Logf("c->s payload: %v", helloworld)
   647  	t.Logf("s->c payload: %v", worldhello)
   648  
   649  	ln, err := lc.Listen(context.Background(), "tcp", "[::1]:")
   650  	if err != nil {
   651  		t.Fatal(err)
   652  	}
   653  	lntcp := ln.(*net.TCPListener)
   654  	defer lntcp.Close()
   655  	t.Log("Started listener on", lntcp.Addr())
   656  
   657  	ctrlCh := make(chan struct{})
   658  	go func() {
   659  		conn, err := lntcp.AcceptTCP()
   660  		if err != nil {
   661  			t.Error(err)
   662  			return
   663  		}
   664  		defer conn.Close()
   665  		t.Log("Accepted", conn.RemoteAddr())
   666  
   667  		readUntilEOF(conn, helloworld, t)
   668  		writeWithReadFrom(conn, world, t)
   669  		writeWithReadFrom(conn, hello, t)
   670  		conn.CloseWrite()
   671  		close(ctrlCh)
   672  	}()
   673  
   674  	c, err := d.Dial("tcp", ln.Addr().String(), hello)
   675  	if err != nil {
   676  		t.Fatal(err)
   677  	}
   678  	tc := c.(*net.TCPConn)
   679  	defer tc.Close()
   680  
   681  	writeWithReadFrom(tc, world, t)
   682  	tc.CloseWrite()
   683  	readUntilEOF(tc, worldhello, t)
   684  	<-ctrlCh
   685  }
   686  
   687  func testSetDeadline(t *testing.T, lc ListenConfig, d Dialer) {
   688  	t.Logf("payload: %v", helloWorldSentence)
   689  
   690  	ln, err := lc.Listen(context.Background(), "tcp", "[::1]:")
   691  	if err != nil {
   692  		t.Fatal(err)
   693  	}
   694  	lntcp := ln.(*net.TCPListener)
   695  	defer lntcp.Close()
   696  	t.Log("Started listener on", lntcp.Addr())
   697  
   698  	ctrlCh := make(chan struct{})
   699  	go func() {
   700  		conn, err := lntcp.AcceptTCP()
   701  		if err != nil {
   702  			t.Error(err)
   703  			return
   704  		}
   705  		t.Log("Accepted", conn.RemoteAddr())
   706  		defer conn.Close()
   707  
   708  		write(conn, helloWorldSentence, t)
   709  		readUntilEOF(conn, []byte{'h', 'l', 'l', ','}, t)
   710  		close(ctrlCh)
   711  	}()
   712  
   713  	c, err := d.Dial("tcp", ln.Addr().String(), helloWorldSentence[:1])
   714  	if err != nil {
   715  		t.Fatal(err)
   716  	}
   717  	tc := c.(*net.TCPConn)
   718  	defer tc.Close()
   719  
   720  	b := make([]byte, 1)
   721  
   722  	// SetReadDeadline
   723  	readExactlyOneByte(tc, 'h', t)
   724  	if err := tc.SetReadDeadline(time.Now().Add(-time.Second)); err != nil {
   725  		t.Fatal(err)
   726  	}
   727  	if n, err := tc.Read(b); n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
   728  		t.Fatal(n, err)
   729  	}
   730  	if err := tc.SetReadDeadline(time.Time{}); err != nil {
   731  		t.Fatal(err)
   732  	}
   733  	readExactlyOneByte(tc, 'e', t)
   734  
   735  	// SetWriteDeadline
   736  	if err := tc.SetWriteDeadline(time.Now().Add(-time.Second)); err != nil {
   737  		t.Fatal(err)
   738  	}
   739  	if n, err := tc.Write(helloWorldSentence[1:2]); n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
   740  		t.Fatal(n, err)
   741  	}
   742  	if err := tc.SetWriteDeadline(time.Time{}); err != nil {
   743  		t.Fatal(err)
   744  	}
   745  	write(tc, helloWorldSentence[2:3], t)
   746  
   747  	// SetDeadline
   748  	readExactlyOneByte(tc, 'l', t)
   749  	write(tc, helloWorldSentence[3:4], t)
   750  	if err := tc.SetDeadline(time.Now().Add(-time.Second)); err != nil {
   751  		t.Fatal(err)
   752  	}
   753  	if _, err := tc.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
   754  		t.Fatal(err)
   755  	}
   756  	if n, err := tc.Write(helloWorldSentence[4:5]); n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
   757  		t.Fatal(n, err)
   758  	}
   759  	if err := tc.SetDeadline(time.Time{}); err != nil {
   760  		t.Fatal(err)
   761  	}
   762  	readExactlyOneByte(tc, 'l', t)
   763  	write(tc, helloWorldSentence[5:6], t)
   764  
   765  	tc.CloseWrite()
   766  	<-ctrlCh
   767  }