github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/tcp_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"math"
    22  	"strings"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	"github.com/SagerNet/gvisor/pkg/rand"
    28  	"github.com/SagerNet/gvisor/pkg/sync"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/link/loopback"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/link/sniffer"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    36  	"github.com/SagerNet/gvisor/pkg/tcpip/seqnum"
    37  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    38  	tcpiptestutil "github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    39  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp"
    40  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp/testing/context"
    41  	"github.com/SagerNet/gvisor/pkg/test/testutil"
    42  	"github.com/SagerNet/gvisor/pkg/waiter"
    43  )
    44  
    45  // endpointTester provides helper functions to test a tcpip.Endpoint.
    46  type endpointTester struct {
    47  	ep tcpip.Endpoint
    48  }
    49  
    50  // CheckReadError issues a read to the endpoint and checking for an error.
    51  func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) {
    52  	t.Helper()
    53  	res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
    54  	if got != want {
    55  		t.Fatalf("ep.Read = %s, want %s", got, want)
    56  	}
    57  	if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
    58  		t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
    59  	}
    60  }
    61  
    62  // CheckRead issues a read to the endpoint and checking for a success, returning
    63  // the data read.
    64  func (e *endpointTester) CheckRead(t *testing.T) []byte {
    65  	t.Helper()
    66  	var buf bytes.Buffer
    67  	res, err := e.ep.Read(&buf, tcpip.ReadOptions{})
    68  	if err != nil {
    69  		t.Fatalf("ep.Read = _, %s; want _, nil", err)
    70  	}
    71  	if diff := cmp.Diff(tcpip.ReadResult{
    72  		Count: buf.Len(),
    73  		Total: buf.Len(),
    74  	}, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
    75  		t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
    76  	}
    77  	return buf.Bytes()
    78  }
    79  
    80  // CheckReadFull reads from the endpoint for exactly count bytes.
    81  func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
    82  	t.Helper()
    83  	var buf bytes.Buffer
    84  	w := tcpip.LimitedWriter{
    85  		W: &buf,
    86  		N: int64(count),
    87  	}
    88  	for w.N != 0 {
    89  		_, err := e.ep.Read(&w, tcpip.ReadOptions{})
    90  		if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
    91  			// Wait for receive to be notified.
    92  			select {
    93  			case <-notifyRead:
    94  			case <-time.After(timeout):
    95  				t.Fatalf("Timed out waiting for data to arrive")
    96  			}
    97  			continue
    98  		} else if err != nil {
    99  			t.Fatalf("ep.Read = _, %s; want _, nil", err)
   100  		}
   101  	}
   102  	return buf.Bytes()
   103  }
   104  
   105  const (
   106  	// defaultMTU is the MTU, in bytes, used throughout the tests, except
   107  	// where another value is explicitly used. It is chosen to match the MTU
   108  	// of loopback interfaces on linux systems.
   109  	defaultMTU = 65535
   110  
   111  	// defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an
   112  	// IPv4 endpoint when the MTU is set to defaultMTU in the test.
   113  	defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
   114  )
   115  
   116  func TestGiveUpConnect(t *testing.T) {
   117  	c := context.New(t, defaultMTU)
   118  	defer c.Cleanup()
   119  
   120  	var wq waiter.Queue
   121  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   122  	if err != nil {
   123  		t.Fatalf("NewEndpoint failed: %s", err)
   124  	}
   125  
   126  	// Register for notification, then start connection attempt.
   127  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
   128  	wq.EventRegister(&waitEntry, waiter.EventHUp)
   129  	defer wq.EventUnregister(&waitEntry)
   130  
   131  	{
   132  		err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
   133  		if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
   134  			t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
   135  		}
   136  	}
   137  
   138  	// Close the connection, wait for completion.
   139  	ep.Close()
   140  
   141  	// Wait for ep to become writable.
   142  	<-notifyCh
   143  
   144  	// Call Connect again to retreive the handshake failure status
   145  	// and stats updates.
   146  	{
   147  		err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
   148  		if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" {
   149  			t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
   150  		}
   151  	}
   152  
   153  	if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
   154  		t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
   155  	}
   156  
   157  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
   158  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
   159  	}
   160  }
   161  
   162  // Test for ICMP error handling without completing handshake.
   163  func TestConnectICMPError(t *testing.T) {
   164  	c := context.New(t, defaultMTU)
   165  	defer c.Cleanup()
   166  
   167  	var wq waiter.Queue
   168  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   169  	if err != nil {
   170  		t.Fatalf("NewEndpoint failed: %s", err)
   171  	}
   172  
   173  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
   174  	wq.EventRegister(&waitEntry, waiter.EventHUp)
   175  	defer wq.EventUnregister(&waitEntry)
   176  
   177  	{
   178  		err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
   179  		if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
   180  			t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
   181  		}
   182  	}
   183  
   184  	syn := c.GetPacket()
   185  	checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn)))
   186  
   187  	wep := ep.(interface {
   188  		StopWork()
   189  		ResumeWork()
   190  		LastErrorLocked() tcpip.Error
   191  	})
   192  
   193  	// Stop the protocol loop, ensure that the ICMP error is processed and
   194  	// the last ICMP error is read before the loop is resumed. This sanity
   195  	// tests the handshake completion logic on ICMP errors.
   196  	wep.StopWork()
   197  
   198  	c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU)
   199  
   200  	for {
   201  		if err := wep.LastErrorLocked(); err != nil {
   202  			if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
   203  				t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d)
   204  			}
   205  			break
   206  		}
   207  		time.Sleep(time.Millisecond)
   208  	}
   209  
   210  	wep.ResumeWork()
   211  
   212  	<-notifyCh
   213  
   214  	// The stack would have unregistered the endpoint because of the ICMP error.
   215  	// Expect a RST for any subsequent packets sent to the endpoint.
   216  	c.SendPacket(nil, &context.Headers{
   217  		SrcPort: context.TestPort,
   218  		DstPort: context.StackPort,
   219  		Flags:   header.TCPFlagAck,
   220  		SeqNum:  seqnum.Value(context.TestInitialSequenceNumber) + 1,
   221  		AckNum:  c.IRS + 1,
   222  	})
   223  
   224  	checker.IPv4(t, c.GetPacket(), checker.TCP(
   225  		checker.SrcPort(context.StackPort),
   226  		checker.DstPort(context.TestPort),
   227  		checker.TCPSeqNum(uint32(c.IRS+1)),
   228  		checker.TCPAckNum(0),
   229  		checker.TCPFlags(header.TCPFlagRst)))
   230  }
   231  
   232  func TestConnectIncrementActiveConnection(t *testing.T) {
   233  	c := context.New(t, defaultMTU)
   234  	defer c.Cleanup()
   235  
   236  	stats := c.Stack().Stats()
   237  	want := stats.TCP.ActiveConnectionOpenings.Value() + 1
   238  
   239  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   240  	if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
   241  		t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
   242  	}
   243  }
   244  
   245  func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
   246  	c := context.New(t, defaultMTU)
   247  	defer c.Cleanup()
   248  
   249  	stats := c.Stack().Stats()
   250  	want := stats.TCP.FailedConnectionAttempts.Value()
   251  
   252  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   253  	if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
   254  		t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
   255  	}
   256  	if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
   257  		t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
   258  	}
   259  }
   260  
   261  func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
   262  	c := context.New(t, defaultMTU)
   263  	defer c.Cleanup()
   264  
   265  	stats := c.Stack().Stats()
   266  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
   267  	if err != nil {
   268  		t.Fatalf("NewEndpoint failed: %s", err)
   269  	}
   270  	c.EP = ep
   271  	want := stats.TCP.FailedConnectionAttempts.Value() + 1
   272  
   273  	{
   274  		err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
   275  		if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
   276  			t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
   277  		}
   278  	}
   279  
   280  	if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
   281  		t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
   282  	}
   283  	if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
   284  		t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
   285  	}
   286  }
   287  
   288  func TestCloseWithoutConnect(t *testing.T) {
   289  	c := context.New(t, defaultMTU)
   290  	defer c.Cleanup()
   291  
   292  	// Create TCP endpoint.
   293  	var err tcpip.Error
   294  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
   295  	if err != nil {
   296  		t.Fatalf("NewEndpoint failed: %s", err)
   297  	}
   298  
   299  	c.EP.Close()
   300  
   301  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
   302  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
   303  	}
   304  }
   305  
   306  func TestTCPSegmentsSentIncrement(t *testing.T) {
   307  	c := context.New(t, defaultMTU)
   308  	defer c.Cleanup()
   309  
   310  	stats := c.Stack().Stats()
   311  	// SYN and ACK
   312  	want := stats.TCP.SegmentsSent.Value() + 2
   313  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   314  
   315  	if got := stats.TCP.SegmentsSent.Value(); got != want {
   316  		t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
   317  	}
   318  	if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
   319  		t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
   320  	}
   321  }
   322  
   323  func TestTCPResetsSentIncrement(t *testing.T) {
   324  	c := context.New(t, defaultMTU)
   325  	defer c.Cleanup()
   326  	stats := c.Stack().Stats()
   327  	wq := &waiter.Queue{}
   328  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
   329  	if err != nil {
   330  		t.Fatalf("NewEndpoint failed: %s", err)
   331  	}
   332  	want := stats.TCP.SegmentsSent.Value() + 1
   333  
   334  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   335  		t.Fatalf("Bind failed: %s", err)
   336  	}
   337  
   338  	if err := ep.Listen(10); err != nil {
   339  		t.Fatalf("Listen failed: %s", err)
   340  	}
   341  
   342  	// Send a SYN request.
   343  	iss := seqnum.Value(context.TestInitialSequenceNumber)
   344  	c.SendPacket(nil, &context.Headers{
   345  		SrcPort: context.TestPort,
   346  		DstPort: context.StackPort,
   347  		Flags:   header.TCPFlagSyn,
   348  		SeqNum:  iss,
   349  	})
   350  
   351  	// Receive the SYN-ACK reply.
   352  	b := c.GetPacket()
   353  	tcpHdr := header.TCP(header.IPv4(b).Payload())
   354  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
   355  
   356  	ackHeaders := &context.Headers{
   357  		SrcPort: context.TestPort,
   358  		DstPort: context.StackPort,
   359  		Flags:   header.TCPFlagAck,
   360  		SeqNum:  iss + 1,
   361  		// If the AckNum is not the increment of the last sequence number, a RST
   362  		// segment is sent back in response.
   363  		AckNum: c.IRS + 2,
   364  	}
   365  
   366  	// Send ACK.
   367  	c.SendPacket(nil, ackHeaders)
   368  
   369  	c.GetPacket()
   370  
   371  	metricPollFn := func() error {
   372  		if got := stats.TCP.ResetsSent.Value(); got != want {
   373  			return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
   374  		}
   375  		return nil
   376  	}
   377  	if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
   378  		t.Error(err)
   379  	}
   380  }
   381  
   382  // TestTCPResetsSentNoICMP confirms that we don't get an ICMP
   383  // DstUnreachable packet when we try send a packet which is not part
   384  // of an active session.
   385  func TestTCPResetsSentNoICMP(t *testing.T) {
   386  	c := context.New(t, defaultMTU)
   387  	defer c.Cleanup()
   388  	stats := c.Stack().Stats()
   389  
   390  	// Send a SYN request for a closed port. This should elicit an RST
   391  	// but NOT an ICMPv4 DstUnreachable packet.
   392  	iss := seqnum.Value(context.TestInitialSequenceNumber)
   393  	c.SendPacket(nil, &context.Headers{
   394  		SrcPort: context.TestPort,
   395  		DstPort: context.StackPort,
   396  		Flags:   header.TCPFlagSyn,
   397  		SeqNum:  iss,
   398  	})
   399  
   400  	// Receive whatever comes back.
   401  	b := c.GetPacket()
   402  	ipHdr := header.IPv4(b)
   403  	if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
   404  		t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
   405  	}
   406  
   407  	// Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
   408  	sent := stats.ICMP.V4.PacketsSent
   409  	if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
   410  		t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
   411  	}
   412  }
   413  
   414  // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
   415  // a RST if an ACK is received on the listening socket for which there is no
   416  // active handshake in progress and we are not using SYN cookies.
   417  func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
   418  	c := context.New(t, defaultMTU)
   419  	defer c.Cleanup()
   420  
   421  	// Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
   422  	wq := &waiter.Queue{}
   423  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
   424  	if err != nil {
   425  		t.Fatalf("NewEndpoint failed: %s", err)
   426  	}
   427  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   428  		t.Fatalf("Bind failed: %s", err)
   429  	}
   430  
   431  	if err := ep.Listen(10); err != nil {
   432  		t.Fatalf("Listen failed: %s", err)
   433  	}
   434  
   435  	// Send a SYN request.
   436  	iss := seqnum.Value(context.TestInitialSequenceNumber)
   437  	c.SendPacket(nil, &context.Headers{
   438  		SrcPort: context.TestPort,
   439  		DstPort: context.StackPort,
   440  		Flags:   header.TCPFlagSyn,
   441  		SeqNum:  iss,
   442  	})
   443  
   444  	// Receive the SYN-ACK reply.
   445  	b := c.GetPacket()
   446  	tcpHdr := header.TCP(header.IPv4(b).Payload())
   447  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
   448  
   449  	ackHeaders := &context.Headers{
   450  		SrcPort: context.TestPort,
   451  		DstPort: context.StackPort,
   452  		Flags:   header.TCPFlagAck,
   453  		SeqNum:  iss + 1,
   454  		AckNum:  c.IRS + 1,
   455  	}
   456  
   457  	// Send ACK.
   458  	c.SendPacket(nil, ackHeaders)
   459  
   460  	// Try to accept the connection.
   461  	we, ch := waiter.NewChannelEntry(nil)
   462  	wq.EventRegister(&we, waiter.ReadableEvents)
   463  	defer wq.EventUnregister(&we)
   464  
   465  	c.EP, _, err = ep.Accept(nil)
   466  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
   467  		// Wait for connection to be established.
   468  		select {
   469  		case <-ch:
   470  			c.EP, _, err = ep.Accept(nil)
   471  			if err != nil {
   472  				t.Fatalf("Accept failed: %s", err)
   473  			}
   474  
   475  		case <-time.After(1 * time.Second):
   476  			t.Fatalf("Timed out waiting for accept")
   477  		}
   478  	}
   479  
   480  	// Lower stackwide TIME_WAIT timeout so that the reservations
   481  	// are released instantly on Close.
   482  	tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
   483  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil {
   484  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err)
   485  	}
   486  
   487  	c.EP.Close()
   488  	checker.IPv4(t, c.GetPacket(), checker.TCP(
   489  		checker.SrcPort(context.StackPort),
   490  		checker.DstPort(context.TestPort),
   491  		checker.TCPSeqNum(uint32(c.IRS+1)),
   492  		checker.TCPAckNum(uint32(iss)+1),
   493  		checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
   494  	finHeaders := &context.Headers{
   495  		SrcPort: context.TestPort,
   496  		DstPort: context.StackPort,
   497  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
   498  		SeqNum:  iss + 1,
   499  		AckNum:  c.IRS + 2,
   500  	}
   501  
   502  	c.SendPacket(nil, finHeaders)
   503  
   504  	// Get the ACK to the FIN we just sent.
   505  	c.GetPacket()
   506  
   507  	// Since an active close was done we need to wait for a little more than
   508  	// tcpLingerTimeout for the port reservations to be released and the
   509  	// socket to move to a CLOSED state.
   510  	time.Sleep(20 * time.Millisecond)
   511  
   512  	// Now resend the same ACK, this ACK should generate a RST as there
   513  	// should be no endpoint in SYN-RCVD state and we are not using
   514  	// syn-cookies yet. The reason we send the same ACK is we need a valid
   515  	// cookie(IRS) generated by the netstack without which the ACK will be
   516  	// rejected.
   517  	c.SendPacket(nil, ackHeaders)
   518  
   519  	checker.IPv4(t, c.GetPacket(), checker.TCP(
   520  		checker.SrcPort(context.StackPort),
   521  		checker.DstPort(context.TestPort),
   522  		checker.TCPSeqNum(uint32(c.IRS+1)),
   523  		checker.TCPAckNum(0),
   524  		checker.TCPFlags(header.TCPFlagRst)))
   525  }
   526  
   527  func TestTCPResetsReceivedIncrement(t *testing.T) {
   528  	c := context.New(t, defaultMTU)
   529  	defer c.Cleanup()
   530  
   531  	stats := c.Stack().Stats()
   532  	want := stats.TCP.ResetsReceived.Value() + 1
   533  	iss := seqnum.Value(context.TestInitialSequenceNumber)
   534  	rcvWnd := seqnum.Size(30000)
   535  	c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
   536  
   537  	c.SendPacket(nil, &context.Headers{
   538  		SrcPort: context.TestPort,
   539  		DstPort: c.Port,
   540  		SeqNum:  iss.Add(1),
   541  		AckNum:  c.IRS.Add(1),
   542  		RcvWnd:  rcvWnd,
   543  		Flags:   header.TCPFlagRst,
   544  	})
   545  
   546  	if got := stats.TCP.ResetsReceived.Value(); got != want {
   547  		t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
   548  	}
   549  }
   550  
   551  func TestTCPResetsDoNotGenerateResets(t *testing.T) {
   552  	c := context.New(t, defaultMTU)
   553  	defer c.Cleanup()
   554  
   555  	stats := c.Stack().Stats()
   556  	want := stats.TCP.ResetsReceived.Value() + 1
   557  	iss := seqnum.Value(context.TestInitialSequenceNumber)
   558  	rcvWnd := seqnum.Size(30000)
   559  	c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
   560  
   561  	c.SendPacket(nil, &context.Headers{
   562  		SrcPort: context.TestPort,
   563  		DstPort: c.Port,
   564  		SeqNum:  iss.Add(1),
   565  		AckNum:  c.IRS.Add(1),
   566  		RcvWnd:  rcvWnd,
   567  		Flags:   header.TCPFlagRst,
   568  	})
   569  
   570  	if got := stats.TCP.ResetsReceived.Value(); got != want {
   571  		t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
   572  	}
   573  	c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
   574  }
   575  
   576  func TestActiveHandshake(t *testing.T) {
   577  	c := context.New(t, defaultMTU)
   578  	defer c.Cleanup()
   579  
   580  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   581  }
   582  
   583  func TestNonBlockingClose(t *testing.T) {
   584  	c := context.New(t, defaultMTU)
   585  	defer c.Cleanup()
   586  
   587  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   588  	ep := c.EP
   589  	c.EP = nil
   590  
   591  	// Close the endpoint and measure how long it takes.
   592  	t0 := time.Now()
   593  	ep.Close()
   594  	if diff := time.Now().Sub(t0); diff > 3*time.Second {
   595  		t.Fatalf("Took too long to close: %s", diff)
   596  	}
   597  }
   598  
   599  func TestConnectResetAfterClose(t *testing.T) {
   600  	c := context.New(t, defaultMTU)
   601  	defer c.Cleanup()
   602  
   603  	// Set TCPLinger to 3 seconds so that sockets are marked closed
   604  	// after 3 second in FIN_WAIT2 state.
   605  	tcpLingerTimeout := 3 * time.Second
   606  	opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)
   607  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
   608  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
   609  	}
   610  
   611  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   612  	ep := c.EP
   613  	c.EP = nil
   614  
   615  	// Close the endpoint, make sure we get a FIN segment, then acknowledge
   616  	// to complete closure of sender, but don't send our own FIN.
   617  	ep.Close()
   618  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   619  	checker.IPv4(t, c.GetPacket(),
   620  		checker.TCP(
   621  			checker.DstPort(context.TestPort),
   622  			checker.TCPSeqNum(uint32(c.IRS)+1),
   623  			checker.TCPAckNum(uint32(iss)),
   624  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
   625  		),
   626  	)
   627  	c.SendPacket(nil, &context.Headers{
   628  		SrcPort: context.TestPort,
   629  		DstPort: c.Port,
   630  		Flags:   header.TCPFlagAck,
   631  		SeqNum:  iss,
   632  		AckNum:  c.IRS.Add(2),
   633  		RcvWnd:  30000,
   634  	})
   635  
   636  	// Wait for the ep to give up waiting for a FIN.
   637  	time.Sleep(tcpLingerTimeout + 1*time.Second)
   638  
   639  	// Now send an ACK and it should trigger a RST as the endpoint should
   640  	// not exist anymore.
   641  	c.SendPacket(nil, &context.Headers{
   642  		SrcPort: context.TestPort,
   643  		DstPort: c.Port,
   644  		Flags:   header.TCPFlagAck,
   645  		SeqNum:  iss,
   646  		AckNum:  c.IRS.Add(2),
   647  		RcvWnd:  30000,
   648  	})
   649  
   650  	for {
   651  		b := c.GetPacket()
   652  		tcpHdr := header.TCP(header.IPv4(b).Payload())
   653  		if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
   654  			// This is a retransmit of the FIN, ignore it.
   655  			continue
   656  		}
   657  
   658  		checker.IPv4(t, b,
   659  			checker.TCP(
   660  				checker.DstPort(context.TestPort),
   661  				// RST is always generated with sndNxt which if the FIN
   662  				// has been sent will be 1 higher than the sequence number
   663  				// of the FIN itself.
   664  				checker.TCPSeqNum(uint32(c.IRS)+2),
   665  				checker.TCPAckNum(0),
   666  				checker.TCPFlags(header.TCPFlagRst),
   667  			),
   668  		)
   669  		break
   670  	}
   671  }
   672  
   673  // TestCurrentConnectedIncrement tests increment of the current
   674  // established and connected counters.
   675  func TestCurrentConnectedIncrement(t *testing.T) {
   676  	c := context.New(t, defaultMTU)
   677  	defer c.Cleanup()
   678  
   679  	// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
   680  	// after 1 second in TIME_WAIT state.
   681  	tcpTimeWaitTimeout := 1 * time.Second
   682  	opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
   683  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
   684  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
   685  	}
   686  
   687  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   688  	ep := c.EP
   689  	c.EP = nil
   690  
   691  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
   692  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
   693  	}
   694  	gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
   695  	if gotConnected != 1 {
   696  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
   697  	}
   698  
   699  	ep.Close()
   700  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   701  	checker.IPv4(t, c.GetPacket(),
   702  		checker.TCP(
   703  			checker.DstPort(context.TestPort),
   704  			checker.TCPSeqNum(uint32(c.IRS)+1),
   705  			checker.TCPAckNum(uint32(iss)),
   706  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
   707  		),
   708  	)
   709  	c.SendPacket(nil, &context.Headers{
   710  		SrcPort: context.TestPort,
   711  		DstPort: c.Port,
   712  		Flags:   header.TCPFlagAck,
   713  		SeqNum:  iss,
   714  		AckNum:  c.IRS.Add(2),
   715  		RcvWnd:  30000,
   716  	})
   717  
   718  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
   719  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
   720  	}
   721  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
   722  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
   723  	}
   724  
   725  	// Ack and send FIN as well.
   726  	c.SendPacket(nil, &context.Headers{
   727  		SrcPort: context.TestPort,
   728  		DstPort: c.Port,
   729  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
   730  		SeqNum:  iss,
   731  		AckNum:  c.IRS.Add(2),
   732  		RcvWnd:  30000,
   733  	})
   734  
   735  	// Check that the stack acks the FIN.
   736  	checker.IPv4(t, c.GetPacket(),
   737  		checker.PayloadLen(header.TCPMinimumSize),
   738  		checker.TCP(
   739  			checker.DstPort(context.TestPort),
   740  			checker.TCPSeqNum(uint32(c.IRS)+2),
   741  			checker.TCPAckNum(uint32(iss)+1),
   742  			checker.TCPFlags(header.TCPFlagAck),
   743  		),
   744  	)
   745  
   746  	// Wait for a little more than the TIME-WAIT duration for the socket to
   747  	// transition to CLOSED state.
   748  	time.Sleep(1200 * time.Millisecond)
   749  
   750  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
   751  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
   752  	}
   753  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
   754  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
   755  	}
   756  }
   757  
   758  // TestClosingWithEnqueuedSegments tests handling of still enqueued segments
   759  // when the endpoint transitions to StateClose. The in-flight segments would be
   760  // re-enqueued to a any listening endpoint.
   761  func TestClosingWithEnqueuedSegments(t *testing.T) {
   762  	c := context.New(t, defaultMTU)
   763  	defer c.Cleanup()
   764  
   765  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   766  	ep := c.EP
   767  	c.EP = nil
   768  
   769  	if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
   770  		t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
   771  	}
   772  
   773  	// Send a FIN for ESTABLISHED --> CLOSED-WAIT
   774  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   775  	c.SendPacket(nil, &context.Headers{
   776  		SrcPort: context.TestPort,
   777  		DstPort: c.Port,
   778  		Flags:   header.TCPFlagFin | header.TCPFlagAck,
   779  		SeqNum:  iss,
   780  		AckNum:  c.IRS.Add(1),
   781  		RcvWnd:  30000,
   782  	})
   783  
   784  	// Get the ACK for the FIN we sent.
   785  	checker.IPv4(t, c.GetPacket(),
   786  		checker.TCP(
   787  			checker.DstPort(context.TestPort),
   788  			checker.TCPSeqNum(uint32(c.IRS)+1),
   789  			checker.TCPAckNum(uint32(iss)+1),
   790  			checker.TCPFlags(header.TCPFlagAck),
   791  		),
   792  	)
   793  
   794  	// Give the stack a few ms to transition the endpoint out of ESTABLISHED
   795  	// state.
   796  	time.Sleep(10 * time.Millisecond)
   797  
   798  	if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
   799  		t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
   800  	}
   801  
   802  	// Close the application endpoint for CLOSE_WAIT --> LAST_ACK
   803  	ep.Close()
   804  
   805  	// Get the FIN
   806  	checker.IPv4(t, c.GetPacket(),
   807  		checker.TCP(
   808  			checker.DstPort(context.TestPort),
   809  			checker.TCPSeqNum(uint32(c.IRS)+1),
   810  			checker.TCPAckNum(uint32(iss)+1),
   811  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
   812  		),
   813  	)
   814  
   815  	if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
   816  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
   817  	}
   818  
   819  	// Pause the endpoint`s protocolMainLoop.
   820  	ep.(interface{ StopWork() }).StopWork()
   821  
   822  	// Enqueue last ACK followed by an ACK matching the endpoint
   823  	//
   824  	// Send Last ACK for LAST_ACK --> CLOSED
   825  	c.SendPacket(nil, &context.Headers{
   826  		SrcPort: context.TestPort,
   827  		DstPort: c.Port,
   828  		Flags:   header.TCPFlagAck,
   829  		SeqNum:  iss.Add(1),
   830  		AckNum:  c.IRS.Add(2),
   831  		RcvWnd:  30000,
   832  	})
   833  
   834  	// Send a packet with ACK set, this would generate RST when
   835  	// not using SYN cookies as in this test.
   836  	c.SendPacket(nil, &context.Headers{
   837  		SrcPort: context.TestPort,
   838  		DstPort: c.Port,
   839  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
   840  		SeqNum:  iss.Add(2),
   841  		AckNum:  c.IRS.Add(2),
   842  		RcvWnd:  30000,
   843  	})
   844  
   845  	// Unpause endpoint`s protocolMainLoop.
   846  	ep.(interface{ ResumeWork() }).ResumeWork()
   847  
   848  	// Wait for the protocolMainLoop to resume and update state.
   849  	time.Sleep(10 * time.Millisecond)
   850  
   851  	// Expect the endpoint to be closed.
   852  	if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
   853  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
   854  	}
   855  
   856  	if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
   857  		t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
   858  	}
   859  
   860  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
   861  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
   862  	}
   863  
   864  	// Check if the endpoint was moved to CLOSED and netstack a reset in
   865  	// response to the ACK packet that we sent after last-ACK.
   866  	checker.IPv4(t, c.GetPacket(),
   867  		checker.TCP(
   868  			checker.DstPort(context.TestPort),
   869  			checker.TCPSeqNum(uint32(c.IRS)+2),
   870  			checker.TCPAckNum(0),
   871  			checker.TCPFlags(header.TCPFlagRst),
   872  		),
   873  	)
   874  }
   875  
   876  func TestSimpleReceive(t *testing.T) {
   877  	c := context.New(t, defaultMTU)
   878  	defer c.Cleanup()
   879  
   880  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
   881  
   882  	we, ch := waiter.NewChannelEntry(nil)
   883  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
   884  	defer c.WQ.EventUnregister(&we)
   885  
   886  	ept := endpointTester{c.EP}
   887  
   888  	data := []byte{1, 2, 3}
   889  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   890  	c.SendPacket(data, &context.Headers{
   891  		SrcPort: context.TestPort,
   892  		DstPort: c.Port,
   893  		Flags:   header.TCPFlagAck,
   894  		SeqNum:  iss,
   895  		AckNum:  c.IRS.Add(1),
   896  		RcvWnd:  30000,
   897  	})
   898  
   899  	// Wait for receive to be notified.
   900  	select {
   901  	case <-ch:
   902  	case <-time.After(1 * time.Second):
   903  		t.Fatalf("Timed out waiting for data to arrive")
   904  	}
   905  
   906  	// Receive data.
   907  	v := ept.CheckRead(t)
   908  	if !bytes.Equal(data, v) {
   909  		t.Fatalf("got data = %v, want = %v", v, data)
   910  	}
   911  
   912  	// Check that ACK is received.
   913  	checker.IPv4(t, c.GetPacket(),
   914  		checker.TCP(
   915  			checker.DstPort(context.TestPort),
   916  			checker.TCPSeqNum(uint32(c.IRS)+1),
   917  			checker.TCPAckNum(uint32(iss)+uint32(len(data))),
   918  			checker.TCPFlags(header.TCPFlagAck),
   919  		),
   920  	)
   921  }
   922  
   923  // TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
   924  // creating a new active TCP socket. It should be present in the sent TCP
   925  // SYN segment.
   926  func TestUserSuppliedMSSOnConnect(t *testing.T) {
   927  	const mtu = 5000
   928  
   929  	ips := []struct {
   930  		name        string
   931  		createEP    func(*context.Context)
   932  		connectAddr tcpip.Address
   933  		checker     func(*testing.T, *context.Context, uint16, int)
   934  		maxMSS      uint16
   935  	}{
   936  		{
   937  			name: "IPv4",
   938  			createEP: func(c *context.Context) {
   939  				c.Create(-1)
   940  			},
   941  			connectAddr: context.TestAddr,
   942  			checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
   943  				checker.IPv4(t, c.GetPacket(), checker.TCP(
   944  					checker.DstPort(context.TestPort),
   945  					checker.TCPFlags(header.TCPFlagSyn),
   946  					checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
   947  			},
   948  			maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
   949  		},
   950  		{
   951  			name: "IPv6",
   952  			createEP: func(c *context.Context) {
   953  				c.CreateV6Endpoint(true)
   954  			},
   955  			connectAddr: context.TestV6Addr,
   956  			checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
   957  				checker.IPv6(t, c.GetV6Packet(), checker.TCP(
   958  					checker.DstPort(context.TestPort),
   959  					checker.TCPFlags(header.TCPFlagSyn),
   960  					checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
   961  			},
   962  			maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
   963  		},
   964  	}
   965  
   966  	for _, ip := range ips {
   967  		t.Run(ip.name, func(t *testing.T) {
   968  			tests := []struct {
   969  				name   string
   970  				setMSS uint16
   971  				expMSS uint16
   972  			}{
   973  				{
   974  					name:   "EqualToMaxMSS",
   975  					setMSS: ip.maxMSS,
   976  					expMSS: ip.maxMSS,
   977  				},
   978  				{
   979  					name:   "LessThanMaxMSS",
   980  					setMSS: ip.maxMSS - 1,
   981  					expMSS: ip.maxMSS - 1,
   982  				},
   983  				{
   984  					name:   "GreaterThanMaxMSS",
   985  					setMSS: ip.maxMSS + 1,
   986  					expMSS: ip.maxMSS,
   987  				},
   988  			}
   989  
   990  			for _, test := range tests {
   991  				t.Run(test.name, func(t *testing.T) {
   992  					c := context.New(t, mtu)
   993  					defer c.Cleanup()
   994  
   995  					ip.createEP(c)
   996  
   997  					// Set the MSS socket option.
   998  					if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
   999  						t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
  1000  					}
  1001  
  1002  					// Get expected window size.
  1003  					rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize()
  1004  					ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
  1005  
  1006  					connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
  1007  					{
  1008  						err := c.EP.Connect(connectAddr)
  1009  						if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
  1010  							t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d)
  1011  						}
  1012  					}
  1013  
  1014  					// Receive SYN packet with our user supplied MSS.
  1015  					ip.checker(t, c, test.expMSS, ws)
  1016  				})
  1017  			}
  1018  		})
  1019  	}
  1020  }
  1021  
  1022  // TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
  1023  // when completing the handshake for a new TCP connection from a TCP
  1024  // listening socket. It should be present in the sent TCP SYN-ACK segment.
  1025  func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
  1026  	const mtu = 5000
  1027  
  1028  	ips := []struct {
  1029  		name     string
  1030  		createEP func(*context.Context)
  1031  		sendPkt  func(*context.Context, *context.Headers)
  1032  		checker  func(*testing.T, *context.Context, uint16, uint16)
  1033  		maxMSS   uint16
  1034  	}{
  1035  		{
  1036  			name: "IPv4",
  1037  			createEP: func(c *context.Context) {
  1038  				c.Create(-1)
  1039  			},
  1040  			sendPkt: func(c *context.Context, h *context.Headers) {
  1041  				c.SendPacket(nil, h)
  1042  			},
  1043  			checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
  1044  				checker.IPv4(t, c.GetPacket(), checker.TCP(
  1045  					checker.DstPort(srcPort),
  1046  					checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
  1047  					checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
  1048  			},
  1049  			maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
  1050  		},
  1051  		{
  1052  			name: "IPv6",
  1053  			createEP: func(c *context.Context) {
  1054  				c.CreateV6Endpoint(false)
  1055  			},
  1056  			sendPkt: func(c *context.Context, h *context.Headers) {
  1057  				c.SendV6Packet(nil, h)
  1058  			},
  1059  			checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
  1060  				checker.IPv6(t, c.GetV6Packet(), checker.TCP(
  1061  					checker.DstPort(srcPort),
  1062  					checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
  1063  					checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
  1064  			},
  1065  			maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
  1066  		},
  1067  	}
  1068  
  1069  	for _, ip := range ips {
  1070  		t.Run(ip.name, func(t *testing.T) {
  1071  			tests := []struct {
  1072  				name   string
  1073  				setMSS uint16
  1074  				expMSS uint16
  1075  			}{
  1076  				{
  1077  					name:   "EqualToMaxMSS",
  1078  					setMSS: ip.maxMSS,
  1079  					expMSS: ip.maxMSS,
  1080  				},
  1081  				{
  1082  					name:   "LessThanMaxMSS",
  1083  					setMSS: ip.maxMSS - 1,
  1084  					expMSS: ip.maxMSS - 1,
  1085  				},
  1086  				{
  1087  					name:   "GreaterThanMaxMSS",
  1088  					setMSS: ip.maxMSS + 1,
  1089  					expMSS: ip.maxMSS,
  1090  				},
  1091  			}
  1092  
  1093  			for _, test := range tests {
  1094  				t.Run(test.name, func(t *testing.T) {
  1095  					c := context.New(t, mtu)
  1096  					defer c.Cleanup()
  1097  
  1098  					ip.createEP(c)
  1099  
  1100  					if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
  1101  						t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
  1102  					}
  1103  
  1104  					bindAddr := tcpip.FullAddress{Port: context.StackPort}
  1105  					if err := c.EP.Bind(bindAddr); err != nil {
  1106  						t.Fatalf("Bind(%+v): %s:", bindAddr, err)
  1107  					}
  1108  
  1109  					backlog := 5
  1110  					// Keep the number of client requests twice to the backlog
  1111  					// such that half of the connections do not use syncookies
  1112  					// and the other half does.
  1113  					clientConnects := backlog * 2
  1114  
  1115  					if err := c.EP.Listen(backlog); err != nil {
  1116  						t.Fatalf("Listen(%d): %s:", backlog, err)
  1117  					}
  1118  
  1119  					for i := 0; i < clientConnects; i++ {
  1120  						// Send a SYN requests.
  1121  						iss := seqnum.Value(i)
  1122  						srcPort := context.TestPort + uint16(i)
  1123  						ip.sendPkt(c, &context.Headers{
  1124  							SrcPort: srcPort,
  1125  							DstPort: context.StackPort,
  1126  							Flags:   header.TCPFlagSyn,
  1127  							SeqNum:  iss,
  1128  						})
  1129  
  1130  						// Receive the SYN-ACK reply.
  1131  						ip.checker(t, c, srcPort, test.expMSS)
  1132  					}
  1133  				})
  1134  			}
  1135  		})
  1136  	}
  1137  }
  1138  func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
  1139  	c := context.New(t, defaultMTU)
  1140  	defer c.Cleanup()
  1141  
  1142  	c.Create(-1)
  1143  
  1144  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1145  		t.Fatal("Bind failed:", err)
  1146  	}
  1147  
  1148  	if err := c.EP.Listen(10); err != nil {
  1149  		t.Fatal("Listen failed:", err)
  1150  	}
  1151  
  1152  	c.SendPacket(nil, &context.Headers{
  1153  		SrcPort: context.TestPort,
  1154  		DstPort: context.StackPort,
  1155  		Flags:   header.TCPFlagSyn | header.TCPFlagAck,
  1156  		SeqNum:  100,
  1157  		AckNum:  200,
  1158  	})
  1159  
  1160  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  1161  		checker.DstPort(context.TestPort),
  1162  		checker.TCPFlags(header.TCPFlagRst),
  1163  		checker.TCPSeqNum(200)))
  1164  }
  1165  
  1166  func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
  1167  	c := context.New(t, defaultMTU)
  1168  	defer c.Cleanup()
  1169  
  1170  	c.CreateV6Endpoint(true)
  1171  
  1172  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1173  		t.Fatal("Bind failed:", err)
  1174  	}
  1175  
  1176  	if err := c.EP.Listen(10); err != nil {
  1177  		t.Fatal("Listen failed:", err)
  1178  	}
  1179  
  1180  	c.SendV6Packet(nil, &context.Headers{
  1181  		SrcPort: context.TestPort,
  1182  		DstPort: context.StackPort,
  1183  		Flags:   header.TCPFlagSyn | header.TCPFlagAck,
  1184  		SeqNum:  100,
  1185  		AckNum:  200,
  1186  	})
  1187  
  1188  	checker.IPv6(t, c.GetV6Packet(), checker.TCP(
  1189  		checker.DstPort(context.TestPort),
  1190  		checker.TCPFlags(header.TCPFlagRst),
  1191  		checker.TCPSeqNum(200)))
  1192  }
  1193  
  1194  // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
  1195  // peers can send data and expect a response within a reasonable ammount of time
  1196  // without calling Accept on the listening endpoint first.
  1197  //
  1198  // This test uses IPv4.
  1199  func TestTCPAckBeforeAcceptV4(t *testing.T) {
  1200  	c := context.New(t, defaultMTU)
  1201  	defer c.Cleanup()
  1202  
  1203  	c.Create(-1)
  1204  
  1205  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1206  		t.Fatal("Bind failed:", err)
  1207  	}
  1208  
  1209  	if err := c.EP.Listen(10); err != nil {
  1210  		t.Fatal("Listen failed:", err)
  1211  	}
  1212  
  1213  	irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
  1214  
  1215  	// Send data before accepting the connection.
  1216  	c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
  1217  		SrcPort: context.TestPort,
  1218  		DstPort: context.StackPort,
  1219  		Flags:   header.TCPFlagAck,
  1220  		SeqNum:  irs + 1,
  1221  		AckNum:  iss + 1,
  1222  	})
  1223  
  1224  	// Receive ACK for the data we sent.
  1225  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  1226  		checker.DstPort(context.TestPort),
  1227  		checker.TCPFlags(header.TCPFlagAck),
  1228  		checker.TCPSeqNum(uint32(iss+1)),
  1229  		checker.TCPAckNum(uint32(irs+5))))
  1230  }
  1231  
  1232  // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
  1233  // peers can send data and expect a response within a reasonable ammount of time
  1234  // without calling Accept on the listening endpoint first.
  1235  //
  1236  // This test uses IPv6.
  1237  func TestTCPAckBeforeAcceptV6(t *testing.T) {
  1238  	c := context.New(t, defaultMTU)
  1239  	defer c.Cleanup()
  1240  
  1241  	c.CreateV6Endpoint(true)
  1242  
  1243  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1244  		t.Fatal("Bind failed:", err)
  1245  	}
  1246  
  1247  	if err := c.EP.Listen(10); err != nil {
  1248  		t.Fatal("Listen failed:", err)
  1249  	}
  1250  
  1251  	irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
  1252  
  1253  	// Send data before accepting the connection.
  1254  	c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
  1255  		SrcPort: context.TestPort,
  1256  		DstPort: context.StackPort,
  1257  		Flags:   header.TCPFlagAck,
  1258  		SeqNum:  irs + 1,
  1259  		AckNum:  iss + 1,
  1260  	})
  1261  
  1262  	// Receive ACK for the data we sent.
  1263  	checker.IPv6(t, c.GetV6Packet(), checker.TCP(
  1264  		checker.DstPort(context.TestPort),
  1265  		checker.TCPFlags(header.TCPFlagAck),
  1266  		checker.TCPSeqNum(uint32(iss+1)),
  1267  		checker.TCPAckNum(uint32(irs+5))))
  1268  }
  1269  
  1270  func TestSendRstOnListenerRxAckV4(t *testing.T) {
  1271  	c := context.New(t, defaultMTU)
  1272  	defer c.Cleanup()
  1273  
  1274  	c.Create(-1 /* epRcvBuf */)
  1275  
  1276  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1277  		t.Fatal("Bind failed:", err)
  1278  	}
  1279  
  1280  	if err := c.EP.Listen(10 /* backlog */); err != nil {
  1281  		t.Fatal("Listen failed:", err)
  1282  	}
  1283  
  1284  	c.SendPacket(nil, &context.Headers{
  1285  		SrcPort: context.TestPort,
  1286  		DstPort: context.StackPort,
  1287  		Flags:   header.TCPFlagFin | header.TCPFlagAck,
  1288  		SeqNum:  100,
  1289  		AckNum:  200,
  1290  	})
  1291  
  1292  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  1293  		checker.DstPort(context.TestPort),
  1294  		checker.TCPFlags(header.TCPFlagRst),
  1295  		checker.TCPSeqNum(200)))
  1296  }
  1297  
  1298  func TestSendRstOnListenerRxAckV6(t *testing.T) {
  1299  	c := context.New(t, defaultMTU)
  1300  	defer c.Cleanup()
  1301  
  1302  	c.CreateV6Endpoint(true /* v6Only */)
  1303  
  1304  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1305  		t.Fatal("Bind failed:", err)
  1306  	}
  1307  
  1308  	if err := c.EP.Listen(10 /* backlog */); err != nil {
  1309  		t.Fatal("Listen failed:", err)
  1310  	}
  1311  
  1312  	c.SendV6Packet(nil, &context.Headers{
  1313  		SrcPort: context.TestPort,
  1314  		DstPort: context.StackPort,
  1315  		Flags:   header.TCPFlagFin | header.TCPFlagAck,
  1316  		SeqNum:  100,
  1317  		AckNum:  200,
  1318  	})
  1319  
  1320  	checker.IPv6(t, c.GetV6Packet(), checker.TCP(
  1321  		checker.DstPort(context.TestPort),
  1322  		checker.TCPFlags(header.TCPFlagRst),
  1323  		checker.TCPSeqNum(200)))
  1324  }
  1325  
  1326  // TestListenShutdown tests for the listening endpoint replying with RST
  1327  // on read shutdown.
  1328  func TestListenShutdown(t *testing.T) {
  1329  	c := context.New(t, defaultMTU)
  1330  	defer c.Cleanup()
  1331  
  1332  	c.Create(-1 /* epRcvBuf */)
  1333  
  1334  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1335  		t.Fatal("Bind failed:", err)
  1336  	}
  1337  
  1338  	if err := c.EP.Listen(1 /* backlog */); err != nil {
  1339  		t.Fatal("Listen failed:", err)
  1340  	}
  1341  
  1342  	if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
  1343  		t.Fatal("Shutdown failed:", err)
  1344  	}
  1345  
  1346  	c.SendPacket(nil, &context.Headers{
  1347  		SrcPort: context.TestPort,
  1348  		DstPort: context.StackPort,
  1349  		Flags:   header.TCPFlagSyn,
  1350  		SeqNum:  100,
  1351  		AckNum:  200,
  1352  	})
  1353  
  1354  	// Expect the listening endpoint to reset the connection.
  1355  	checker.IPv4(t, c.GetPacket(),
  1356  		checker.TCP(
  1357  			checker.DstPort(context.TestPort),
  1358  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
  1359  		))
  1360  }
  1361  
  1362  var _ waiter.EntryCallback = (callback)(nil)
  1363  
  1364  type callback func(*waiter.Entry, waiter.EventMask)
  1365  
  1366  func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
  1367  	cb(entry, mask)
  1368  }
  1369  
  1370  func TestListenerReadinessOnEvent(t *testing.T) {
  1371  	s := stack.New(stack.Options{
  1372  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
  1373  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
  1374  	})
  1375  	{
  1376  		ep := loopback.New()
  1377  		if testing.Verbose() {
  1378  			ep = sniffer.New(ep)
  1379  		}
  1380  		const id = 1
  1381  		if err := s.CreateNIC(id, ep); err != nil {
  1382  			t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
  1383  		}
  1384  		if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
  1385  			t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
  1386  		}
  1387  		s.SetRouteTable([]tcpip.Route{
  1388  			{Destination: header.IPv4EmptySubnet, NIC: id},
  1389  		})
  1390  	}
  1391  
  1392  	var wq waiter.Queue
  1393  	ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
  1394  	if err != nil {
  1395  		t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
  1396  	}
  1397  	defer ep.Close()
  1398  
  1399  	if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil {
  1400  		t.Fatalf("Bind(%s): %s", context.StackAddr, err)
  1401  	}
  1402  	const backlog = 1
  1403  	if err := ep.Listen(backlog); err != nil {
  1404  		t.Fatalf("Listen(%d): %s", backlog, err)
  1405  	}
  1406  
  1407  	address, err := ep.GetLocalAddress()
  1408  	if err != nil {
  1409  		t.Fatalf("GetLocalAddress(): %s", err)
  1410  	}
  1411  
  1412  	conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
  1413  	if err != nil {
  1414  		t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
  1415  	}
  1416  	defer conn.Close()
  1417  
  1418  	events := make(chan waiter.EventMask)
  1419  	// Scope `entry` to allow a binding of the same name below.
  1420  	{
  1421  		entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) {
  1422  			events <- ep.Readiness(mask)
  1423  		})}
  1424  		wq.EventRegister(&entry, waiter.EventIn)
  1425  		defer wq.EventUnregister(&entry)
  1426  	}
  1427  
  1428  	entry, ch := waiter.NewChannelEntry(nil)
  1429  	wq.EventRegister(&entry, waiter.EventOut)
  1430  	defer wq.EventUnregister(&entry)
  1431  
  1432  	switch err := conn.Connect(address).(type) {
  1433  	case *tcpip.ErrConnectStarted:
  1434  	default:
  1435  		t.Fatalf("Connect(%#v): %v", address, err)
  1436  	}
  1437  
  1438  	// Read at least one event.
  1439  	got := <-events
  1440  	for {
  1441  		select {
  1442  		case event := <-events:
  1443  			got |= event
  1444  			continue
  1445  		case <-ch:
  1446  			if want := waiter.ReadableEvents; got != want {
  1447  				t.Errorf("observed events = %b, want %b", got, want)
  1448  			}
  1449  		}
  1450  		break
  1451  	}
  1452  }
  1453  
  1454  // TestListenCloseWhileConnect tests for the listening endpoint to
  1455  // drain the accept-queue when closed. This should reset all of the
  1456  // pending connections that are waiting to be accepted.
  1457  func TestListenCloseWhileConnect(t *testing.T) {
  1458  	c := context.New(t, defaultMTU)
  1459  	defer c.Cleanup()
  1460  
  1461  	c.Create(-1 /* epRcvBuf */)
  1462  
  1463  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1464  		t.Fatal("Bind failed:", err)
  1465  	}
  1466  
  1467  	if err := c.EP.Listen(1 /* backlog */); err != nil {
  1468  		t.Fatal("Listen failed:", err)
  1469  	}
  1470  
  1471  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
  1472  	c.WQ.EventRegister(&waitEntry, waiter.ReadableEvents)
  1473  	defer c.WQ.EventUnregister(&waitEntry)
  1474  
  1475  	executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
  1476  	// Wait for the new endpoint created because of handshake to be delivered
  1477  	// to the listening endpoint's accept queue.
  1478  	<-notifyCh
  1479  
  1480  	// Close the listening endpoint.
  1481  	c.EP.Close()
  1482  
  1483  	// Expect the listening endpoint to reset the connection.
  1484  	checker.IPv4(t, c.GetPacket(),
  1485  		checker.TCP(
  1486  			checker.DstPort(context.TestPort),
  1487  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
  1488  		))
  1489  }
  1490  
  1491  func TestTOSV4(t *testing.T) {
  1492  	c := context.New(t, defaultMTU)
  1493  	defer c.Cleanup()
  1494  
  1495  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  1496  	if err != nil {
  1497  		t.Fatalf("NewEndpoint failed: %s", err)
  1498  	}
  1499  	c.EP = ep
  1500  
  1501  	const tos = 0xC0
  1502  	if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
  1503  		t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
  1504  	}
  1505  
  1506  	v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
  1507  	if err != nil {
  1508  		t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
  1509  	}
  1510  
  1511  	if v != tos {
  1512  		t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
  1513  	}
  1514  
  1515  	testV4Connect(t, c, checker.TOS(tos, 0))
  1516  
  1517  	data := []byte{1, 2, 3}
  1518  	var r bytes.Reader
  1519  	r.Reset(data)
  1520  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  1521  		t.Fatalf("Write failed: %s", err)
  1522  	}
  1523  
  1524  	// Check that data is received.
  1525  	b := c.GetPacket()
  1526  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  1527  	checker.IPv4(t, b,
  1528  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  1529  		checker.TCP(
  1530  			checker.DstPort(context.TestPort),
  1531  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1532  			checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1
  1533  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  1534  		),
  1535  		checker.TOS(tos, 0),
  1536  	)
  1537  
  1538  	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
  1539  		t.Errorf("got data = %x, want = %x", p, data)
  1540  	}
  1541  }
  1542  
  1543  func TestTrafficClassV6(t *testing.T) {
  1544  	c := context.New(t, defaultMTU)
  1545  	defer c.Cleanup()
  1546  
  1547  	c.CreateV6Endpoint(false)
  1548  
  1549  	const tos = 0xC0
  1550  	if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
  1551  		t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
  1552  	}
  1553  
  1554  	v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
  1555  	if err != nil {
  1556  		t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
  1557  	}
  1558  
  1559  	if v != tos {
  1560  		t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
  1561  	}
  1562  
  1563  	// Test the connection request.
  1564  	testV6Connect(t, c, checker.TOS(tos, 0))
  1565  
  1566  	data := []byte{1, 2, 3}
  1567  	var r bytes.Reader
  1568  	r.Reset(data)
  1569  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  1570  		t.Fatalf("Write failed: %s", err)
  1571  	}
  1572  
  1573  	// Check that data is received.
  1574  	b := c.GetV6Packet()
  1575  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  1576  	checker.IPv6(t, b,
  1577  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  1578  		checker.TCP(
  1579  			checker.DstPort(context.TestPort),
  1580  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1581  			checker.TCPAckNum(uint32(iss)),
  1582  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  1583  		),
  1584  		checker.TOS(tos, 0),
  1585  	)
  1586  
  1587  	if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
  1588  		t.Errorf("got data = %x, want = %x", p, data)
  1589  	}
  1590  }
  1591  
  1592  func TestConnectBindToDevice(t *testing.T) {
  1593  	for _, test := range []struct {
  1594  		name   string
  1595  		device tcpip.NICID
  1596  		want   tcp.EndpointState
  1597  	}{
  1598  		{"RightDevice", 1, tcp.StateEstablished},
  1599  		{"WrongDevice", 2, tcp.StateSynSent},
  1600  		{"AnyDevice", 0, tcp.StateEstablished},
  1601  	} {
  1602  		t.Run(test.name, func(t *testing.T) {
  1603  			c := context.New(t, defaultMTU)
  1604  			defer c.Cleanup()
  1605  
  1606  			c.Create(-1)
  1607  			if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil {
  1608  				t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err)
  1609  			}
  1610  			// Start connection attempt.
  1611  			waitEntry, _ := waiter.NewChannelEntry(nil)
  1612  			c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
  1613  			defer c.WQ.EventUnregister(&waitEntry)
  1614  
  1615  			err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
  1616  			if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
  1617  				t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
  1618  			}
  1619  
  1620  			// Receive SYN packet.
  1621  			b := c.GetPacket()
  1622  			checker.IPv4(t, b,
  1623  				checker.TCP(
  1624  					checker.DstPort(context.TestPort),
  1625  					checker.TCPFlags(header.TCPFlagSyn),
  1626  				),
  1627  			)
  1628  			if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
  1629  				t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
  1630  			}
  1631  			tcpHdr := header.TCP(header.IPv4(b).Payload())
  1632  			c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  1633  
  1634  			iss := seqnum.Value(context.TestInitialSequenceNumber)
  1635  			rcvWnd := seqnum.Size(30000)
  1636  			c.SendPacket(nil, &context.Headers{
  1637  				SrcPort: tcpHdr.DestinationPort(),
  1638  				DstPort: tcpHdr.SourcePort(),
  1639  				Flags:   header.TCPFlagSyn | header.TCPFlagAck,
  1640  				SeqNum:  iss,
  1641  				AckNum:  c.IRS.Add(1),
  1642  				RcvWnd:  rcvWnd,
  1643  				TCPOpts: nil,
  1644  			})
  1645  
  1646  			c.GetPacket()
  1647  			if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
  1648  				t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
  1649  			}
  1650  		})
  1651  	}
  1652  }
  1653  
  1654  func TestSynSent(t *testing.T) {
  1655  	for _, test := range []struct {
  1656  		name  string
  1657  		reset bool
  1658  	}{
  1659  		{"RstOnSynSent", true},
  1660  		{"CloseOnSynSent", false},
  1661  	} {
  1662  		t.Run(test.name, func(t *testing.T) {
  1663  			c := context.New(t, defaultMTU)
  1664  			defer c.Cleanup()
  1665  
  1666  			// Create an endpoint, don't handshake because we want to interfere with the
  1667  			// handshake process.
  1668  			c.Create(-1)
  1669  
  1670  			// Start connection attempt.
  1671  			waitEntry, ch := waiter.NewChannelEntry(nil)
  1672  			c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
  1673  			defer c.WQ.EventUnregister(&waitEntry)
  1674  
  1675  			addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
  1676  			err := c.EP.Connect(addr)
  1677  			if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
  1678  				t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
  1679  			}
  1680  
  1681  			// Receive SYN packet.
  1682  			b := c.GetPacket()
  1683  			checker.IPv4(t, b,
  1684  				checker.TCP(
  1685  					checker.DstPort(context.TestPort),
  1686  					checker.TCPFlags(header.TCPFlagSyn),
  1687  				),
  1688  			)
  1689  
  1690  			if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
  1691  				t.Fatalf("got State() = %s, want %s", got, want)
  1692  			}
  1693  			tcpHdr := header.TCP(header.IPv4(b).Payload())
  1694  			c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  1695  
  1696  			if test.reset {
  1697  				// Send a packet with a proper ACK and a RST flag to cause the socket
  1698  				// to error and close out.
  1699  				iss := seqnum.Value(context.TestInitialSequenceNumber)
  1700  				rcvWnd := seqnum.Size(30000)
  1701  				c.SendPacket(nil, &context.Headers{
  1702  					SrcPort: tcpHdr.DestinationPort(),
  1703  					DstPort: tcpHdr.SourcePort(),
  1704  					Flags:   header.TCPFlagRst | header.TCPFlagAck,
  1705  					SeqNum:  iss,
  1706  					AckNum:  c.IRS.Add(1),
  1707  					RcvWnd:  rcvWnd,
  1708  					TCPOpts: nil,
  1709  				})
  1710  			} else {
  1711  				c.EP.Close()
  1712  			}
  1713  
  1714  			// Wait for receive to be notified.
  1715  			select {
  1716  			case <-ch:
  1717  			case <-time.After(3 * time.Second):
  1718  				t.Fatal("timed out waiting for packet to arrive")
  1719  			}
  1720  
  1721  			ept := endpointTester{c.EP}
  1722  			if test.reset {
  1723  				ept.CheckReadError(t, &tcpip.ErrConnectionRefused{})
  1724  			} else {
  1725  				ept.CheckReadError(t, &tcpip.ErrAborted{})
  1726  			}
  1727  
  1728  			if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
  1729  				t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
  1730  			}
  1731  
  1732  			// Due to the RST the endpoint should be in an error state.
  1733  			if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
  1734  				t.Fatalf("got State() = %s, want %s", got, want)
  1735  			}
  1736  		})
  1737  	}
  1738  }
  1739  
  1740  func TestOutOfOrderReceive(t *testing.T) {
  1741  	c := context.New(t, defaultMTU)
  1742  	defer c.Cleanup()
  1743  
  1744  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  1745  
  1746  	we, ch := waiter.NewChannelEntry(nil)
  1747  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  1748  	defer c.WQ.EventUnregister(&we)
  1749  
  1750  	ept := endpointTester{c.EP}
  1751  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  1752  
  1753  	// Send second half of data first, with seqnum 3 ahead of expected.
  1754  	data := []byte{1, 2, 3, 4, 5, 6}
  1755  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  1756  	c.SendPacket(data[3:], &context.Headers{
  1757  		SrcPort: context.TestPort,
  1758  		DstPort: c.Port,
  1759  		Flags:   header.TCPFlagAck,
  1760  		SeqNum:  iss.Add(3),
  1761  		AckNum:  c.IRS.Add(1),
  1762  		RcvWnd:  30000,
  1763  	})
  1764  
  1765  	// Check that we get an ACK specifying which seqnum is expected.
  1766  	checker.IPv4(t, c.GetPacket(),
  1767  		checker.TCP(
  1768  			checker.DstPort(context.TestPort),
  1769  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1770  			checker.TCPAckNum(uint32(iss)),
  1771  			checker.TCPFlags(header.TCPFlagAck),
  1772  		),
  1773  	)
  1774  
  1775  	// Wait 200ms and check that no data has been received.
  1776  	time.Sleep(200 * time.Millisecond)
  1777  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  1778  
  1779  	// Send the first 3 bytes now.
  1780  	c.SendPacket(data[:3], &context.Headers{
  1781  		SrcPort: context.TestPort,
  1782  		DstPort: c.Port,
  1783  		Flags:   header.TCPFlagAck,
  1784  		SeqNum:  iss,
  1785  		AckNum:  c.IRS.Add(1),
  1786  		RcvWnd:  30000,
  1787  	})
  1788  
  1789  	// Receive data.
  1790  	read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
  1791  
  1792  	// Check that we received the data in proper order.
  1793  	if !bytes.Equal(data, read) {
  1794  		t.Fatalf("got data = %v, want = %v", read, data)
  1795  	}
  1796  
  1797  	// Check that the whole data is acknowledged.
  1798  	checker.IPv4(t, c.GetPacket(),
  1799  		checker.TCP(
  1800  			checker.DstPort(context.TestPort),
  1801  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1802  			checker.TCPAckNum(uint32(iss)+uint32(len(data))),
  1803  			checker.TCPFlags(header.TCPFlagAck),
  1804  		),
  1805  	)
  1806  }
  1807  
  1808  func TestOutOfOrderFlood(t *testing.T) {
  1809  	c := context.New(t, defaultMTU)
  1810  	defer c.Cleanup()
  1811  
  1812  	rcvBufSz := math.MaxUint16
  1813  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
  1814  
  1815  	ept := endpointTester{c.EP}
  1816  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  1817  
  1818  	// Send 100 packets before the actual one that is expected.
  1819  	data := []byte{1, 2, 3, 4, 5, 6}
  1820  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  1821  	for i := 0; i < 100; i++ {
  1822  		c.SendPacket(data[3:], &context.Headers{
  1823  			SrcPort: context.TestPort,
  1824  			DstPort: c.Port,
  1825  			Flags:   header.TCPFlagAck,
  1826  			SeqNum:  iss.Add(6),
  1827  			AckNum:  c.IRS.Add(1),
  1828  			RcvWnd:  30000,
  1829  		})
  1830  
  1831  		checker.IPv4(t, c.GetPacket(),
  1832  			checker.TCP(
  1833  				checker.DstPort(context.TestPort),
  1834  				checker.TCPSeqNum(uint32(c.IRS)+1),
  1835  				checker.TCPAckNum(uint32(iss)),
  1836  				checker.TCPFlags(header.TCPFlagAck),
  1837  			),
  1838  		)
  1839  	}
  1840  
  1841  	// Send packet with seqnum as initial + 3. It must be discarded because the
  1842  	// out-of-order buffer was filled by the previous packets.
  1843  	c.SendPacket(data[3:], &context.Headers{
  1844  		SrcPort: context.TestPort,
  1845  		DstPort: c.Port,
  1846  		Flags:   header.TCPFlagAck,
  1847  		SeqNum:  iss.Add(3),
  1848  		AckNum:  c.IRS.Add(1),
  1849  		RcvWnd:  30000,
  1850  	})
  1851  
  1852  	checker.IPv4(t, c.GetPacket(),
  1853  		checker.TCP(
  1854  			checker.DstPort(context.TestPort),
  1855  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1856  			checker.TCPAckNum(uint32(iss)),
  1857  			checker.TCPFlags(header.TCPFlagAck),
  1858  		),
  1859  	)
  1860  
  1861  	// Now send the expected packet with initial sequence number.
  1862  	c.SendPacket(data[:3], &context.Headers{
  1863  		SrcPort: context.TestPort,
  1864  		DstPort: c.Port,
  1865  		Flags:   header.TCPFlagAck,
  1866  		SeqNum:  iss,
  1867  		AckNum:  c.IRS.Add(1),
  1868  		RcvWnd:  30000,
  1869  	})
  1870  
  1871  	// Check that only packet with initial sequence number is acknowledged.
  1872  	checker.IPv4(t, c.GetPacket(),
  1873  		checker.TCP(
  1874  			checker.DstPort(context.TestPort),
  1875  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1876  			checker.TCPAckNum(uint32(iss)+3),
  1877  			checker.TCPFlags(header.TCPFlagAck),
  1878  		),
  1879  	)
  1880  }
  1881  
  1882  func TestRstOnCloseWithUnreadData(t *testing.T) {
  1883  	c := context.New(t, defaultMTU)
  1884  	defer c.Cleanup()
  1885  
  1886  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  1887  
  1888  	we, ch := waiter.NewChannelEntry(nil)
  1889  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  1890  	defer c.WQ.EventUnregister(&we)
  1891  
  1892  	ept := endpointTester{c.EP}
  1893  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  1894  
  1895  	data := []byte{1, 2, 3}
  1896  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  1897  	c.SendPacket(data, &context.Headers{
  1898  		SrcPort: context.TestPort,
  1899  		DstPort: c.Port,
  1900  		Flags:   header.TCPFlagAck,
  1901  		SeqNum:  iss,
  1902  		AckNum:  c.IRS.Add(1),
  1903  		RcvWnd:  30000,
  1904  	})
  1905  
  1906  	// Wait for receive to be notified.
  1907  	select {
  1908  	case <-ch:
  1909  	case <-time.After(3 * time.Second):
  1910  		t.Fatalf("Timed out waiting for data to arrive")
  1911  	}
  1912  
  1913  	// Check that ACK is received, this happens regardless of the read.
  1914  	checker.IPv4(t, c.GetPacket(),
  1915  		checker.TCP(
  1916  			checker.DstPort(context.TestPort),
  1917  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1918  			checker.TCPAckNum(uint32(iss)+uint32(len(data))),
  1919  			checker.TCPFlags(header.TCPFlagAck),
  1920  		),
  1921  	)
  1922  
  1923  	// Now that we know we have unread data, let's just close the connection
  1924  	// and verify that netstack sends an RST rather than a FIN.
  1925  	c.EP.Close()
  1926  
  1927  	checker.IPv4(t, c.GetPacket(),
  1928  		checker.TCP(
  1929  			checker.DstPort(context.TestPort),
  1930  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
  1931  			// We shouldn't consume a sequence number on RST.
  1932  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1933  		))
  1934  	// The RST puts the endpoint into an error state.
  1935  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
  1936  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  1937  	}
  1938  
  1939  	// This final ACK should be ignored because an ACK on a reset doesn't mean
  1940  	// anything.
  1941  	c.SendPacket(nil, &context.Headers{
  1942  		SrcPort: context.TestPort,
  1943  		DstPort: c.Port,
  1944  		Flags:   header.TCPFlagAck,
  1945  		SeqNum:  iss.Add(seqnum.Size(len(data))),
  1946  		AckNum:  c.IRS.Add(seqnum.Size(2)),
  1947  		RcvWnd:  30000,
  1948  	})
  1949  }
  1950  
  1951  func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
  1952  	c := context.New(t, defaultMTU)
  1953  	defer c.Cleanup()
  1954  
  1955  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  1956  
  1957  	we, ch := waiter.NewChannelEntry(nil)
  1958  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  1959  	defer c.WQ.EventUnregister(&we)
  1960  
  1961  	ept := endpointTester{c.EP}
  1962  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  1963  
  1964  	data := []byte{1, 2, 3}
  1965  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  1966  	c.SendPacket(data, &context.Headers{
  1967  		SrcPort: context.TestPort,
  1968  		DstPort: c.Port,
  1969  		Flags:   header.TCPFlagAck,
  1970  		SeqNum:  iss,
  1971  		AckNum:  c.IRS.Add(1),
  1972  		RcvWnd:  30000,
  1973  	})
  1974  
  1975  	// Wait for receive to be notified.
  1976  	select {
  1977  	case <-ch:
  1978  	case <-time.After(3 * time.Second):
  1979  		t.Fatalf("Timed out waiting for data to arrive")
  1980  	}
  1981  
  1982  	// Check that ACK is received, this happens regardless of the read.
  1983  	checker.IPv4(t, c.GetPacket(),
  1984  		checker.TCP(
  1985  			checker.DstPort(context.TestPort),
  1986  			checker.TCPSeqNum(uint32(c.IRS)+1),
  1987  			checker.TCPAckNum(uint32(iss)+uint32(len(data))),
  1988  			checker.TCPFlags(header.TCPFlagAck),
  1989  		),
  1990  	)
  1991  
  1992  	// Cause a FIN to be generated.
  1993  	c.EP.Shutdown(tcpip.ShutdownWrite)
  1994  
  1995  	// Make sure we get the FIN but DON't ACK IT.
  1996  	checker.IPv4(t, c.GetPacket(),
  1997  		checker.TCP(
  1998  			checker.DstPort(context.TestPort),
  1999  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  2000  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2001  		))
  2002  
  2003  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
  2004  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  2005  	}
  2006  
  2007  	// Cause a RST to be generated by closing the read end now since we have
  2008  	// unread data.
  2009  	c.EP.Shutdown(tcpip.ShutdownRead)
  2010  
  2011  	// Make sure we get the RST
  2012  	checker.IPv4(t, c.GetPacket(),
  2013  		checker.TCP(
  2014  			checker.DstPort(context.TestPort),
  2015  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
  2016  			// RST is always generated with sndNxt which if the FIN
  2017  			// has been sent will be 1 higher than the sequence
  2018  			// number of the FIN itself.
  2019  			checker.TCPSeqNum(uint32(c.IRS)+2),
  2020  		))
  2021  	// The RST puts the endpoint into an error state.
  2022  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
  2023  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  2024  	}
  2025  
  2026  	// The ACK to the FIN should now be rejected since the connection has been
  2027  	// closed by a RST.
  2028  	c.SendPacket(nil, &context.Headers{
  2029  		SrcPort: context.TestPort,
  2030  		DstPort: c.Port,
  2031  		Flags:   header.TCPFlagAck,
  2032  		SeqNum:  iss.Add(seqnum.Size(len(data))),
  2033  		AckNum:  c.IRS.Add(seqnum.Size(2)),
  2034  		RcvWnd:  30000,
  2035  	})
  2036  }
  2037  
  2038  func TestShutdownRead(t *testing.T) {
  2039  	c := context.New(t, defaultMTU)
  2040  	defer c.Cleanup()
  2041  
  2042  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  2043  
  2044  	ept := endpointTester{c.EP}
  2045  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  2046  
  2047  	if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
  2048  		t.Fatalf("Shutdown failed: %s", err)
  2049  	}
  2050  
  2051  	ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
  2052  	var want uint64 = 1
  2053  	if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
  2054  		t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
  2055  	}
  2056  }
  2057  
  2058  func TestFullWindowReceive(t *testing.T) {
  2059  	c := context.New(t, defaultMTU)
  2060  	defer c.Cleanup()
  2061  
  2062  	const rcvBufSz = 10
  2063  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
  2064  
  2065  	we, ch := waiter.NewChannelEntry(nil)
  2066  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  2067  	defer c.WQ.EventUnregister(&we)
  2068  
  2069  	ept := endpointTester{c.EP}
  2070  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  2071  
  2072  	// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
  2073  	// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
  2074  	// receive buffer size.
  2075  	data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
  2076  	for i := range data {
  2077  		data[i] = byte(i % 255)
  2078  	}
  2079  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2080  	c.SendPacket(data, &context.Headers{
  2081  		SrcPort: context.TestPort,
  2082  		DstPort: c.Port,
  2083  		Flags:   header.TCPFlagAck,
  2084  		SeqNum:  iss,
  2085  		AckNum:  c.IRS.Add(1),
  2086  		RcvWnd:  30000,
  2087  	})
  2088  
  2089  	// Wait for receive to be notified.
  2090  	select {
  2091  	case <-ch:
  2092  	case <-time.After(5 * time.Second):
  2093  		t.Fatalf("Timed out waiting for data to arrive")
  2094  	}
  2095  
  2096  	// Check that data is acknowledged, and window goes to zero.
  2097  	checker.IPv4(t, c.GetPacket(),
  2098  		checker.TCP(
  2099  			checker.DstPort(context.TestPort),
  2100  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2101  			checker.TCPAckNum(uint32(iss)+uint32(len(data))),
  2102  			checker.TCPFlags(header.TCPFlagAck),
  2103  			checker.TCPWindow(0),
  2104  		),
  2105  	)
  2106  
  2107  	// Receive data and check it.
  2108  	v := ept.CheckRead(t)
  2109  	if !bytes.Equal(data, v) {
  2110  		t.Fatalf("got data = %v, want = %v", v, data)
  2111  	}
  2112  
  2113  	var want uint64 = 1
  2114  	if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
  2115  		t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
  2116  	}
  2117  
  2118  	// Check that we get an ACK for the newly non-zero window.
  2119  	checker.IPv4(t, c.GetPacket(),
  2120  		checker.TCP(
  2121  			checker.DstPort(context.TestPort),
  2122  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2123  			checker.TCPAckNum(uint32(iss)+uint32(len(data))),
  2124  			checker.TCPFlags(header.TCPFlagAck),
  2125  			checker.TCPWindow(10),
  2126  		),
  2127  	)
  2128  }
  2129  
  2130  // Test the stack receive window advertisement on receiving segments smaller than
  2131  // segment overhead. It tests for the right edge of the window to not grow when
  2132  // the endpoint is not being read from.
  2133  func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
  2134  	c := context.New(t, defaultMTU)
  2135  	defer c.Cleanup()
  2136  
  2137  	opt := tcpip.TCPReceiveBufferSizeRangeOption{
  2138  		Min:     1,
  2139  		Default: tcp.DefaultReceiveBufferSize,
  2140  		Max:     tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
  2141  	}
  2142  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  2143  		t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  2144  	}
  2145  
  2146  	c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
  2147  
  2148  	// Bump up the receive buffer size such that, when the receive window grows,
  2149  	// the scaled window exceeds maxUint16.
  2150  	c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true)
  2151  
  2152  	// Keep the payload size < segment overhead and such that it is a multiple
  2153  	// of the window scaled value. This enables the test to perform equality
  2154  	// checks on the incoming receive window.
  2155  	payloadSize := 1 << c.RcvdWindowScale
  2156  	if payloadSize >= tcp.SegSize {
  2157  		t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize)
  2158  	}
  2159  	payload := generateRandomPayload(t, payloadSize)
  2160  	payloadLen := seqnum.Size(len(payload))
  2161  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2162  
  2163  	// Send payload to the endpoint and return the advertised receive window
  2164  	// from the endpoint.
  2165  	getIncomingRcvWnd := func() uint32 {
  2166  		c.SendPacket(payload, &context.Headers{
  2167  			SrcPort: context.TestPort,
  2168  			DstPort: c.Port,
  2169  			SeqNum:  iss,
  2170  			AckNum:  c.IRS.Add(1),
  2171  			Flags:   header.TCPFlagAck,
  2172  			RcvWnd:  30000,
  2173  		})
  2174  		iss = iss.Add(payloadLen)
  2175  
  2176  		pkt := c.GetPacket()
  2177  		return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
  2178  	}
  2179  
  2180  	// Read the advertised receive window with the ACK for payload.
  2181  	rcvWnd := getIncomingRcvWnd()
  2182  
  2183  	// Check if the subsequent ACK to our send has not grown the right edge of
  2184  	// the window.
  2185  	if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
  2186  		t.Fatalf("got incomingRcvwnd %d want %d", got, want)
  2187  	}
  2188  
  2189  	// Read the data so that the subsequent ACK from the endpoint
  2190  	// grows the right edge of the window.
  2191  	var buf bytes.Buffer
  2192  	if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil {
  2193  		t.Fatalf("c.EP.Read: %s", err)
  2194  	}
  2195  
  2196  	// Check if we have received max uint16 as our advertised
  2197  	// scaled window now after a read above.
  2198  	maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
  2199  	if got, want := getIncomingRcvWnd(), maxRcv; got != want {
  2200  		t.Fatalf("got incomingRcvwnd %d want %d", got, want)
  2201  	}
  2202  
  2203  	// Check if the subsequent ACK to our send has not grown the right edge of
  2204  	// the window.
  2205  	if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
  2206  		t.Fatalf("got incomingRcvwnd %d want %d", got, want)
  2207  	}
  2208  }
  2209  
  2210  func TestNoWindowShrinking(t *testing.T) {
  2211  	c := context.New(t, defaultMTU)
  2212  	defer c.Cleanup()
  2213  
  2214  	// Start off with a certain receive buffer then cut it in half and verify that
  2215  	// the right edge of the window does not shrink.
  2216  	// NOTE: Netstack doubles the value specified here.
  2217  	rcvBufSize := 65536
  2218  	// Enable window scaling with a scale of zero from our end.
  2219  	c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{
  2220  		header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
  2221  	})
  2222  
  2223  	we, ch := waiter.NewChannelEntry(nil)
  2224  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  2225  	defer c.WQ.EventUnregister(&we)
  2226  
  2227  	ept := endpointTester{c.EP}
  2228  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  2229  
  2230  	// Send a 1 byte payload so that we can record the current receive window.
  2231  	// Send a payload of half the size of rcvBufSize.
  2232  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2233  	payload := []byte{1}
  2234  	c.SendPacket(payload, &context.Headers{
  2235  		SrcPort: context.TestPort,
  2236  		DstPort: c.Port,
  2237  		Flags:   header.TCPFlagAck,
  2238  		SeqNum:  iss,
  2239  		AckNum:  c.IRS.Add(1),
  2240  		RcvWnd:  30000,
  2241  	})
  2242  
  2243  	// Wait for receive to be notified.
  2244  	select {
  2245  	case <-ch:
  2246  	case <-time.After(5 * time.Second):
  2247  		t.Fatalf("Timed out waiting for data to arrive")
  2248  	}
  2249  
  2250  	// Read the 1 byte payload we just sent.
  2251  	if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) {
  2252  		t.Fatalf("got data: %v, want: %v", got, want)
  2253  	}
  2254  
  2255  	// Verify that the ACK does not shrink the window.
  2256  	pkt := c.GetPacket()
  2257  	iss = iss.Add(1)
  2258  	checker.IPv4(t, pkt,
  2259  		checker.TCP(
  2260  			checker.DstPort(context.TestPort),
  2261  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2262  			checker.TCPAckNum(uint32(iss)),
  2263  			checker.TCPFlags(header.TCPFlagAck),
  2264  		),
  2265  	)
  2266  	// Stash the initial window.
  2267  	initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
  2268  	initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
  2269  	// Now shrink the receive buffer to half its original size.
  2270  	c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true)
  2271  
  2272  	data := generateRandomPayload(t, rcvBufSize)
  2273  	// Send a payload of half the size of rcvBufSize.
  2274  	c.SendPacket(data[:rcvBufSize/2], &context.Headers{
  2275  		SrcPort: context.TestPort,
  2276  		DstPort: c.Port,
  2277  		Flags:   header.TCPFlagAck,
  2278  		SeqNum:  iss,
  2279  		AckNum:  c.IRS.Add(1),
  2280  		RcvWnd:  30000,
  2281  	})
  2282  	iss = iss.Add(seqnum.Size(rcvBufSize / 2))
  2283  
  2284  	// Verify that the ACK does not shrink the window.
  2285  	pkt = c.GetPacket()
  2286  	checker.IPv4(t, pkt,
  2287  		checker.TCP(
  2288  			checker.DstPort(context.TestPort),
  2289  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2290  			checker.TCPAckNum(uint32(iss)),
  2291  			checker.TCPFlags(header.TCPFlagAck),
  2292  		),
  2293  	)
  2294  	newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
  2295  	newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd))
  2296  	if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
  2297  		t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
  2298  	}
  2299  
  2300  	// Send another payload of half the size of rcvBufSize. This should fill up the
  2301  	// socket receive buffer and we should see a zero window.
  2302  	c.SendPacket(data[rcvBufSize/2:], &context.Headers{
  2303  		SrcPort: context.TestPort,
  2304  		DstPort: c.Port,
  2305  		Flags:   header.TCPFlagAck,
  2306  		SeqNum:  iss,
  2307  		AckNum:  c.IRS.Add(1),
  2308  		RcvWnd:  30000,
  2309  	})
  2310  	iss = iss.Add(seqnum.Size(rcvBufSize / 2))
  2311  
  2312  	checker.IPv4(t, c.GetPacket(),
  2313  		checker.TCP(
  2314  			checker.DstPort(context.TestPort),
  2315  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2316  			checker.TCPAckNum(uint32(iss)),
  2317  			checker.TCPFlags(header.TCPFlagAck),
  2318  			checker.TCPWindow(0),
  2319  		),
  2320  	)
  2321  
  2322  	// Receive data and check it.
  2323  	read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
  2324  	if !bytes.Equal(data, read) {
  2325  		t.Fatalf("got data = %v, want = %v", read, data)
  2326  	}
  2327  
  2328  	// Check that we get an ACK for the newly non-zero window, which is the new
  2329  	// receive buffer size we set after the connection was established.
  2330  	checker.IPv4(t, c.GetPacket(),
  2331  		checker.TCP(
  2332  			checker.DstPort(context.TestPort),
  2333  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2334  			checker.TCPAckNum(uint32(iss)),
  2335  			checker.TCPFlags(header.TCPFlagAck),
  2336  			checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
  2337  		),
  2338  	)
  2339  }
  2340  
  2341  func TestSimpleSend(t *testing.T) {
  2342  	c := context.New(t, defaultMTU)
  2343  	defer c.Cleanup()
  2344  
  2345  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  2346  
  2347  	data := []byte{1, 2, 3}
  2348  	var r bytes.Reader
  2349  	r.Reset(data)
  2350  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2351  		t.Fatalf("Write failed: %s", err)
  2352  	}
  2353  
  2354  	// Check that data is received.
  2355  	b := c.GetPacket()
  2356  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2357  	checker.IPv4(t, b,
  2358  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  2359  		checker.TCP(
  2360  			checker.DstPort(context.TestPort),
  2361  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2362  			checker.TCPAckNum(uint32(iss)),
  2363  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2364  		),
  2365  	)
  2366  
  2367  	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
  2368  		t.Fatalf("got data = %v, want = %v", p, data)
  2369  	}
  2370  
  2371  	// Acknowledge the data.
  2372  	c.SendPacket(nil, &context.Headers{
  2373  		SrcPort: context.TestPort,
  2374  		DstPort: c.Port,
  2375  		Flags:   header.TCPFlagAck,
  2376  		SeqNum:  iss,
  2377  		AckNum:  c.IRS.Add(1 + seqnum.Size(len(data))),
  2378  		RcvWnd:  30000,
  2379  	})
  2380  }
  2381  
  2382  func TestZeroWindowSend(t *testing.T) {
  2383  	c := context.New(t, defaultMTU)
  2384  	defer c.Cleanup()
  2385  
  2386  	c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */)
  2387  
  2388  	data := []byte{1, 2, 3}
  2389  	var r bytes.Reader
  2390  	r.Reset(data)
  2391  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2392  		t.Fatalf("Write failed: %s", err)
  2393  	}
  2394  
  2395  	// Check if we got a zero-window probe.
  2396  	b := c.GetPacket()
  2397  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2398  	checker.IPv4(t, b,
  2399  		checker.PayloadLen(header.TCPMinimumSize),
  2400  		checker.TCP(
  2401  			checker.DstPort(context.TestPort),
  2402  			checker.TCPSeqNum(uint32(c.IRS)),
  2403  			checker.TCPAckNum(uint32(iss)),
  2404  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2405  		),
  2406  	)
  2407  
  2408  	// Open up the window. Data should be received now.
  2409  	c.SendPacket(nil, &context.Headers{
  2410  		SrcPort: context.TestPort,
  2411  		DstPort: c.Port,
  2412  		Flags:   header.TCPFlagAck,
  2413  		SeqNum:  iss,
  2414  		AckNum:  c.IRS.Add(1),
  2415  		RcvWnd:  30000,
  2416  	})
  2417  
  2418  	// Check that data is received.
  2419  	b = c.GetPacket()
  2420  	checker.IPv4(t, b,
  2421  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  2422  		checker.TCP(
  2423  			checker.DstPort(context.TestPort),
  2424  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2425  			checker.TCPAckNum(uint32(iss)),
  2426  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2427  		),
  2428  	)
  2429  
  2430  	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
  2431  		t.Fatalf("got data = %v, want = %v", p, data)
  2432  	}
  2433  
  2434  	// Acknowledge the data.
  2435  	c.SendPacket(nil, &context.Headers{
  2436  		SrcPort: context.TestPort,
  2437  		DstPort: c.Port,
  2438  		Flags:   header.TCPFlagAck,
  2439  		SeqNum:  iss,
  2440  		AckNum:  c.IRS.Add(1 + seqnum.Size(len(data))),
  2441  		RcvWnd:  30000,
  2442  	})
  2443  }
  2444  
  2445  func TestScaledWindowConnect(t *testing.T) {
  2446  	// This test ensures that window scaling is used when the peer
  2447  	// does advertise it and connection is established with Connect().
  2448  	c := context.New(t, defaultMTU)
  2449  	defer c.Cleanup()
  2450  
  2451  	// Set the window size greater than the maximum non-scaled window.
  2452  	c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{
  2453  		header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
  2454  	})
  2455  
  2456  	data := []byte{1, 2, 3}
  2457  	var r bytes.Reader
  2458  	r.Reset(data)
  2459  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2460  		t.Fatalf("Write failed: %s", err)
  2461  	}
  2462  
  2463  	// Check that data is received, and that advertised window is 0x5fff,
  2464  	// that is, that it is scaled.
  2465  	b := c.GetPacket()
  2466  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2467  	checker.IPv4(t, b,
  2468  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  2469  		checker.TCP(
  2470  			checker.DstPort(context.TestPort),
  2471  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2472  			checker.TCPAckNum(uint32(iss)),
  2473  			checker.TCPWindow(0x5fff),
  2474  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2475  		),
  2476  	)
  2477  }
  2478  
  2479  func TestNonScaledWindowConnect(t *testing.T) {
  2480  	// This test ensures that window scaling is not used when the peer
  2481  	// doesn't advertise it and connection is established with Connect().
  2482  	c := context.New(t, defaultMTU)
  2483  	defer c.Cleanup()
  2484  
  2485  	// Set the window size greater than the maximum non-scaled window.
  2486  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3)
  2487  
  2488  	data := []byte{1, 2, 3}
  2489  	var r bytes.Reader
  2490  	r.Reset(data)
  2491  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2492  		t.Fatalf("Write failed: %s", err)
  2493  	}
  2494  
  2495  	// Check that data is received, and that advertised window is 0xffff,
  2496  	// that is, that it's not scaled.
  2497  	b := c.GetPacket()
  2498  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2499  	checker.IPv4(t, b,
  2500  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  2501  		checker.TCP(
  2502  			checker.DstPort(context.TestPort),
  2503  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2504  			checker.TCPAckNum(uint32(iss)),
  2505  			checker.TCPWindow(0xffff),
  2506  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2507  		),
  2508  	)
  2509  }
  2510  
  2511  func TestScaledWindowAccept(t *testing.T) {
  2512  	// This test ensures that window scaling is used when the peer
  2513  	// does advertise it and connection is established with Accept().
  2514  	c := context.New(t, defaultMTU)
  2515  	defer c.Cleanup()
  2516  
  2517  	// Create EP and start listening.
  2518  	wq := &waiter.Queue{}
  2519  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  2520  	if err != nil {
  2521  		t.Fatalf("NewEndpoint failed: %s", err)
  2522  	}
  2523  	defer ep.Close()
  2524  
  2525  	// Set the window size greater than the maximum non-scaled window.
  2526  	ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
  2527  
  2528  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  2529  		t.Fatalf("Bind failed: %s", err)
  2530  	}
  2531  
  2532  	if err := ep.Listen(10); err != nil {
  2533  		t.Fatalf("Listen failed: %s", err)
  2534  	}
  2535  
  2536  	// Do 3-way handshake.
  2537  	// wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
  2538  	c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
  2539  
  2540  	// Try to accept the connection.
  2541  	we, ch := waiter.NewChannelEntry(nil)
  2542  	wq.EventRegister(&we, waiter.ReadableEvents)
  2543  	defer wq.EventUnregister(&we)
  2544  
  2545  	c.EP, _, err = ep.Accept(nil)
  2546  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  2547  		// Wait for connection to be established.
  2548  		select {
  2549  		case <-ch:
  2550  			c.EP, _, err = ep.Accept(nil)
  2551  			if err != nil {
  2552  				t.Fatalf("Accept failed: %s", err)
  2553  			}
  2554  
  2555  		case <-time.After(1 * time.Second):
  2556  			t.Fatalf("Timed out waiting for accept")
  2557  		}
  2558  	}
  2559  
  2560  	data := []byte{1, 2, 3}
  2561  	var r bytes.Reader
  2562  	r.Reset(data)
  2563  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2564  		t.Fatalf("Write failed: %s", err)
  2565  	}
  2566  
  2567  	// Check that data is received, and that advertised window is 0x5fff,
  2568  	// that is, that it is scaled.
  2569  	b := c.GetPacket()
  2570  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2571  	checker.IPv4(t, b,
  2572  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  2573  		checker.TCP(
  2574  			checker.DstPort(context.TestPort),
  2575  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2576  			checker.TCPAckNum(uint32(iss)),
  2577  			checker.TCPWindow(0x5fff),
  2578  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2579  		),
  2580  	)
  2581  }
  2582  
  2583  func TestNonScaledWindowAccept(t *testing.T) {
  2584  	// This test ensures that window scaling is not used when the peer
  2585  	// doesn't advertise it and connection is established with Accept().
  2586  	c := context.New(t, defaultMTU)
  2587  	defer c.Cleanup()
  2588  
  2589  	// Create EP and start listening.
  2590  	wq := &waiter.Queue{}
  2591  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  2592  	if err != nil {
  2593  		t.Fatalf("NewEndpoint failed: %s", err)
  2594  	}
  2595  	defer ep.Close()
  2596  
  2597  	// Set the window size greater than the maximum non-scaled window.
  2598  	ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
  2599  
  2600  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  2601  		t.Fatalf("Bind failed: %s", err)
  2602  	}
  2603  
  2604  	if err := ep.Listen(10); err != nil {
  2605  		t.Fatalf("Listen failed: %s", err)
  2606  	}
  2607  
  2608  	// Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
  2609  	// should not carry the window scaling option.
  2610  	c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
  2611  
  2612  	// Try to accept the connection.
  2613  	we, ch := waiter.NewChannelEntry(nil)
  2614  	wq.EventRegister(&we, waiter.ReadableEvents)
  2615  	defer wq.EventUnregister(&we)
  2616  
  2617  	c.EP, _, err = ep.Accept(nil)
  2618  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  2619  		// Wait for connection to be established.
  2620  		select {
  2621  		case <-ch:
  2622  			c.EP, _, err = ep.Accept(nil)
  2623  			if err != nil {
  2624  				t.Fatalf("Accept failed: %s", err)
  2625  			}
  2626  
  2627  		case <-time.After(1 * time.Second):
  2628  			t.Fatalf("Timed out waiting for accept")
  2629  		}
  2630  	}
  2631  
  2632  	data := []byte{1, 2, 3}
  2633  	var r bytes.Reader
  2634  	r.Reset(data)
  2635  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2636  		t.Fatalf("Write failed: %s", err)
  2637  	}
  2638  
  2639  	// Check that data is received, and that advertised window is 0xffff,
  2640  	// that is, that it's not scaled.
  2641  	b := c.GetPacket()
  2642  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2643  	checker.IPv4(t, b,
  2644  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  2645  		checker.TCP(
  2646  			checker.DstPort(context.TestPort),
  2647  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2648  			checker.TCPAckNum(uint32(iss)),
  2649  			checker.TCPWindow(0xffff),
  2650  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2651  		),
  2652  	)
  2653  }
  2654  
  2655  func TestZeroScaledWindowReceive(t *testing.T) {
  2656  	// This test ensures that the endpoint sends a non-zero window size
  2657  	// advertisement when the scaled window transitions from 0 to non-zero,
  2658  	// but the actual window (not scaled) hasn't gotten to zero.
  2659  	c := context.New(t, defaultMTU)
  2660  	defer c.Cleanup()
  2661  
  2662  	// Set the buffer size such that a window scale of 5 will be used.
  2663  	const bufSz = 65535 * 10
  2664  	const ws = uint32(5)
  2665  	c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{
  2666  		header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
  2667  	})
  2668  
  2669  	// Write chunks of 50000 bytes.
  2670  	remain := 0
  2671  	sent := 0
  2672  	data := make([]byte, 50000)
  2673  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2674  	// Keep writing till the window drops below len(data).
  2675  	for {
  2676  		c.SendPacket(data, &context.Headers{
  2677  			SrcPort: context.TestPort,
  2678  			DstPort: c.Port,
  2679  			Flags:   header.TCPFlagAck,
  2680  			SeqNum:  iss.Add(seqnum.Size(sent)),
  2681  			AckNum:  c.IRS.Add(1),
  2682  			RcvWnd:  30000,
  2683  		})
  2684  		sent += len(data)
  2685  		pkt := c.GetPacket()
  2686  		checker.IPv4(t, pkt,
  2687  			checker.PayloadLen(header.TCPMinimumSize),
  2688  			checker.TCP(
  2689  				checker.DstPort(context.TestPort),
  2690  				checker.TCPSeqNum(uint32(c.IRS)+1),
  2691  				checker.TCPAckNum(uint32(iss)+uint32(sent)),
  2692  				checker.TCPFlags(header.TCPFlagAck),
  2693  			),
  2694  		)
  2695  		// Don't reduce window to zero here.
  2696  		if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
  2697  			remain = wnd << ws
  2698  			break
  2699  		}
  2700  	}
  2701  
  2702  	// Make the window non-zero, but the scaled window zero.
  2703  	for remain >= 16 {
  2704  		data = data[:remain-15]
  2705  		c.SendPacket(data, &context.Headers{
  2706  			SrcPort: context.TestPort,
  2707  			DstPort: c.Port,
  2708  			Flags:   header.TCPFlagAck,
  2709  			SeqNum:  iss.Add(seqnum.Size(sent)),
  2710  			AckNum:  c.IRS.Add(1),
  2711  			RcvWnd:  30000,
  2712  		})
  2713  		sent += len(data)
  2714  		pkt := c.GetPacket()
  2715  		checker.IPv4(t, pkt,
  2716  			checker.PayloadLen(header.TCPMinimumSize),
  2717  			checker.TCP(
  2718  				checker.DstPort(context.TestPort),
  2719  				checker.TCPSeqNum(uint32(c.IRS)+1),
  2720  				checker.TCPAckNum(uint32(iss)+uint32(sent)),
  2721  				checker.TCPFlags(header.TCPFlagAck),
  2722  			),
  2723  		)
  2724  		// Since the receive buffer is split between window advertisement and
  2725  		// application data buffer the window does not always reflect the space
  2726  		// available and actual space available can be a bit more than what is
  2727  		// advertised in the window.
  2728  		wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
  2729  		if wnd == 0 {
  2730  			break
  2731  		}
  2732  		remain = wnd << ws
  2733  	}
  2734  
  2735  	// Read at least 2MSS of data. An ack should be sent in response to that.
  2736  	// Since buffer space is now split in half between window and application
  2737  	// data we need to read more than 1 MSS(65536) of data for a non-zero window
  2738  	// update to be sent. For 1MSS worth of window to be available we need to
  2739  	// read at least 128KB. Since our segments above were 50KB each it means
  2740  	// we need to read at 3 packets.
  2741  	w := tcpip.LimitedWriter{
  2742  		W: ioutil.Discard,
  2743  		N: defaultMTU * 2,
  2744  	}
  2745  	for w.N != 0 {
  2746  		res, err := c.EP.Read(&w, tcpip.ReadOptions{})
  2747  		t.Logf("err=%v res=%#v", err, res)
  2748  		if err != nil {
  2749  			t.Fatalf("Read failed: %s", err)
  2750  		}
  2751  	}
  2752  
  2753  	checker.IPv4(t, c.GetPacket(),
  2754  		checker.PayloadLen(header.TCPMinimumSize),
  2755  		checker.TCP(
  2756  			checker.DstPort(context.TestPort),
  2757  			checker.TCPSeqNum(uint32(c.IRS)+1),
  2758  			checker.TCPAckNum(uint32(iss)+uint32(sent)),
  2759  			checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
  2760  			checker.TCPFlags(header.TCPFlagAck),
  2761  		),
  2762  	)
  2763  }
  2764  
  2765  func TestSegmentMerging(t *testing.T) {
  2766  	tests := []struct {
  2767  		name   string
  2768  		stop   func(tcpip.Endpoint)
  2769  		resume func(tcpip.Endpoint)
  2770  	}{
  2771  		{
  2772  			"stop work",
  2773  			func(ep tcpip.Endpoint) {
  2774  				ep.(interface{ StopWork() }).StopWork()
  2775  			},
  2776  			func(ep tcpip.Endpoint) {
  2777  				ep.(interface{ ResumeWork() }).ResumeWork()
  2778  			},
  2779  		},
  2780  		{
  2781  			"cork",
  2782  			func(ep tcpip.Endpoint) {
  2783  				ep.SocketOptions().SetCorkOption(true)
  2784  			},
  2785  			func(ep tcpip.Endpoint) {
  2786  				ep.SocketOptions().SetCorkOption(false)
  2787  			},
  2788  		},
  2789  	}
  2790  
  2791  	for _, test := range tests {
  2792  		t.Run(test.name, func(t *testing.T) {
  2793  			c := context.New(t, defaultMTU)
  2794  			defer c.Cleanup()
  2795  
  2796  			c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  2797  
  2798  			// Send tcp.InitialCwnd number of segments to fill up
  2799  			// InitialWindow but don't ACK. That should prevent
  2800  			// anymore packets from going out.
  2801  			var r bytes.Reader
  2802  			for i := 0; i < tcp.InitialCwnd; i++ {
  2803  				r.Reset([]byte{0})
  2804  				if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2805  					t.Fatalf("Write #%d failed: %s", i+1, err)
  2806  				}
  2807  			}
  2808  
  2809  			// Now send the segments that should get merged as the congestion
  2810  			// window is full and we won't be able to send any more packets.
  2811  			var allData []byte
  2812  			for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
  2813  				allData = append(allData, data...)
  2814  				r.Reset(data)
  2815  				if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2816  					t.Fatalf("Write #%d failed: %s", i+1, err)
  2817  				}
  2818  			}
  2819  
  2820  			// Check that we get tcp.InitialCwnd packets.
  2821  			iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2822  			for i := 0; i < tcp.InitialCwnd; i++ {
  2823  				b := c.GetPacket()
  2824  				checker.IPv4(t, b,
  2825  					checker.PayloadLen(header.TCPMinimumSize+1),
  2826  					checker.TCP(
  2827  						checker.DstPort(context.TestPort),
  2828  						checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
  2829  						checker.TCPAckNum(uint32(iss)),
  2830  						checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2831  					),
  2832  				)
  2833  			}
  2834  
  2835  			// Acknowledge the data.
  2836  			c.SendPacket(nil, &context.Headers{
  2837  				SrcPort: context.TestPort,
  2838  				DstPort: c.Port,
  2839  				Flags:   header.TCPFlagAck,
  2840  				SeqNum:  iss,
  2841  				AckNum:  c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
  2842  				RcvWnd:  30000,
  2843  			})
  2844  
  2845  			// Check that data is received.
  2846  			b := c.GetPacket()
  2847  			checker.IPv4(t, b,
  2848  				checker.PayloadLen(len(allData)+header.TCPMinimumSize),
  2849  				checker.TCP(
  2850  					checker.DstPort(context.TestPort),
  2851  					checker.TCPSeqNum(uint32(c.IRS)+11),
  2852  					checker.TCPAckNum(uint32(iss)),
  2853  					checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2854  				),
  2855  			)
  2856  
  2857  			if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) {
  2858  				t.Fatalf("got data = %v, want = %v", got, allData)
  2859  			}
  2860  
  2861  			// Acknowledge the data.
  2862  			c.SendPacket(nil, &context.Headers{
  2863  				SrcPort: context.TestPort,
  2864  				DstPort: c.Port,
  2865  				Flags:   header.TCPFlagAck,
  2866  				SeqNum:  iss,
  2867  				AckNum:  c.IRS.Add(11 + seqnum.Size(len(allData))),
  2868  				RcvWnd:  30000,
  2869  			})
  2870  		})
  2871  	}
  2872  }
  2873  
  2874  func TestDelay(t *testing.T) {
  2875  	c := context.New(t, defaultMTU)
  2876  	defer c.Cleanup()
  2877  
  2878  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  2879  
  2880  	c.EP.SocketOptions().SetDelayOption(true)
  2881  
  2882  	var allData []byte
  2883  	for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
  2884  		allData = append(allData, data...)
  2885  		var r bytes.Reader
  2886  		r.Reset(data)
  2887  		if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2888  			t.Fatalf("Write #%d failed: %s", i+1, err)
  2889  		}
  2890  	}
  2891  
  2892  	seq := c.IRS.Add(1)
  2893  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2894  	for _, want := range [][]byte{allData[:1], allData[1:]} {
  2895  		// Check that data is received.
  2896  		b := c.GetPacket()
  2897  		checker.IPv4(t, b,
  2898  			checker.PayloadLen(len(want)+header.TCPMinimumSize),
  2899  			checker.TCP(
  2900  				checker.DstPort(context.TestPort),
  2901  				checker.TCPSeqNum(uint32(seq)),
  2902  				checker.TCPAckNum(uint32(iss)),
  2903  				checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2904  			),
  2905  		)
  2906  
  2907  		if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) {
  2908  			t.Fatalf("got data = %v, want = %v", got, want)
  2909  		}
  2910  
  2911  		seq = seq.Add(seqnum.Size(len(want)))
  2912  		// Acknowledge the data.
  2913  		c.SendPacket(nil, &context.Headers{
  2914  			SrcPort: context.TestPort,
  2915  			DstPort: c.Port,
  2916  			Flags:   header.TCPFlagAck,
  2917  			SeqNum:  iss,
  2918  			AckNum:  seq,
  2919  			RcvWnd:  30000,
  2920  		})
  2921  	}
  2922  }
  2923  
  2924  func TestUndelay(t *testing.T) {
  2925  	c := context.New(t, defaultMTU)
  2926  	defer c.Cleanup()
  2927  
  2928  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  2929  
  2930  	c.EP.SocketOptions().SetDelayOption(true)
  2931  
  2932  	allData := [][]byte{{0}, {1, 2, 3}}
  2933  	for i, data := range allData {
  2934  		var r bytes.Reader
  2935  		r.Reset(data)
  2936  		if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  2937  			t.Fatalf("Write #%d failed: %s", i+1, err)
  2938  		}
  2939  	}
  2940  
  2941  	seq := c.IRS.Add(1)
  2942  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  2943  	// Check that data is received.
  2944  	first := c.GetPacket()
  2945  	checker.IPv4(t, first,
  2946  		checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
  2947  		checker.TCP(
  2948  			checker.DstPort(context.TestPort),
  2949  			checker.TCPSeqNum(uint32(seq)),
  2950  			checker.TCPAckNum(uint32(iss)),
  2951  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2952  		),
  2953  	)
  2954  
  2955  	if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) {
  2956  		t.Fatalf("got first packet's data = %v, want = %v", got, want)
  2957  	}
  2958  
  2959  	seq = seq.Add(seqnum.Size(len(allData[0])))
  2960  
  2961  	// Check that we don't get the second packet yet.
  2962  	c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
  2963  
  2964  	c.EP.SocketOptions().SetDelayOption(false)
  2965  
  2966  	// Check that data is received.
  2967  	second := c.GetPacket()
  2968  	checker.IPv4(t, second,
  2969  		checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
  2970  		checker.TCP(
  2971  			checker.DstPort(context.TestPort),
  2972  			checker.TCPSeqNum(uint32(seq)),
  2973  			checker.TCPAckNum(uint32(iss)),
  2974  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  2975  		),
  2976  	)
  2977  
  2978  	if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) {
  2979  		t.Fatalf("got second packet's data = %v, want = %v", got, want)
  2980  	}
  2981  
  2982  	seq = seq.Add(seqnum.Size(len(allData[1])))
  2983  
  2984  	// Acknowledge the data.
  2985  	c.SendPacket(nil, &context.Headers{
  2986  		SrcPort: context.TestPort,
  2987  		DstPort: c.Port,
  2988  		Flags:   header.TCPFlagAck,
  2989  		SeqNum:  iss,
  2990  		AckNum:  seq,
  2991  		RcvWnd:  30000,
  2992  	})
  2993  }
  2994  
  2995  func TestMSSNotDelayed(t *testing.T) {
  2996  	tests := []struct {
  2997  		name string
  2998  		fn   func(tcpip.Endpoint)
  2999  	}{
  3000  		{"no-op", func(tcpip.Endpoint) {}},
  3001  		{"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
  3002  		{"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
  3003  	}
  3004  
  3005  	for _, test := range tests {
  3006  		t.Run(test.name, func(t *testing.T) {
  3007  			const maxPayload = 100
  3008  			c := context.New(t, defaultMTU)
  3009  			defer c.Cleanup()
  3010  
  3011  			c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
  3012  				header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
  3013  			})
  3014  
  3015  			test.fn(c.EP)
  3016  
  3017  			allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
  3018  			for i, data := range allData {
  3019  				var r bytes.Reader
  3020  				r.Reset(data)
  3021  				if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3022  					t.Fatalf("Write #%d failed: %s", i+1, err)
  3023  				}
  3024  			}
  3025  
  3026  			seq := c.IRS.Add(1)
  3027  			iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3028  			for i, data := range allData {
  3029  				// Check that data is received.
  3030  				packet := c.GetPacket()
  3031  				checker.IPv4(t, packet,
  3032  					checker.PayloadLen(len(data)+header.TCPMinimumSize),
  3033  					checker.TCP(
  3034  						checker.DstPort(context.TestPort),
  3035  						checker.TCPSeqNum(uint32(seq)),
  3036  						checker.TCPAckNum(uint32(iss)),
  3037  						checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3038  					),
  3039  				)
  3040  
  3041  				if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) {
  3042  					t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want)
  3043  				}
  3044  
  3045  				seq = seq.Add(seqnum.Size(len(data)))
  3046  			}
  3047  
  3048  			// Acknowledge the data.
  3049  			c.SendPacket(nil, &context.Headers{
  3050  				SrcPort: context.TestPort,
  3051  				DstPort: c.Port,
  3052  				Flags:   header.TCPFlagAck,
  3053  				SeqNum:  iss,
  3054  				AckNum:  seq,
  3055  				RcvWnd:  30000,
  3056  			})
  3057  		})
  3058  	}
  3059  }
  3060  
  3061  func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
  3062  	payloadMultiplier := 10
  3063  	dataLen := payloadMultiplier * maxPayload
  3064  	data := make([]byte, dataLen)
  3065  	for i := range data {
  3066  		data[i] = byte(i)
  3067  	}
  3068  
  3069  	var r bytes.Reader
  3070  	r.Reset(data)
  3071  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3072  		t.Fatalf("Write failed: %s", err)
  3073  	}
  3074  
  3075  	// Check that data is received in chunks.
  3076  	bytesReceived := 0
  3077  	numPackets := 0
  3078  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3079  	for bytesReceived != dataLen {
  3080  		b := c.GetPacket()
  3081  		numPackets++
  3082  		tcpHdr := header.TCP(header.IPv4(b).Payload())
  3083  		payloadLen := len(tcpHdr.Payload())
  3084  		checker.IPv4(t, b,
  3085  			checker.TCP(
  3086  				checker.DstPort(context.TestPort),
  3087  				checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
  3088  				checker.TCPAckNum(uint32(iss)),
  3089  				checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3090  			),
  3091  		)
  3092  
  3093  		pdata := data[bytesReceived : bytesReceived+payloadLen]
  3094  		if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
  3095  			t.Fatalf("got data = %v, want = %v", p, pdata)
  3096  		}
  3097  		bytesReceived += payloadLen
  3098  		var options []byte
  3099  		if c.TimeStampEnabled {
  3100  			// If timestamp option is enabled, echo back the timestamp and increment
  3101  			// the TSEcr value included in the packet and send that back as the TSVal.
  3102  			parsedOpts := tcpHdr.ParsedOptions()
  3103  			tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
  3104  			header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
  3105  			options = tsOpt[:]
  3106  		}
  3107  		// Acknowledge the data.
  3108  		c.SendPacket(nil, &context.Headers{
  3109  			SrcPort: context.TestPort,
  3110  			DstPort: c.Port,
  3111  			Flags:   header.TCPFlagAck,
  3112  			SeqNum:  iss,
  3113  			AckNum:  c.IRS.Add(1 + seqnum.Size(bytesReceived)),
  3114  			RcvWnd:  30000,
  3115  			TCPOpts: options,
  3116  		})
  3117  	}
  3118  	if numPackets == 1 {
  3119  		t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet")
  3120  	}
  3121  }
  3122  
  3123  func TestSendGreaterThanMTU(t *testing.T) {
  3124  	const maxPayload = 100
  3125  	c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
  3126  	defer c.Cleanup()
  3127  
  3128  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3129  	testBrokenUpWrite(t, c, maxPayload)
  3130  }
  3131  
  3132  func TestSetTTL(t *testing.T) {
  3133  	for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
  3134  		t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
  3135  			c := context.New(t, 65535)
  3136  			defer c.Cleanup()
  3137  
  3138  			var err tcpip.Error
  3139  			c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  3140  			if err != nil {
  3141  				t.Fatalf("NewEndpoint failed: %s", err)
  3142  			}
  3143  
  3144  			if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
  3145  				t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
  3146  			}
  3147  
  3148  			{
  3149  				err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
  3150  				if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
  3151  					t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
  3152  				}
  3153  			}
  3154  
  3155  			// Receive SYN packet.
  3156  			b := c.GetPacket()
  3157  
  3158  			checker.IPv4(t, b, checker.TTL(wantTTL))
  3159  		})
  3160  	}
  3161  }
  3162  
  3163  func TestActiveSendMSSLessThanMTU(t *testing.T) {
  3164  	const maxPayload = 100
  3165  	c := context.New(t, 65535)
  3166  	defer c.Cleanup()
  3167  
  3168  	c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
  3169  		header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
  3170  	})
  3171  	testBrokenUpWrite(t, c, maxPayload)
  3172  }
  3173  
  3174  func TestPassiveSendMSSLessThanMTU(t *testing.T) {
  3175  	const maxPayload = 100
  3176  	const mtu = 1200
  3177  	c := context.New(t, mtu)
  3178  	defer c.Cleanup()
  3179  
  3180  	// Create EP and start listening.
  3181  	wq := &waiter.Queue{}
  3182  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  3183  	if err != nil {
  3184  		t.Fatalf("NewEndpoint failed: %s", err)
  3185  	}
  3186  	defer ep.Close()
  3187  
  3188  	// Set the buffer size to a deterministic size so that we can check the
  3189  	// window scaling option.
  3190  	const rcvBufferSize = 0x20000
  3191  	ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
  3192  
  3193  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  3194  		t.Fatalf("Bind failed: %s", err)
  3195  	}
  3196  
  3197  	if err := ep.Listen(10); err != nil {
  3198  		t.Fatalf("Listen failed: %s", err)
  3199  	}
  3200  
  3201  	// Do 3-way handshake.
  3202  	c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
  3203  
  3204  	// Try to accept the connection.
  3205  	we, ch := waiter.NewChannelEntry(nil)
  3206  	wq.EventRegister(&we, waiter.ReadableEvents)
  3207  	defer wq.EventUnregister(&we)
  3208  
  3209  	c.EP, _, err = ep.Accept(nil)
  3210  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  3211  		// Wait for connection to be established.
  3212  		select {
  3213  		case <-ch:
  3214  			c.EP, _, err = ep.Accept(nil)
  3215  			if err != nil {
  3216  				t.Fatalf("Accept failed: %s", err)
  3217  			}
  3218  
  3219  		case <-time.After(1 * time.Second):
  3220  			t.Fatalf("Timed out waiting for accept")
  3221  		}
  3222  	}
  3223  
  3224  	// Check that data gets properly segmented.
  3225  	testBrokenUpWrite(t, c, maxPayload)
  3226  }
  3227  
  3228  func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
  3229  	const maxPayload = 536
  3230  	const mtu = 2000
  3231  	c := context.New(t, mtu)
  3232  	defer c.Cleanup()
  3233  
  3234  	opt := tcpip.TCPAlwaysUseSynCookies(true)
  3235  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  3236  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
  3237  	}
  3238  
  3239  	// Create EP and start listening.
  3240  	wq := &waiter.Queue{}
  3241  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  3242  	if err != nil {
  3243  		t.Fatalf("NewEndpoint failed: %s", err)
  3244  	}
  3245  	defer ep.Close()
  3246  
  3247  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  3248  		t.Fatalf("Bind failed: %s", err)
  3249  	}
  3250  
  3251  	if err := ep.Listen(10); err != nil {
  3252  		t.Fatalf("Listen failed: %s", err)
  3253  	}
  3254  
  3255  	// Do 3-way handshake.
  3256  	c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
  3257  
  3258  	// Try to accept the connection.
  3259  	we, ch := waiter.NewChannelEntry(nil)
  3260  	wq.EventRegister(&we, waiter.ReadableEvents)
  3261  	defer wq.EventUnregister(&we)
  3262  
  3263  	c.EP, _, err = ep.Accept(nil)
  3264  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  3265  		// Wait for connection to be established.
  3266  		select {
  3267  		case <-ch:
  3268  			c.EP, _, err = ep.Accept(nil)
  3269  			if err != nil {
  3270  				t.Fatalf("Accept failed: %s", err)
  3271  			}
  3272  
  3273  		case <-time.After(1 * time.Second):
  3274  			t.Fatalf("Timed out waiting for accept")
  3275  		}
  3276  	}
  3277  
  3278  	// Check that data gets properly segmented.
  3279  	testBrokenUpWrite(t, c, maxPayload)
  3280  }
  3281  
  3282  func TestForwarderSendMSSLessThanMTU(t *testing.T) {
  3283  	const maxPayload = 100
  3284  	const mtu = 1200
  3285  	c := context.New(t, mtu)
  3286  	defer c.Cleanup()
  3287  
  3288  	s := c.Stack()
  3289  	ch := make(chan tcpip.Error, 1)
  3290  	f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
  3291  		var err tcpip.Error
  3292  		c.EP, err = r.CreateEndpoint(&c.WQ)
  3293  		ch <- err
  3294  	})
  3295  	s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
  3296  
  3297  	// Do 3-way handshake.
  3298  	c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
  3299  
  3300  	// Wait for connection to be available.
  3301  	select {
  3302  	case err := <-ch:
  3303  		if err != nil {
  3304  			t.Fatalf("Error creating endpoint: %s", err)
  3305  		}
  3306  	case <-time.After(2 * time.Second):
  3307  		t.Fatalf("Timed out waiting for connection")
  3308  	}
  3309  
  3310  	// Check that data gets properly segmented.
  3311  	testBrokenUpWrite(t, c, maxPayload)
  3312  }
  3313  
  3314  func TestSynOptionsOnActiveConnect(t *testing.T) {
  3315  	const mtu = 1400
  3316  	c := context.New(t, mtu)
  3317  	defer c.Cleanup()
  3318  
  3319  	// Create TCP endpoint.
  3320  	var err tcpip.Error
  3321  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  3322  	if err != nil {
  3323  		t.Fatalf("NewEndpoint failed: %s", err)
  3324  	}
  3325  
  3326  	// Set the buffer size to a deterministic size so that we can check the
  3327  	// window scaling option.
  3328  	const rcvBufferSize = 0x20000
  3329  	const wndScale = 3
  3330  	c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
  3331  
  3332  	// Start connection attempt.
  3333  	we, ch := waiter.NewChannelEntry(nil)
  3334  	c.WQ.EventRegister(&we, waiter.WritableEvents)
  3335  	defer c.WQ.EventUnregister(&we)
  3336  
  3337  	{
  3338  		err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
  3339  		if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
  3340  			t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
  3341  		}
  3342  	}
  3343  
  3344  	// Receive SYN packet.
  3345  	b := c.GetPacket()
  3346  	mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize)
  3347  	checker.IPv4(t, b,
  3348  		checker.TCP(
  3349  			checker.DstPort(context.TestPort),
  3350  			checker.TCPFlags(header.TCPFlagSyn),
  3351  			checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
  3352  		),
  3353  	)
  3354  
  3355  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  3356  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  3357  
  3358  	// Wait for retransmit.
  3359  	time.Sleep(1 * time.Second)
  3360  	checker.IPv4(t, c.GetPacket(),
  3361  		checker.TCP(
  3362  			checker.DstPort(context.TestPort),
  3363  			checker.TCPFlags(header.TCPFlagSyn),
  3364  			checker.SrcPort(tcpHdr.SourcePort()),
  3365  			checker.TCPSeqNum(tcpHdr.SequenceNumber()),
  3366  			checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
  3367  		),
  3368  	)
  3369  
  3370  	// Send SYN-ACK.
  3371  	iss := seqnum.Value(context.TestInitialSequenceNumber)
  3372  	c.SendPacket(nil, &context.Headers{
  3373  		SrcPort: tcpHdr.DestinationPort(),
  3374  		DstPort: tcpHdr.SourcePort(),
  3375  		Flags:   header.TCPFlagSyn | header.TCPFlagAck,
  3376  		SeqNum:  iss,
  3377  		AckNum:  c.IRS.Add(1),
  3378  		RcvWnd:  30000,
  3379  	})
  3380  
  3381  	// Receive ACK packet.
  3382  	checker.IPv4(t, c.GetPacket(),
  3383  		checker.TCP(
  3384  			checker.DstPort(context.TestPort),
  3385  			checker.TCPFlags(header.TCPFlagAck),
  3386  			checker.TCPSeqNum(uint32(c.IRS)+1),
  3387  			checker.TCPAckNum(uint32(iss)+1),
  3388  		),
  3389  	)
  3390  
  3391  	// Wait for connection to be established.
  3392  	select {
  3393  	case <-ch:
  3394  		if err := c.EP.LastError(); err != nil {
  3395  			t.Fatalf("Connect failed: %s", err)
  3396  		}
  3397  	case <-time.After(1 * time.Second):
  3398  		t.Fatalf("Timed out waiting for connection")
  3399  	}
  3400  }
  3401  
  3402  func TestCloseListener(t *testing.T) {
  3403  	c := context.New(t, defaultMTU)
  3404  	defer c.Cleanup()
  3405  
  3406  	// Create listener.
  3407  	var wq waiter.Queue
  3408  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
  3409  	if err != nil {
  3410  		t.Fatalf("NewEndpoint failed: %s", err)
  3411  	}
  3412  
  3413  	if err := ep.Bind(tcpip.FullAddress{}); err != nil {
  3414  		t.Fatalf("Bind failed: %s", err)
  3415  	}
  3416  
  3417  	if err := ep.Listen(10); err != nil {
  3418  		t.Fatalf("Listen failed: %s", err)
  3419  	}
  3420  
  3421  	// Close the listener and measure how long it takes.
  3422  	t0 := time.Now()
  3423  	ep.Close()
  3424  	if diff := time.Now().Sub(t0); diff > 3*time.Second {
  3425  		t.Fatalf("Took too long to close: %s", diff)
  3426  	}
  3427  }
  3428  
  3429  func TestReceiveOnResetConnection(t *testing.T) {
  3430  	c := context.New(t, defaultMTU)
  3431  	defer c.Cleanup()
  3432  
  3433  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3434  
  3435  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3436  	// Send RST segment.
  3437  	c.SendPacket(nil, &context.Headers{
  3438  		SrcPort: context.TestPort,
  3439  		DstPort: c.Port,
  3440  		Flags:   header.TCPFlagRst,
  3441  		SeqNum:  iss,
  3442  		RcvWnd:  30000,
  3443  	})
  3444  
  3445  	// Try to read.
  3446  	we, ch := waiter.NewChannelEntry(nil)
  3447  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  3448  	defer c.WQ.EventUnregister(&we)
  3449  
  3450  loop:
  3451  	for {
  3452  		switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
  3453  		case *tcpip.ErrWouldBlock:
  3454  			<-ch
  3455  			// Expect the state to be StateError and subsequent Reads to fail with HardError.
  3456  			_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
  3457  			if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
  3458  				t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
  3459  			}
  3460  			break loop
  3461  		case *tcpip.ErrConnectionReset:
  3462  			break loop
  3463  		default:
  3464  			t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
  3465  		}
  3466  	}
  3467  
  3468  	if tcp.EndpointState(c.EP.State()) != tcp.StateError {
  3469  		t.Fatalf("got EP state is not StateError")
  3470  	}
  3471  
  3472  	checkValid := func() []error {
  3473  		var errors []error
  3474  		if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
  3475  			errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
  3476  		}
  3477  		if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
  3478  			errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
  3479  		}
  3480  		if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
  3481  			errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
  3482  		}
  3483  		return errors
  3484  	}
  3485  
  3486  	start := time.Now()
  3487  	for time.Since(start) < time.Minute && len(checkValid()) > 0 {
  3488  		time.Sleep(50 * time.Millisecond)
  3489  	}
  3490  	for _, err := range checkValid() {
  3491  		t.Error(err)
  3492  	}
  3493  }
  3494  
  3495  func TestSendOnResetConnection(t *testing.T) {
  3496  	c := context.New(t, defaultMTU)
  3497  	defer c.Cleanup()
  3498  
  3499  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3500  
  3501  	// Send RST segment.
  3502  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3503  	c.SendPacket(nil, &context.Headers{
  3504  		SrcPort: context.TestPort,
  3505  		DstPort: c.Port,
  3506  		Flags:   header.TCPFlagRst,
  3507  		SeqNum:  iss,
  3508  		RcvWnd:  30000,
  3509  	})
  3510  
  3511  	// Wait for the RST to be received.
  3512  	time.Sleep(1 * time.Second)
  3513  
  3514  	// Try to write.
  3515  	var r bytes.Reader
  3516  	r.Reset(make([]byte, 10))
  3517  	_, err := c.EP.Write(&r, tcpip.WriteOptions{})
  3518  	if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
  3519  		t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d)
  3520  	}
  3521  }
  3522  
  3523  // TestMaxRetransmitsTimeout tests if the connection is timed out after
  3524  // a segment has been retransmitted MaxRetries times.
  3525  func TestMaxRetransmitsTimeout(t *testing.T) {
  3526  	c := context.New(t, defaultMTU)
  3527  	defer c.Cleanup()
  3528  
  3529  	const numRetries = 2
  3530  	opt := tcpip.TCPMaxRetriesOption(numRetries)
  3531  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  3532  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
  3533  	}
  3534  
  3535  	c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
  3536  
  3537  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
  3538  	c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
  3539  	defer c.WQ.EventUnregister(&waitEntry)
  3540  
  3541  	var r bytes.Reader
  3542  	r.Reset(make([]byte, 1))
  3543  	_, err := c.EP.Write(&r, tcpip.WriteOptions{})
  3544  	if err != nil {
  3545  		t.Fatalf("Write failed: %s", err)
  3546  	}
  3547  
  3548  	// Expect first transmit and MaxRetries retransmits.
  3549  	for i := 0; i < numRetries+1; i++ {
  3550  		checker.IPv4(t, c.GetPacket(),
  3551  			checker.TCP(
  3552  				checker.DstPort(context.TestPort),
  3553  				checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
  3554  			),
  3555  		)
  3556  	}
  3557  	// Wait for the connection to timeout after MaxRetries retransmits.
  3558  	initRTO := 1 * time.Second
  3559  	select {
  3560  	case <-notifyCh:
  3561  	case <-time.After((2 << numRetries) * initRTO):
  3562  		t.Fatalf("connection still alive after maximum retransmits.\n")
  3563  	}
  3564  
  3565  	// Send an ACK and expect a RST as the connection would have been closed.
  3566  	c.SendPacket(nil, &context.Headers{
  3567  		SrcPort: context.TestPort,
  3568  		DstPort: c.Port,
  3569  		Flags:   header.TCPFlagAck,
  3570  	})
  3571  
  3572  	checker.IPv4(t, c.GetPacket(),
  3573  		checker.TCP(
  3574  			checker.DstPort(context.TestPort),
  3575  			checker.TCPFlags(header.TCPFlagRst),
  3576  		),
  3577  	)
  3578  
  3579  	if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
  3580  		t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
  3581  	}
  3582  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
  3583  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
  3584  	}
  3585  }
  3586  
  3587  // TestMaxRTO tests if the retransmit interval caps to MaxRTO.
  3588  func TestMaxRTO(t *testing.T) {
  3589  	c := context.New(t, defaultMTU)
  3590  	defer c.Cleanup()
  3591  
  3592  	rto := 1 * time.Second
  3593  	opt := tcpip.TCPMaxRTOOption(rto)
  3594  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  3595  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
  3596  	}
  3597  
  3598  	c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
  3599  
  3600  	var r bytes.Reader
  3601  	r.Reset(make([]byte, 1))
  3602  	_, err := c.EP.Write(&r, tcpip.WriteOptions{})
  3603  	if err != nil {
  3604  		t.Fatalf("Write failed: %s", err)
  3605  	}
  3606  	checker.IPv4(t, c.GetPacket(),
  3607  		checker.TCP(
  3608  			checker.DstPort(context.TestPort),
  3609  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3610  		),
  3611  	)
  3612  	const numRetransmits = 2
  3613  	for i := 0; i < numRetransmits; i++ {
  3614  		start := time.Now()
  3615  		checker.IPv4(t, c.GetPacket(),
  3616  			checker.TCP(
  3617  				checker.DstPort(context.TestPort),
  3618  				checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3619  			),
  3620  		)
  3621  		if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
  3622  			t.Errorf("Retransmit interval not capped to MaxRTO.\n")
  3623  		}
  3624  	}
  3625  }
  3626  
  3627  // TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
  3628  // unique on retransmits.
  3629  func TestRetransmitIPv4IDUniqueness(t *testing.T) {
  3630  	for _, tc := range []struct {
  3631  		name string
  3632  		size int
  3633  	}{
  3634  		{"1Byte", 1},
  3635  		{"512Bytes", 512},
  3636  	} {
  3637  		t.Run(tc.name, func(t *testing.T) {
  3638  			c := context.New(t, defaultMTU)
  3639  			defer c.Cleanup()
  3640  
  3641  			c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
  3642  
  3643  			// Disabling PMTU discovery causes all packets sent from this socket to
  3644  			// have DF=0. This needs to be done because the IPv4 ID uniqueness
  3645  			// applies only to non-atomic IPv4 datagrams as defined in RFC 6864
  3646  			// Section 4, and datagrams with DF=0 are non-atomic.
  3647  			if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
  3648  				t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
  3649  			}
  3650  
  3651  			var r bytes.Reader
  3652  			r.Reset(make([]byte, tc.size))
  3653  			if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3654  				t.Fatalf("Write failed: %s", err)
  3655  			}
  3656  			pkt := c.GetPacket()
  3657  			checker.IPv4(t, pkt,
  3658  				checker.FragmentFlags(0),
  3659  				checker.TCP(
  3660  					checker.DstPort(context.TestPort),
  3661  					checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3662  				),
  3663  			)
  3664  			idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}}
  3665  			// Expect two retransmitted packets, and that all packets received have
  3666  			// unique IPv4 ID values.
  3667  			for i := 0; i <= 2; i++ {
  3668  				pkt := c.GetPacket()
  3669  				checker.IPv4(t, pkt,
  3670  					checker.FragmentFlags(0),
  3671  					checker.TCP(
  3672  						checker.DstPort(context.TestPort),
  3673  						checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3674  					),
  3675  				)
  3676  				id := header.IPv4(pkt).ID()
  3677  				if _, exists := idSet[id]; exists {
  3678  					t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
  3679  				}
  3680  				idSet[id] = struct{}{}
  3681  			}
  3682  		})
  3683  	}
  3684  }
  3685  
  3686  func TestFinImmediately(t *testing.T) {
  3687  	c := context.New(t, defaultMTU)
  3688  	defer c.Cleanup()
  3689  
  3690  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3691  
  3692  	// Shutdown immediately, check that we get a FIN.
  3693  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  3694  		t.Fatalf("Shutdown failed: %s", err)
  3695  	}
  3696  
  3697  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3698  	checker.IPv4(t, c.GetPacket(),
  3699  		checker.PayloadLen(header.TCPMinimumSize),
  3700  		checker.TCP(
  3701  			checker.DstPort(context.TestPort),
  3702  			checker.TCPSeqNum(uint32(c.IRS)+1),
  3703  			checker.TCPAckNum(uint32(iss)),
  3704  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  3705  		),
  3706  	)
  3707  
  3708  	// Ack and send FIN as well.
  3709  	c.SendPacket(nil, &context.Headers{
  3710  		SrcPort: context.TestPort,
  3711  		DstPort: c.Port,
  3712  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  3713  		SeqNum:  iss,
  3714  		AckNum:  c.IRS.Add(2),
  3715  		RcvWnd:  30000,
  3716  	})
  3717  
  3718  	// Check that the stack acks the FIN.
  3719  	checker.IPv4(t, c.GetPacket(),
  3720  		checker.PayloadLen(header.TCPMinimumSize),
  3721  		checker.TCP(
  3722  			checker.DstPort(context.TestPort),
  3723  			checker.TCPSeqNum(uint32(c.IRS)+2),
  3724  			checker.TCPAckNum(uint32(iss)+1),
  3725  			checker.TCPFlags(header.TCPFlagAck),
  3726  		),
  3727  	)
  3728  }
  3729  
  3730  func TestFinRetransmit(t *testing.T) {
  3731  	c := context.New(t, defaultMTU)
  3732  	defer c.Cleanup()
  3733  
  3734  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3735  
  3736  	// Shutdown immediately, check that we get a FIN.
  3737  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  3738  		t.Fatalf("Shutdown failed: %s", err)
  3739  	}
  3740  
  3741  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3742  	checker.IPv4(t, c.GetPacket(),
  3743  		checker.PayloadLen(header.TCPMinimumSize),
  3744  		checker.TCP(
  3745  			checker.DstPort(context.TestPort),
  3746  			checker.TCPSeqNum(uint32(c.IRS)+1),
  3747  			checker.TCPAckNum(uint32(iss)),
  3748  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  3749  		),
  3750  	)
  3751  
  3752  	// Don't acknowledge yet. We should get a retransmit of the FIN.
  3753  	checker.IPv4(t, c.GetPacket(),
  3754  		checker.PayloadLen(header.TCPMinimumSize),
  3755  		checker.TCP(
  3756  			checker.DstPort(context.TestPort),
  3757  			checker.TCPSeqNum(uint32(c.IRS)+1),
  3758  			checker.TCPAckNum(uint32(iss)),
  3759  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  3760  		),
  3761  	)
  3762  
  3763  	// Ack and send FIN as well.
  3764  	c.SendPacket(nil, &context.Headers{
  3765  		SrcPort: context.TestPort,
  3766  		DstPort: c.Port,
  3767  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  3768  		SeqNum:  iss,
  3769  		AckNum:  c.IRS.Add(2),
  3770  		RcvWnd:  30000,
  3771  	})
  3772  
  3773  	// Check that the stack acks the FIN.
  3774  	checker.IPv4(t, c.GetPacket(),
  3775  		checker.PayloadLen(header.TCPMinimumSize),
  3776  		checker.TCP(
  3777  			checker.DstPort(context.TestPort),
  3778  			checker.TCPSeqNum(uint32(c.IRS)+2),
  3779  			checker.TCPAckNum(uint32(iss)+1),
  3780  			checker.TCPFlags(header.TCPFlagAck),
  3781  		),
  3782  	)
  3783  }
  3784  
  3785  func TestFinWithNoPendingData(t *testing.T) {
  3786  	c := context.New(t, defaultMTU)
  3787  	defer c.Cleanup()
  3788  
  3789  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3790  
  3791  	// Write something out, and have it acknowledged.
  3792  	view := make([]byte, 10)
  3793  	var r bytes.Reader
  3794  	r.Reset(view)
  3795  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3796  		t.Fatalf("Write failed: %s", err)
  3797  	}
  3798  
  3799  	next := uint32(c.IRS) + 1
  3800  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3801  	checker.IPv4(t, c.GetPacket(),
  3802  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  3803  		checker.TCP(
  3804  			checker.DstPort(context.TestPort),
  3805  			checker.TCPSeqNum(next),
  3806  			checker.TCPAckNum(uint32(iss)),
  3807  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3808  		),
  3809  	)
  3810  	next += uint32(len(view))
  3811  
  3812  	c.SendPacket(nil, &context.Headers{
  3813  		SrcPort: context.TestPort,
  3814  		DstPort: c.Port,
  3815  		Flags:   header.TCPFlagAck,
  3816  		SeqNum:  iss,
  3817  		AckNum:  seqnum.Value(next),
  3818  		RcvWnd:  30000,
  3819  	})
  3820  
  3821  	// Shutdown, check that we get a FIN.
  3822  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  3823  		t.Fatalf("Shutdown failed: %s", err)
  3824  	}
  3825  
  3826  	checker.IPv4(t, c.GetPacket(),
  3827  		checker.PayloadLen(header.TCPMinimumSize),
  3828  		checker.TCP(
  3829  			checker.DstPort(context.TestPort),
  3830  			checker.TCPSeqNum(next),
  3831  			checker.TCPAckNum(uint32(iss)),
  3832  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  3833  		),
  3834  	)
  3835  	next++
  3836  
  3837  	// Ack and send FIN as well.
  3838  	c.SendPacket(nil, &context.Headers{
  3839  		SrcPort: context.TestPort,
  3840  		DstPort: c.Port,
  3841  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  3842  		SeqNum:  iss,
  3843  		AckNum:  seqnum.Value(next),
  3844  		RcvWnd:  30000,
  3845  	})
  3846  
  3847  	// Check that the stack acks the FIN.
  3848  	checker.IPv4(t, c.GetPacket(),
  3849  		checker.PayloadLen(header.TCPMinimumSize),
  3850  		checker.TCP(
  3851  			checker.DstPort(context.TestPort),
  3852  			checker.TCPSeqNum(next),
  3853  			checker.TCPAckNum(uint32(iss)+1),
  3854  			checker.TCPFlags(header.TCPFlagAck),
  3855  		),
  3856  	)
  3857  }
  3858  
  3859  func TestFinWithPendingDataCwndFull(t *testing.T) {
  3860  	c := context.New(t, defaultMTU)
  3861  	defer c.Cleanup()
  3862  
  3863  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3864  
  3865  	// Write enough segments to fill the congestion window before ACK'ing
  3866  	// any of them.
  3867  	view := make([]byte, 10)
  3868  	var r bytes.Reader
  3869  	for i := tcp.InitialCwnd; i > 0; i-- {
  3870  		r.Reset(view)
  3871  		if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3872  			t.Fatalf("Write failed: %s", err)
  3873  		}
  3874  	}
  3875  
  3876  	next := uint32(c.IRS) + 1
  3877  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3878  	for i := tcp.InitialCwnd; i > 0; i-- {
  3879  		checker.IPv4(t, c.GetPacket(),
  3880  			checker.PayloadLen(len(view)+header.TCPMinimumSize),
  3881  			checker.TCP(
  3882  				checker.DstPort(context.TestPort),
  3883  				checker.TCPSeqNum(next),
  3884  				checker.TCPAckNum(uint32(iss)),
  3885  				checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3886  			),
  3887  		)
  3888  		next += uint32(len(view))
  3889  	}
  3890  
  3891  	// Shutdown the connection, check that the FIN segment isn't sent
  3892  	// because the congestion window doesn't allow it. Wait until a
  3893  	// retransmit is received.
  3894  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  3895  		t.Fatalf("Shutdown failed: %s", err)
  3896  	}
  3897  
  3898  	checker.IPv4(t, c.GetPacket(),
  3899  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  3900  		checker.TCP(
  3901  			checker.DstPort(context.TestPort),
  3902  			checker.TCPSeqNum(uint32(c.IRS)+1),
  3903  			checker.TCPAckNum(uint32(iss)),
  3904  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3905  		),
  3906  	)
  3907  
  3908  	// Send the ACK that will allow the FIN to be sent as well.
  3909  	c.SendPacket(nil, &context.Headers{
  3910  		SrcPort: context.TestPort,
  3911  		DstPort: c.Port,
  3912  		Flags:   header.TCPFlagAck,
  3913  		SeqNum:  iss,
  3914  		AckNum:  seqnum.Value(next),
  3915  		RcvWnd:  30000,
  3916  	})
  3917  
  3918  	checker.IPv4(t, c.GetPacket(),
  3919  		checker.PayloadLen(header.TCPMinimumSize),
  3920  		checker.TCP(
  3921  			checker.DstPort(context.TestPort),
  3922  			checker.TCPSeqNum(next),
  3923  			checker.TCPAckNum(uint32(iss)),
  3924  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  3925  		),
  3926  	)
  3927  	next++
  3928  
  3929  	// Send a FIN that acknowledges everything. Get an ACK back.
  3930  	c.SendPacket(nil, &context.Headers{
  3931  		SrcPort: context.TestPort,
  3932  		DstPort: c.Port,
  3933  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  3934  		SeqNum:  iss,
  3935  		AckNum:  seqnum.Value(next),
  3936  		RcvWnd:  30000,
  3937  	})
  3938  
  3939  	checker.IPv4(t, c.GetPacket(),
  3940  		checker.PayloadLen(header.TCPMinimumSize),
  3941  		checker.TCP(
  3942  			checker.DstPort(context.TestPort),
  3943  			checker.TCPSeqNum(next),
  3944  			checker.TCPAckNum(uint32(iss)+1),
  3945  			checker.TCPFlags(header.TCPFlagAck),
  3946  		),
  3947  	)
  3948  }
  3949  
  3950  func TestFinWithPendingData(t *testing.T) {
  3951  	c := context.New(t, defaultMTU)
  3952  	defer c.Cleanup()
  3953  
  3954  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  3955  
  3956  	// Write something out, and acknowledge it to get cwnd to 2.
  3957  	view := make([]byte, 10)
  3958  	var r bytes.Reader
  3959  	r.Reset(view)
  3960  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3961  		t.Fatalf("Write failed: %s", err)
  3962  	}
  3963  
  3964  	next := uint32(c.IRS) + 1
  3965  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  3966  	checker.IPv4(t, c.GetPacket(),
  3967  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  3968  		checker.TCP(
  3969  			checker.DstPort(context.TestPort),
  3970  			checker.TCPSeqNum(next),
  3971  			checker.TCPAckNum(uint32(iss)),
  3972  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3973  		),
  3974  	)
  3975  	next += uint32(len(view))
  3976  
  3977  	c.SendPacket(nil, &context.Headers{
  3978  		SrcPort: context.TestPort,
  3979  		DstPort: c.Port,
  3980  		Flags:   header.TCPFlagAck,
  3981  		SeqNum:  iss,
  3982  		AckNum:  seqnum.Value(next),
  3983  		RcvWnd:  30000,
  3984  	})
  3985  
  3986  	// Write new data, but don't acknowledge it.
  3987  	r.Reset(view)
  3988  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  3989  		t.Fatalf("Write failed: %s", err)
  3990  	}
  3991  
  3992  	checker.IPv4(t, c.GetPacket(),
  3993  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  3994  		checker.TCP(
  3995  			checker.DstPort(context.TestPort),
  3996  			checker.TCPSeqNum(next),
  3997  			checker.TCPAckNum(uint32(iss)),
  3998  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  3999  		),
  4000  	)
  4001  	next += uint32(len(view))
  4002  
  4003  	// Shutdown the connection, check that we do get a FIN.
  4004  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  4005  		t.Fatalf("Shutdown failed: %s", err)
  4006  	}
  4007  
  4008  	checker.IPv4(t, c.GetPacket(),
  4009  		checker.PayloadLen(header.TCPMinimumSize),
  4010  		checker.TCP(
  4011  			checker.DstPort(context.TestPort),
  4012  			checker.TCPSeqNum(next),
  4013  			checker.TCPAckNum(uint32(iss)),
  4014  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  4015  		),
  4016  	)
  4017  	next++
  4018  
  4019  	// Send a FIN that acknowledges everything. Get an ACK back.
  4020  	c.SendPacket(nil, &context.Headers{
  4021  		SrcPort: context.TestPort,
  4022  		DstPort: c.Port,
  4023  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  4024  		SeqNum:  iss,
  4025  		AckNum:  seqnum.Value(next),
  4026  		RcvWnd:  30000,
  4027  	})
  4028  
  4029  	checker.IPv4(t, c.GetPacket(),
  4030  		checker.PayloadLen(header.TCPMinimumSize),
  4031  		checker.TCP(
  4032  			checker.DstPort(context.TestPort),
  4033  			checker.TCPSeqNum(next),
  4034  			checker.TCPAckNum(uint32(iss)+1),
  4035  			checker.TCPFlags(header.TCPFlagAck),
  4036  		),
  4037  	)
  4038  }
  4039  
  4040  func TestFinWithPartialAck(t *testing.T) {
  4041  	c := context.New(t, defaultMTU)
  4042  	defer c.Cleanup()
  4043  
  4044  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  4045  
  4046  	// Write something out, and acknowledge it to get cwnd to 2. Also send
  4047  	// FIN from the test side.
  4048  	view := make([]byte, 10)
  4049  	var r bytes.Reader
  4050  	r.Reset(view)
  4051  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  4052  		t.Fatalf("Write failed: %s", err)
  4053  	}
  4054  
  4055  	next := uint32(c.IRS) + 1
  4056  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4057  	checker.IPv4(t, c.GetPacket(),
  4058  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  4059  		checker.TCP(
  4060  			checker.DstPort(context.TestPort),
  4061  			checker.TCPSeqNum(next),
  4062  			checker.TCPAckNum(uint32(iss)),
  4063  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  4064  		),
  4065  	)
  4066  	next += uint32(len(view))
  4067  
  4068  	c.SendPacket(nil, &context.Headers{
  4069  		SrcPort: context.TestPort,
  4070  		DstPort: c.Port,
  4071  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  4072  		SeqNum:  iss,
  4073  		AckNum:  seqnum.Value(next),
  4074  		RcvWnd:  30000,
  4075  	})
  4076  
  4077  	// Check that we get an ACK for the fin.
  4078  	checker.IPv4(t, c.GetPacket(),
  4079  		checker.PayloadLen(header.TCPMinimumSize),
  4080  		checker.TCP(
  4081  			checker.DstPort(context.TestPort),
  4082  			checker.TCPSeqNum(next),
  4083  			checker.TCPAckNum(uint32(iss)+1),
  4084  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  4085  		),
  4086  	)
  4087  
  4088  	// Write new data, but don't acknowledge it.
  4089  	r.Reset(view)
  4090  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  4091  		t.Fatalf("Write failed: %s", err)
  4092  	}
  4093  
  4094  	checker.IPv4(t, c.GetPacket(),
  4095  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  4096  		checker.TCP(
  4097  			checker.DstPort(context.TestPort),
  4098  			checker.TCPSeqNum(next),
  4099  			checker.TCPAckNum(uint32(iss)+1),
  4100  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  4101  		),
  4102  	)
  4103  	next += uint32(len(view))
  4104  
  4105  	// Shutdown the connection, check that we do get a FIN.
  4106  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  4107  		t.Fatalf("Shutdown failed: %s", err)
  4108  	}
  4109  
  4110  	checker.IPv4(t, c.GetPacket(),
  4111  		checker.PayloadLen(header.TCPMinimumSize),
  4112  		checker.TCP(
  4113  			checker.DstPort(context.TestPort),
  4114  			checker.TCPSeqNum(next),
  4115  			checker.TCPAckNum(uint32(iss)+1),
  4116  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  4117  		),
  4118  	)
  4119  	next++
  4120  
  4121  	// Send an ACK for the data, but not for the FIN yet.
  4122  	c.SendPacket(nil, &context.Headers{
  4123  		SrcPort: context.TestPort,
  4124  		DstPort: c.Port,
  4125  		Flags:   header.TCPFlagAck,
  4126  		SeqNum:  iss.Add(1),
  4127  		AckNum:  seqnum.Value(next - 1),
  4128  		RcvWnd:  30000,
  4129  	})
  4130  
  4131  	// Check that we don't get a retransmit of the FIN.
  4132  	c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond)
  4133  
  4134  	// Ack the FIN.
  4135  	c.SendPacket(nil, &context.Headers{
  4136  		SrcPort: context.TestPort,
  4137  		DstPort: c.Port,
  4138  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  4139  		SeqNum:  iss.Add(1),
  4140  		AckNum:  seqnum.Value(next),
  4141  		RcvWnd:  30000,
  4142  	})
  4143  }
  4144  
  4145  func TestUpdateListenBacklog(t *testing.T) {
  4146  	c := context.New(t, defaultMTU)
  4147  	defer c.Cleanup()
  4148  
  4149  	// Create listener.
  4150  	var wq waiter.Queue
  4151  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
  4152  	if err != nil {
  4153  		t.Fatalf("NewEndpoint failed: %s", err)
  4154  	}
  4155  
  4156  	if err := ep.Bind(tcpip.FullAddress{}); err != nil {
  4157  		t.Fatalf("Bind failed: %s", err)
  4158  	}
  4159  
  4160  	if err := ep.Listen(10); err != nil {
  4161  		t.Fatalf("Listen failed: %s", err)
  4162  	}
  4163  
  4164  	// Update the backlog with another Listen() on the same endpoint.
  4165  	if err := ep.Listen(20); err != nil {
  4166  		t.Fatalf("Listen failed to update backlog: %s", err)
  4167  	}
  4168  
  4169  	ep.Close()
  4170  }
  4171  
  4172  func scaledSendWindow(t *testing.T, scale uint8) {
  4173  	// This test ensures that the endpoint is using the right scaling by
  4174  	// sending a buffer that is larger than the window size, and ensuring
  4175  	// that the endpoint doesn't send more than allowed.
  4176  	c := context.New(t, defaultMTU)
  4177  	defer c.Cleanup()
  4178  
  4179  	maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
  4180  	c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{
  4181  		header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
  4182  		header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
  4183  	})
  4184  
  4185  	// Open up the window with a scaled value.
  4186  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4187  	c.SendPacket(nil, &context.Headers{
  4188  		SrcPort: context.TestPort,
  4189  		DstPort: c.Port,
  4190  		Flags:   header.TCPFlagAck,
  4191  		SeqNum:  iss,
  4192  		AckNum:  c.IRS.Add(1),
  4193  		RcvWnd:  1,
  4194  	})
  4195  
  4196  	// Send some data. Check that it's capped by the window size.
  4197  	view := make([]byte, 65535)
  4198  	var r bytes.Reader
  4199  	r.Reset(view)
  4200  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  4201  		t.Fatalf("Write failed: %s", err)
  4202  	}
  4203  
  4204  	// Check that only data that fits in the scaled window is sent.
  4205  	checker.IPv4(t, c.GetPacket(),
  4206  		checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
  4207  		checker.TCP(
  4208  			checker.DstPort(context.TestPort),
  4209  			checker.TCPSeqNum(uint32(c.IRS)+1),
  4210  			checker.TCPAckNum(uint32(iss)),
  4211  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  4212  		),
  4213  	)
  4214  
  4215  	// Reset the connection to free resources.
  4216  	c.SendPacket(nil, &context.Headers{
  4217  		SrcPort: context.TestPort,
  4218  		DstPort: c.Port,
  4219  		Flags:   header.TCPFlagRst,
  4220  		SeqNum:  iss,
  4221  	})
  4222  }
  4223  
  4224  func TestScaledSendWindow(t *testing.T) {
  4225  	for scale := uint8(0); scale <= 14; scale++ {
  4226  		scaledSendWindow(t, scale)
  4227  	}
  4228  }
  4229  
  4230  func TestReceivedValidSegmentCountIncrement(t *testing.T) {
  4231  	c := context.New(t, defaultMTU)
  4232  	defer c.Cleanup()
  4233  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  4234  	stats := c.Stack().Stats()
  4235  	want := stats.TCP.ValidSegmentsReceived.Value() + 1
  4236  
  4237  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4238  	c.SendPacket(nil, &context.Headers{
  4239  		SrcPort: context.TestPort,
  4240  		DstPort: c.Port,
  4241  		Flags:   header.TCPFlagAck,
  4242  		SeqNum:  iss,
  4243  		AckNum:  c.IRS.Add(1),
  4244  		RcvWnd:  30000,
  4245  	})
  4246  
  4247  	if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
  4248  		t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
  4249  	}
  4250  	if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
  4251  		t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
  4252  	}
  4253  	// Ensure there were no errors during handshake. If these stats have
  4254  	// incremented, then the connection should not have been established.
  4255  	if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
  4256  		t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
  4257  	}
  4258  }
  4259  
  4260  func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
  4261  	c := context.New(t, defaultMTU)
  4262  	defer c.Cleanup()
  4263  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  4264  	stats := c.Stack().Stats()
  4265  	want := stats.TCP.InvalidSegmentsReceived.Value() + 1
  4266  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4267  	vv := c.BuildSegment(nil, &context.Headers{
  4268  		SrcPort: context.TestPort,
  4269  		DstPort: c.Port,
  4270  		Flags:   header.TCPFlagAck,
  4271  		SeqNum:  iss,
  4272  		AckNum:  c.IRS.Add(1),
  4273  		RcvWnd:  30000,
  4274  	})
  4275  	tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
  4276  	tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
  4277  
  4278  	c.SendSegment(vv)
  4279  
  4280  	if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
  4281  		t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
  4282  	}
  4283  	if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
  4284  		t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
  4285  	}
  4286  }
  4287  
  4288  func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
  4289  	c := context.New(t, defaultMTU)
  4290  	defer c.Cleanup()
  4291  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  4292  	stats := c.Stack().Stats()
  4293  	want := stats.TCP.ChecksumErrors.Value() + 1
  4294  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4295  	vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
  4296  		SrcPort: context.TestPort,
  4297  		DstPort: c.Port,
  4298  		Flags:   header.TCPFlagAck,
  4299  		SeqNum:  iss,
  4300  		AckNum:  c.IRS.Add(1),
  4301  		RcvWnd:  30000,
  4302  	})
  4303  	tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
  4304  	// Overwrite a byte in the payload which should cause checksum
  4305  	// verification to fail.
  4306  	tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
  4307  
  4308  	c.SendSegment(vv)
  4309  
  4310  	if got := stats.TCP.ChecksumErrors.Value(); got != want {
  4311  		t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
  4312  	}
  4313  	if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  4314  		t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
  4315  	}
  4316  }
  4317  
  4318  func TestReceivedSegmentQueuing(t *testing.T) {
  4319  	// This test sends 200 segments containing a few bytes each to an
  4320  	// endpoint and checks that they're all received and acknowledged by
  4321  	// the endpoint, that is, that none of the segments are dropped by
  4322  	// internal queues.
  4323  	c := context.New(t, defaultMTU)
  4324  	defer c.Cleanup()
  4325  
  4326  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  4327  
  4328  	// Send 200 segments.
  4329  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4330  	data := []byte{1, 2, 3}
  4331  	for i := 0; i < 200; i++ {
  4332  		c.SendPacket(data, &context.Headers{
  4333  			SrcPort: context.TestPort,
  4334  			DstPort: c.Port,
  4335  			Flags:   header.TCPFlagAck,
  4336  			SeqNum:  iss.Add(seqnum.Size(i * len(data))),
  4337  			AckNum:  c.IRS.Add(1),
  4338  			RcvWnd:  30000,
  4339  		})
  4340  	}
  4341  
  4342  	// Receive ACKs for all segments.
  4343  	last := iss.Add(seqnum.Size(200 * len(data)))
  4344  	for {
  4345  		b := c.GetPacket()
  4346  		checker.IPv4(t, b,
  4347  			checker.TCP(
  4348  				checker.DstPort(context.TestPort),
  4349  				checker.TCPSeqNum(uint32(c.IRS)+1),
  4350  				checker.TCPFlags(header.TCPFlagAck),
  4351  			),
  4352  		)
  4353  		tcpHdr := header.TCP(header.IPv4(b).Payload())
  4354  		ack := seqnum.Value(tcpHdr.AckNumber())
  4355  		if ack == last {
  4356  			break
  4357  		}
  4358  
  4359  		if last.LessThan(ack) {
  4360  			t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last)
  4361  		}
  4362  	}
  4363  }
  4364  
  4365  func TestReadAfterClosedState(t *testing.T) {
  4366  	// This test ensures that calling Read() or Peek() after the endpoint
  4367  	// has transitioned to closedState still works if there is pending
  4368  	// data. To transition to stateClosed without calling Close(), we must
  4369  	// shutdown the send path and the peer must send its own FIN.
  4370  	c := context.New(t, defaultMTU)
  4371  	defer c.Cleanup()
  4372  
  4373  	// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
  4374  	// after 1 second in TIME_WAIT state.
  4375  	tcpTimeWaitTimeout := 1 * time.Second
  4376  	opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
  4377  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  4378  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
  4379  	}
  4380  
  4381  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  4382  
  4383  	we, ch := waiter.NewChannelEntry(nil)
  4384  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  4385  	defer c.WQ.EventUnregister(&we)
  4386  
  4387  	ept := endpointTester{c.EP}
  4388  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  4389  
  4390  	// Shutdown immediately for write, check that we get a FIN.
  4391  	if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  4392  		t.Fatalf("Shutdown failed: %s", err)
  4393  	}
  4394  
  4395  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  4396  	checker.IPv4(t, c.GetPacket(),
  4397  		checker.PayloadLen(header.TCPMinimumSize),
  4398  		checker.TCP(
  4399  			checker.DstPort(context.TestPort),
  4400  			checker.TCPSeqNum(uint32(c.IRS)+1),
  4401  			checker.TCPAckNum(uint32(iss)),
  4402  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
  4403  		),
  4404  	)
  4405  
  4406  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
  4407  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  4408  	}
  4409  
  4410  	// Send some data and acknowledge the FIN.
  4411  	data := []byte{1, 2, 3}
  4412  	c.SendPacket(data, &context.Headers{
  4413  		SrcPort: context.TestPort,
  4414  		DstPort: c.Port,
  4415  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  4416  		SeqNum:  iss,
  4417  		AckNum:  c.IRS.Add(2),
  4418  		RcvWnd:  30000,
  4419  	})
  4420  
  4421  	// Check that ACK is received.
  4422  	checker.IPv4(t, c.GetPacket(),
  4423  		checker.TCP(
  4424  			checker.DstPort(context.TestPort),
  4425  			checker.TCPSeqNum(uint32(c.IRS)+2),
  4426  			checker.TCPAckNum(uint32(iss)+uint32(len(data))+1),
  4427  			checker.TCPFlags(header.TCPFlagAck),
  4428  		),
  4429  	)
  4430  
  4431  	// Give the stack the chance to transition to closed state from
  4432  	// TIME_WAIT.
  4433  	time.Sleep(tcpTimeWaitTimeout * 2)
  4434  
  4435  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
  4436  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  4437  	}
  4438  
  4439  	// Wait for receive to be notified.
  4440  	select {
  4441  	case <-ch:
  4442  	case <-time.After(1 * time.Second):
  4443  		t.Fatalf("Timed out waiting for data to arrive")
  4444  	}
  4445  
  4446  	// Check that peek works.
  4447  	var peekBuf bytes.Buffer
  4448  	res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true})
  4449  	if err != nil {
  4450  		t.Fatalf("Peek failed: %s", err)
  4451  	}
  4452  
  4453  	if got, want := res.Count, len(data); got != want {
  4454  		t.Fatalf("res.Count = %d, want %d", got, want)
  4455  	}
  4456  	if !bytes.Equal(data, peekBuf.Bytes()) {
  4457  		t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
  4458  	}
  4459  
  4460  	// Receive data.
  4461  	v := ept.CheckRead(t)
  4462  	if !bytes.Equal(data, v) {
  4463  		t.Fatalf("got data = %v, want = %v", v, data)
  4464  	}
  4465  
  4466  	// Now that we drained the queue, check that functions fail with the
  4467  	// right error code.
  4468  	ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
  4469  	var buf bytes.Buffer
  4470  	{
  4471  		_, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
  4472  		if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" {
  4473  			t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d)
  4474  		}
  4475  	}
  4476  }
  4477  
  4478  func TestReusePort(t *testing.T) {
  4479  	// This test ensures that ports are immediately available for reuse
  4480  	// after Close on the endpoints using them returns.
  4481  	c := context.New(t, defaultMTU)
  4482  	defer c.Cleanup()
  4483  
  4484  	// First case, just an endpoint that was bound.
  4485  	var err tcpip.Error
  4486  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4487  	if err != nil {
  4488  		t.Fatalf("NewEndpoint failed; %s", err)
  4489  	}
  4490  	c.EP.SocketOptions().SetReuseAddress(true)
  4491  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4492  		t.Fatalf("Bind failed: %s", err)
  4493  	}
  4494  
  4495  	c.EP.Close()
  4496  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4497  	if err != nil {
  4498  		t.Fatalf("NewEndpoint failed; %s", err)
  4499  	}
  4500  	c.EP.SocketOptions().SetReuseAddress(true)
  4501  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4502  		t.Fatalf("Bind failed: %s", err)
  4503  	}
  4504  	c.EP.Close()
  4505  
  4506  	// Second case, an endpoint that was bound and is connecting..
  4507  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4508  	if err != nil {
  4509  		t.Fatalf("NewEndpoint failed; %s", err)
  4510  	}
  4511  	c.EP.SocketOptions().SetReuseAddress(true)
  4512  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4513  		t.Fatalf("Bind failed: %s", err)
  4514  	}
  4515  	{
  4516  		err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
  4517  		if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
  4518  			t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
  4519  		}
  4520  	}
  4521  	c.EP.Close()
  4522  
  4523  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4524  	if err != nil {
  4525  		t.Fatalf("NewEndpoint failed; %s", err)
  4526  	}
  4527  	c.EP.SocketOptions().SetReuseAddress(true)
  4528  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4529  		t.Fatalf("Bind failed: %s", err)
  4530  	}
  4531  	c.EP.Close()
  4532  
  4533  	// Third case, an endpoint that was bound and is listening.
  4534  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4535  	if err != nil {
  4536  		t.Fatalf("NewEndpoint failed; %s", err)
  4537  	}
  4538  	c.EP.SocketOptions().SetReuseAddress(true)
  4539  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4540  		t.Fatalf("Bind failed: %s", err)
  4541  	}
  4542  	if err := c.EP.Listen(10); err != nil {
  4543  		t.Fatalf("Listen failed: %s", err)
  4544  	}
  4545  	c.EP.Close()
  4546  
  4547  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4548  	if err != nil {
  4549  		t.Fatalf("NewEndpoint failed; %s", err)
  4550  	}
  4551  	c.EP.SocketOptions().SetReuseAddress(true)
  4552  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4553  		t.Fatalf("Bind failed: %s", err)
  4554  	}
  4555  	if err := c.EP.Listen(10); err != nil {
  4556  		t.Fatalf("Listen failed: %s", err)
  4557  	}
  4558  }
  4559  
  4560  func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
  4561  	t.Helper()
  4562  
  4563  	s := ep.SocketOptions().GetReceiveBufferSize()
  4564  	if int(s) != v {
  4565  		t.Fatalf("got receive buffer size = %d, want = %d", s, v)
  4566  	}
  4567  }
  4568  
  4569  func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
  4570  	t.Helper()
  4571  
  4572  	if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v {
  4573  		t.Fatalf("got send buffer size = %d, want = %d", s, v)
  4574  	}
  4575  }
  4576  
  4577  func TestDefaultBufferSizes(t *testing.T) {
  4578  	s := stack.New(stack.Options{
  4579  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
  4580  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
  4581  	})
  4582  
  4583  	// Check the default values.
  4584  	ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4585  	if err != nil {
  4586  		t.Fatalf("NewEndpoint failed; %s", err)
  4587  	}
  4588  	defer func() {
  4589  		if ep != nil {
  4590  			ep.Close()
  4591  		}
  4592  	}()
  4593  
  4594  	checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
  4595  	checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
  4596  
  4597  	// Change the default send buffer size.
  4598  	{
  4599  		opt := tcpip.TCPSendBufferSizeRangeOption{
  4600  			Min:     1,
  4601  			Default: tcp.DefaultSendBufferSize * 2,
  4602  			Max:     tcp.DefaultSendBufferSize * 20,
  4603  		}
  4604  		if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  4605  			t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  4606  		}
  4607  	}
  4608  
  4609  	ep.Close()
  4610  	ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4611  	if err != nil {
  4612  		t.Fatalf("NewEndpoint failed; %s", err)
  4613  	}
  4614  
  4615  	checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
  4616  	checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
  4617  
  4618  	// Change the default receive buffer size.
  4619  	{
  4620  		opt := tcpip.TCPReceiveBufferSizeRangeOption{
  4621  			Min:     1,
  4622  			Default: tcp.DefaultReceiveBufferSize * 3,
  4623  			Max:     tcp.DefaultReceiveBufferSize * 30,
  4624  		}
  4625  		if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  4626  			t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  4627  		}
  4628  	}
  4629  
  4630  	ep.Close()
  4631  	ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4632  	if err != nil {
  4633  		t.Fatalf("NewEndpoint failed; %s", err)
  4634  	}
  4635  
  4636  	checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
  4637  	checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
  4638  }
  4639  
  4640  func TestMinMaxBufferSizes(t *testing.T) {
  4641  	s := stack.New(stack.Options{
  4642  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
  4643  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
  4644  	})
  4645  
  4646  	// Check the default values.
  4647  	ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4648  	if err != nil {
  4649  		t.Fatalf("NewEndpoint failed; %s", err)
  4650  	}
  4651  	defer ep.Close()
  4652  
  4653  	// Change the min/max values for send/receive
  4654  	{
  4655  		opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}
  4656  		if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  4657  			t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  4658  		}
  4659  	}
  4660  
  4661  	{
  4662  		opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}
  4663  		if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  4664  			t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  4665  		}
  4666  	}
  4667  
  4668  	// Set values below the min/2.
  4669  	ep.SocketOptions().SetReceiveBufferSize(99, true)
  4670  	checkRecvBufferSize(t, ep, 200)
  4671  
  4672  	ep.SocketOptions().SetSendBufferSize(149, true)
  4673  
  4674  	checkSendBufferSize(t, ep, 300)
  4675  
  4676  	// Set values above the max.
  4677  	ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true)
  4678  	// Values above max are capped at max and then doubled.
  4679  	checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
  4680  
  4681  	ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
  4682  	// Values above max are capped at max and then doubled.
  4683  	checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
  4684  }
  4685  
  4686  func TestBindToDeviceOption(t *testing.T) {
  4687  	s := stack.New(stack.Options{
  4688  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
  4689  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}})
  4690  
  4691  	ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
  4692  	if err != nil {
  4693  		t.Fatalf("NewEndpoint failed; %s", err)
  4694  	}
  4695  	defer ep.Close()
  4696  
  4697  	if err := s.CreateNIC(321, loopback.New()); err != nil {
  4698  		t.Errorf("CreateNIC failed: %s", err)
  4699  	}
  4700  
  4701  	// nicIDPtr is used instead of taking the address of NICID literals, which is
  4702  	// a compiler error.
  4703  	nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
  4704  		return &s
  4705  	}
  4706  
  4707  	testActions := []struct {
  4708  		name                 string
  4709  		setBindToDevice      *tcpip.NICID
  4710  		setBindToDeviceError tcpip.Error
  4711  		getBindToDevice      int32
  4712  	}{
  4713  		{"GetDefaultValue", nil, nil, 0},
  4714  		{"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
  4715  		{"BindToExistent", nicIDPtr(321), nil, 321},
  4716  		{"UnbindToDevice", nicIDPtr(0), nil, 0},
  4717  	}
  4718  	for _, testAction := range testActions {
  4719  		t.Run(testAction.name, func(t *testing.T) {
  4720  			if testAction.setBindToDevice != nil {
  4721  				bindToDevice := int32(*testAction.setBindToDevice)
  4722  				if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
  4723  					t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
  4724  				}
  4725  			}
  4726  			bindToDevice := ep.SocketOptions().GetBindToDevice()
  4727  			if bindToDevice != testAction.getBindToDevice {
  4728  				t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
  4729  			}
  4730  		})
  4731  	}
  4732  }
  4733  
  4734  func makeStack() (*stack.Stack, tcpip.Error) {
  4735  	s := stack.New(stack.Options{
  4736  		NetworkProtocols: []stack.NetworkProtocolFactory{
  4737  			ipv4.NewProtocol,
  4738  			ipv6.NewProtocol,
  4739  		},
  4740  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
  4741  	})
  4742  
  4743  	id := loopback.New()
  4744  	if testing.Verbose() {
  4745  		id = sniffer.New(id)
  4746  	}
  4747  
  4748  	if err := s.CreateNIC(1, id); err != nil {
  4749  		return nil, err
  4750  	}
  4751  
  4752  	for _, ct := range []struct {
  4753  		number  tcpip.NetworkProtocolNumber
  4754  		address tcpip.Address
  4755  	}{
  4756  		{ipv4.ProtocolNumber, context.StackAddr},
  4757  		{ipv6.ProtocolNumber, context.StackV6Addr},
  4758  	} {
  4759  		if err := s.AddAddress(1, ct.number, ct.address); err != nil {
  4760  			return nil, err
  4761  		}
  4762  	}
  4763  
  4764  	s.SetRouteTable([]tcpip.Route{
  4765  		{
  4766  			Destination: header.IPv4EmptySubnet,
  4767  			NIC:         1,
  4768  		},
  4769  		{
  4770  			Destination: header.IPv6EmptySubnet,
  4771  			NIC:         1,
  4772  		},
  4773  	})
  4774  
  4775  	return s, nil
  4776  }
  4777  
  4778  func TestSelfConnect(t *testing.T) {
  4779  	// This test ensures that intentional self-connects work. In particular,
  4780  	// it checks that if an endpoint binds to say 127.0.0.1:1000 then
  4781  	// connects to 127.0.0.1:1000, then it will be connected to itself, and
  4782  	// is able to send and receive data through the same endpoint.
  4783  	s, err := makeStack()
  4784  	if err != nil {
  4785  		t.Fatal(err)
  4786  	}
  4787  
  4788  	var wq waiter.Queue
  4789  	ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
  4790  	if err != nil {
  4791  		t.Fatalf("NewEndpoint failed: %s", err)
  4792  	}
  4793  	defer ep.Close()
  4794  
  4795  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  4796  		t.Fatalf("Bind failed: %s", err)
  4797  	}
  4798  
  4799  	// Register for notification, then start connection attempt.
  4800  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
  4801  	wq.EventRegister(&waitEntry, waiter.WritableEvents)
  4802  	defer wq.EventUnregister(&waitEntry)
  4803  
  4804  	{
  4805  		err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
  4806  		if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
  4807  			t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
  4808  		}
  4809  	}
  4810  
  4811  	<-notifyCh
  4812  	if err := ep.LastError(); err != nil {
  4813  		t.Fatalf("Connect failed: %s", err)
  4814  	}
  4815  
  4816  	// Write something.
  4817  	data := []byte{1, 2, 3}
  4818  	var r bytes.Reader
  4819  	r.Reset(data)
  4820  	if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
  4821  		t.Fatalf("Write failed: %s", err)
  4822  	}
  4823  
  4824  	// Read back what was written.
  4825  	wq.EventUnregister(&waitEntry)
  4826  	wq.EventRegister(&waitEntry, waiter.ReadableEvents)
  4827  	ept := endpointTester{ep}
  4828  	rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
  4829  
  4830  	if !bytes.Equal(data, rd) {
  4831  		t.Fatalf("got data = %v, want = %v", rd, data)
  4832  	}
  4833  }
  4834  
  4835  func TestConnectAvoidsBoundPorts(t *testing.T) {
  4836  	addressTypes := func(t *testing.T, network string) []string {
  4837  		switch network {
  4838  		case "ipv4":
  4839  			return []string{"v4"}
  4840  		case "ipv6":
  4841  			return []string{"v6"}
  4842  		case "dual":
  4843  			return []string{"v6", "mapped"}
  4844  		default:
  4845  			t.Fatalf("unknown network: '%s'", network)
  4846  		}
  4847  
  4848  		panic("unreachable")
  4849  	}
  4850  
  4851  	address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
  4852  		switch addressType {
  4853  		case "v4":
  4854  			if isAny {
  4855  				return ""
  4856  			}
  4857  			return context.StackAddr
  4858  		case "v6":
  4859  			if isAny {
  4860  				return ""
  4861  			}
  4862  			return context.StackV6Addr
  4863  		case "mapped":
  4864  			if isAny {
  4865  				return context.V4MappedWildcardAddr
  4866  			}
  4867  			return context.StackV4MappedAddr
  4868  		default:
  4869  			t.Fatalf("unknown address type: '%s'", addressType)
  4870  		}
  4871  
  4872  		panic("unreachable")
  4873  	}
  4874  	// This test ensures that Endpoint.Connect doesn't select already-bound ports.
  4875  	networks := []string{"ipv4", "ipv6", "dual"}
  4876  	for _, exhaustedNetwork := range networks {
  4877  		t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
  4878  			for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
  4879  				t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
  4880  					for _, isAny := range []bool{false, true} {
  4881  						t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
  4882  							for _, candidateNetwork := range networks {
  4883  								t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
  4884  									for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
  4885  										t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
  4886  											s, err := makeStack()
  4887  											if err != nil {
  4888  												t.Fatal(err)
  4889  											}
  4890  
  4891  											var wq waiter.Queue
  4892  											var eps []tcpip.Endpoint
  4893  											defer func() {
  4894  												for _, ep := range eps {
  4895  													ep.Close()
  4896  												}
  4897  											}()
  4898  											makeEP := func(network string) tcpip.Endpoint {
  4899  												var networkProtocolNumber tcpip.NetworkProtocolNumber
  4900  												switch network {
  4901  												case "ipv4":
  4902  													networkProtocolNumber = ipv4.ProtocolNumber
  4903  												case "ipv6", "dual":
  4904  													networkProtocolNumber = ipv6.ProtocolNumber
  4905  												default:
  4906  													t.Fatalf("unknown network: '%s'", network)
  4907  												}
  4908  												ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
  4909  												if err != nil {
  4910  													t.Fatalf("NewEndpoint failed: %s", err)
  4911  												}
  4912  												eps = append(eps, ep)
  4913  												switch network {
  4914  												case "ipv4":
  4915  												case "ipv6":
  4916  													ep.SocketOptions().SetV6Only(true)
  4917  												case "dual":
  4918  													ep.SocketOptions().SetV6Only(false)
  4919  												default:
  4920  													t.Fatalf("unknown network: '%s'", network)
  4921  												}
  4922  												return ep
  4923  											}
  4924  
  4925  											var v4reserved, v6reserved bool
  4926  											switch exhaustedAddressType {
  4927  											case "v4", "mapped":
  4928  												v4reserved = true
  4929  											case "v6":
  4930  												v6reserved = true
  4931  												// Dual stack sockets bound to v6 any reserve on v4 as
  4932  												// well.
  4933  												if isAny {
  4934  													switch exhaustedNetwork {
  4935  													case "ipv6":
  4936  													case "dual":
  4937  														v4reserved = true
  4938  													default:
  4939  														t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
  4940  													}
  4941  												}
  4942  											default:
  4943  												t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
  4944  											}
  4945  											var collides bool
  4946  											switch candidateAddressType {
  4947  											case "v4", "mapped":
  4948  												collides = v4reserved
  4949  											case "v6":
  4950  												collides = v6reserved
  4951  											default:
  4952  												t.Fatalf("unknown address type: '%s'", candidateAddressType)
  4953  											}
  4954  
  4955  											const (
  4956  												start = 16000
  4957  												end   = 16050
  4958  											)
  4959  											if err := s.SetPortRange(start, end); err != nil {
  4960  												t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
  4961  											}
  4962  											for i := start; i <= end; i++ {
  4963  												if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
  4964  													t.Fatalf("Bind(%d) failed: %s", i, err)
  4965  												}
  4966  											}
  4967  											var want tcpip.Error = &tcpip.ErrConnectStarted{}
  4968  											if collides {
  4969  												want = &tcpip.ErrNoPortAvailable{}
  4970  											}
  4971  											if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
  4972  												t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
  4973  											}
  4974  										})
  4975  									}
  4976  								})
  4977  							}
  4978  						})
  4979  					}
  4980  				})
  4981  			}
  4982  		})
  4983  	}
  4984  }
  4985  
  4986  func TestPathMTUDiscovery(t *testing.T) {
  4987  	// This test verifies the stack retransmits packets after it receives an
  4988  	// ICMP packet indicating that the path MTU has been exceeded.
  4989  	c := context.New(t, 1500)
  4990  	defer c.Cleanup()
  4991  
  4992  	// Create new connection with MSS of 1460.
  4993  	const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
  4994  	c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
  4995  		header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
  4996  	})
  4997  
  4998  	// Send 3200 bytes of data.
  4999  	const writeSize = 3200
  5000  	data := make([]byte, writeSize)
  5001  	for i := range data {
  5002  		data[i] = byte(i)
  5003  	}
  5004  	var r bytes.Reader
  5005  	r.Reset(data)
  5006  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  5007  		t.Fatalf("Write failed: %s", err)
  5008  	}
  5009  
  5010  	receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
  5011  		var ret []byte
  5012  		iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  5013  		for i, size := range sizes {
  5014  			p := c.GetPacket()
  5015  			if i == which {
  5016  				ret = p
  5017  			}
  5018  			checker.IPv4(t, p,
  5019  				checker.PayloadLen(size+header.TCPMinimumSize),
  5020  				checker.TCP(
  5021  					checker.DstPort(context.TestPort),
  5022  					checker.TCPSeqNum(seqNum),
  5023  					checker.TCPAckNum(uint32(iss)),
  5024  					checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  5025  				),
  5026  			)
  5027  			seqNum += uint32(size)
  5028  		}
  5029  		return ret
  5030  	}
  5031  
  5032  	// Receive three packets.
  5033  	sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
  5034  	first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
  5035  
  5036  	// Send "packet too big" messages back to netstack.
  5037  	const newMTU = 1200
  5038  	const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
  5039  	mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
  5040  	c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
  5041  
  5042  	// See retransmitted packets. None exceeding the new max.
  5043  	sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
  5044  	receivePackets(c, sizes, -1, uint32(c.IRS)+1)
  5045  }
  5046  
  5047  func TestTCPEndpointProbe(t *testing.T) {
  5048  	c := context.New(t, 1500)
  5049  	defer c.Cleanup()
  5050  
  5051  	invoked := make(chan struct{})
  5052  	c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
  5053  		// Validate that the endpoint ID is what we expect.
  5054  		//
  5055  		// We don't do an extensive validation of every field but a
  5056  		// basic sanity test.
  5057  		if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
  5058  			t.Fatalf("got LocalAddress: %q, want: %q", got, want)
  5059  		}
  5060  		if got, want := state.ID.LocalPort, c.Port; got != want {
  5061  			t.Fatalf("got LocalPort: %d, want: %d", got, want)
  5062  		}
  5063  		if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
  5064  			t.Fatalf("got RemoteAddress: %q, want: %q", got, want)
  5065  		}
  5066  		if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
  5067  			t.Fatalf("got RemotePort: %d, want: %d", got, want)
  5068  		}
  5069  
  5070  		invoked <- struct{}{}
  5071  	})
  5072  
  5073  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  5074  
  5075  	data := []byte{1, 2, 3}
  5076  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  5077  	c.SendPacket(data, &context.Headers{
  5078  		SrcPort: context.TestPort,
  5079  		DstPort: c.Port,
  5080  		Flags:   header.TCPFlagAck,
  5081  		SeqNum:  iss,
  5082  		AckNum:  c.IRS.Add(1),
  5083  		RcvWnd:  30000,
  5084  	})
  5085  
  5086  	select {
  5087  	case <-invoked:
  5088  	case <-time.After(100 * time.Millisecond):
  5089  		t.Fatalf("TCP Probe function was not called")
  5090  	}
  5091  }
  5092  
  5093  func TestStackSetCongestionControl(t *testing.T) {
  5094  	testCases := []struct {
  5095  		cc  tcpip.CongestionControlOption
  5096  		err tcpip.Error
  5097  	}{
  5098  		{"reno", nil},
  5099  		{"cubic", nil},
  5100  		{"blahblah", &tcpip.ErrNoSuchFile{}},
  5101  	}
  5102  
  5103  	for _, tc := range testCases {
  5104  		t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
  5105  			c := context.New(t, 1500)
  5106  			defer c.Cleanup()
  5107  
  5108  			s := c.Stack()
  5109  
  5110  			var oldCC tcpip.CongestionControlOption
  5111  			if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
  5112  				t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
  5113  			}
  5114  
  5115  			if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err {
  5116  				t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err)
  5117  			}
  5118  
  5119  			var cc tcpip.CongestionControlOption
  5120  			if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
  5121  				t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
  5122  			}
  5123  
  5124  			got, want := cc, oldCC
  5125  			// If SetTransportProtocolOption is expected to succeed
  5126  			// then the returned value for congestion control should
  5127  			// match the one specified in the
  5128  			// SetTransportProtocolOption call above, else it should
  5129  			// be what it was before the call to
  5130  			// SetTransportProtocolOption.
  5131  			if tc.err == nil {
  5132  				want = tc.cc
  5133  			}
  5134  			if got != want {
  5135  				t.Fatalf("got congestion control: %v, want: %v", got, want)
  5136  			}
  5137  		})
  5138  	}
  5139  }
  5140  
  5141  func TestStackAvailableCongestionControl(t *testing.T) {
  5142  	c := context.New(t, 1500)
  5143  	defer c.Cleanup()
  5144  
  5145  	s := c.Stack()
  5146  
  5147  	// Query permitted congestion control algorithms.
  5148  	var aCC tcpip.TCPAvailableCongestionControlOption
  5149  	if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
  5150  		t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
  5151  	}
  5152  	if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
  5153  		t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want)
  5154  	}
  5155  }
  5156  
  5157  func TestStackSetAvailableCongestionControl(t *testing.T) {
  5158  	c := context.New(t, 1500)
  5159  	defer c.Cleanup()
  5160  
  5161  	s := c.Stack()
  5162  
  5163  	// Setting AvailableCongestionControlOption should fail.
  5164  	aCC := tcpip.TCPAvailableCongestionControlOption("xyz")
  5165  	if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
  5166  		t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC)
  5167  	}
  5168  
  5169  	// Verify that we still get the expected list of congestion control options.
  5170  	var cc tcpip.TCPAvailableCongestionControlOption
  5171  	if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
  5172  		t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err)
  5173  	}
  5174  	if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
  5175  		t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want)
  5176  	}
  5177  }
  5178  
  5179  func TestEndpointSetCongestionControl(t *testing.T) {
  5180  	testCases := []struct {
  5181  		cc  tcpip.CongestionControlOption
  5182  		err tcpip.Error
  5183  	}{
  5184  		{"reno", nil},
  5185  		{"cubic", nil},
  5186  		{"blahblah", &tcpip.ErrNoSuchFile{}},
  5187  	}
  5188  
  5189  	for _, connected := range []bool{false, true} {
  5190  		for _, tc := range testCases {
  5191  			t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
  5192  				c := context.New(t, 1500)
  5193  				defer c.Cleanup()
  5194  
  5195  				// Create TCP endpoint.
  5196  				var err tcpip.Error
  5197  				c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  5198  				if err != nil {
  5199  					t.Fatalf("NewEndpoint failed: %s", err)
  5200  				}
  5201  
  5202  				var oldCC tcpip.CongestionControlOption
  5203  				if err := c.EP.GetSockOpt(&oldCC); err != nil {
  5204  					t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err)
  5205  				}
  5206  
  5207  				if connected {
  5208  					c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil)
  5209  				}
  5210  
  5211  				if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
  5212  					t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err)
  5213  				}
  5214  
  5215  				var cc tcpip.CongestionControlOption
  5216  				if err := c.EP.GetSockOpt(&cc); err != nil {
  5217  					t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err)
  5218  				}
  5219  
  5220  				got, want := cc, oldCC
  5221  				// If SetSockOpt is expected to succeed then the
  5222  				// returned value for congestion control should match
  5223  				// the one specified in the SetSockOpt above, else it
  5224  				// should be what it was before the call to SetSockOpt.
  5225  				if tc.err == nil {
  5226  					want = tc.cc
  5227  				}
  5228  				if got != want {
  5229  					t.Fatalf("got congestion control = %+v, want = %+v", got, want)
  5230  				}
  5231  			})
  5232  		}
  5233  	}
  5234  }
  5235  
  5236  func enableCUBIC(t *testing.T, c *context.Context) {
  5237  	t.Helper()
  5238  	opt := tcpip.CongestionControlOption("cubic")
  5239  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  5240  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err)
  5241  	}
  5242  }
  5243  
  5244  func TestKeepalive(t *testing.T) {
  5245  	c := context.New(t, defaultMTU)
  5246  	defer c.Cleanup()
  5247  
  5248  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  5249  
  5250  	const keepAliveIdle = 100 * time.Millisecond
  5251  	const keepAliveInterval = 3 * time.Second
  5252  	keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle)
  5253  	if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil {
  5254  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err)
  5255  	}
  5256  	keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval)
  5257  	if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil {
  5258  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err)
  5259  	}
  5260  	c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
  5261  	if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
  5262  		t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
  5263  	}
  5264  	c.EP.SocketOptions().SetKeepAlive(true)
  5265  
  5266  	// 5 unacked keepalives are sent. ACK each one, and check that the
  5267  	// connection stays alive after 5.
  5268  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  5269  	for i := 0; i < 10; i++ {
  5270  		b := c.GetPacket()
  5271  		checker.IPv4(t, b,
  5272  			checker.TCP(
  5273  				checker.DstPort(context.TestPort),
  5274  				checker.TCPSeqNum(uint32(c.IRS)),
  5275  				checker.TCPAckNum(uint32(iss)),
  5276  				checker.TCPFlags(header.TCPFlagAck),
  5277  			),
  5278  		)
  5279  
  5280  		// Acknowledge the keepalive.
  5281  		c.SendPacket(nil, &context.Headers{
  5282  			SrcPort: context.TestPort,
  5283  			DstPort: c.Port,
  5284  			Flags:   header.TCPFlagAck,
  5285  			SeqNum:  iss,
  5286  			AckNum:  c.IRS,
  5287  			RcvWnd:  30000,
  5288  		})
  5289  	}
  5290  
  5291  	// Check that the connection is still alive.
  5292  	ept := endpointTester{c.EP}
  5293  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  5294  
  5295  	// Send some data and wait before ACKing it. Keepalives should be disabled
  5296  	// during this period.
  5297  	view := make([]byte, 3)
  5298  	var r bytes.Reader
  5299  	r.Reset(view)
  5300  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  5301  		t.Fatalf("Write failed: %s", err)
  5302  	}
  5303  
  5304  	next := uint32(c.IRS) + 1
  5305  	checker.IPv4(t, c.GetPacket(),
  5306  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  5307  		checker.TCP(
  5308  			checker.DstPort(context.TestPort),
  5309  			checker.TCPSeqNum(next),
  5310  			checker.TCPAckNum(uint32(iss)),
  5311  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  5312  		),
  5313  	)
  5314  
  5315  	// Wait for the packet to be retransmitted. Verify that no keepalives
  5316  	// were sent.
  5317  	checker.IPv4(t, c.GetPacket(),
  5318  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  5319  		checker.TCP(
  5320  			checker.DstPort(context.TestPort),
  5321  			checker.TCPSeqNum(next),
  5322  			checker.TCPAckNum(uint32(iss)),
  5323  			checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
  5324  		),
  5325  	)
  5326  	c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
  5327  
  5328  	next += uint32(len(view))
  5329  
  5330  	// Send ACK. Keepalives should start sending again.
  5331  	c.SendPacket(nil, &context.Headers{
  5332  		SrcPort: context.TestPort,
  5333  		DstPort: c.Port,
  5334  		Flags:   header.TCPFlagAck,
  5335  		SeqNum:  iss,
  5336  		AckNum:  seqnum.Value(next),
  5337  		RcvWnd:  30000,
  5338  	})
  5339  
  5340  	// Now receive 5 keepalives, but don't ACK them. The connection
  5341  	// should be reset after 5.
  5342  	for i := 0; i < 5; i++ {
  5343  		b := c.GetPacket()
  5344  		checker.IPv4(t, b,
  5345  			checker.TCP(
  5346  				checker.DstPort(context.TestPort),
  5347  				checker.TCPSeqNum(next-1),
  5348  				checker.TCPAckNum(uint32(iss)),
  5349  				checker.TCPFlags(header.TCPFlagAck),
  5350  			),
  5351  		)
  5352  	}
  5353  
  5354  	// Sleep for a litte over the KeepAlive interval to make sure
  5355  	// the timer has time to fire after the last ACK and close the
  5356  	// close the socket.
  5357  	time.Sleep(keepAliveInterval + keepAliveInterval/2)
  5358  
  5359  	// The connection should be terminated after 5 unacked keepalives.
  5360  	// Send an ACK to trigger a RST from the stack as the endpoint should
  5361  	// be dead.
  5362  	c.SendPacket(nil, &context.Headers{
  5363  		SrcPort: context.TestPort,
  5364  		DstPort: c.Port,
  5365  		Flags:   header.TCPFlagAck,
  5366  		SeqNum:  iss,
  5367  		AckNum:  seqnum.Value(next),
  5368  		RcvWnd:  30000,
  5369  	})
  5370  
  5371  	checker.IPv4(t, c.GetPacket(),
  5372  		checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)),
  5373  	)
  5374  
  5375  	if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
  5376  		t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
  5377  	}
  5378  
  5379  	ept.CheckReadError(t, &tcpip.ErrTimeout{})
  5380  
  5381  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
  5382  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
  5383  	}
  5384  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
  5385  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
  5386  	}
  5387  }
  5388  
  5389  func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
  5390  	t.Helper()
  5391  	// Send a SYN request.
  5392  	irs = seqnum.Value(context.TestInitialSequenceNumber)
  5393  	c.SendPacket(nil, &context.Headers{
  5394  		SrcPort: srcPort,
  5395  		DstPort: context.StackPort,
  5396  		Flags:   header.TCPFlagSyn,
  5397  		SeqNum:  irs,
  5398  		RcvWnd:  30000,
  5399  	})
  5400  
  5401  	// Receive the SYN-ACK reply.
  5402  	b := c.GetPacket()
  5403  	tcp := header.TCP(header.IPv4(b).Payload())
  5404  	iss = seqnum.Value(tcp.SequenceNumber())
  5405  	tcpCheckers := []checker.TransportChecker{
  5406  		checker.SrcPort(context.StackPort),
  5407  		checker.DstPort(srcPort),
  5408  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
  5409  		checker.TCPAckNum(uint32(irs) + 1),
  5410  	}
  5411  
  5412  	if synCookieInUse {
  5413  		// When cookies are in use window scaling is disabled.
  5414  		tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
  5415  			WS:  -1,
  5416  			MSS: c.MSSWithoutOptions(),
  5417  		}))
  5418  	}
  5419  
  5420  	checker.IPv4(t, b, checker.TCP(tcpCheckers...))
  5421  
  5422  	// Send ACK.
  5423  	c.SendPacket(nil, &context.Headers{
  5424  		SrcPort: srcPort,
  5425  		DstPort: context.StackPort,
  5426  		Flags:   header.TCPFlagAck,
  5427  		SeqNum:  irs + 1,
  5428  		AckNum:  iss + 1,
  5429  		RcvWnd:  30000,
  5430  	})
  5431  	return irs, iss
  5432  }
  5433  
  5434  func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
  5435  	t.Helper()
  5436  	// Send a SYN request.
  5437  	irs = seqnum.Value(context.TestInitialSequenceNumber)
  5438  	c.SendV6Packet(nil, &context.Headers{
  5439  		SrcPort: srcPort,
  5440  		DstPort: context.StackPort,
  5441  		Flags:   header.TCPFlagSyn,
  5442  		SeqNum:  irs,
  5443  		RcvWnd:  30000,
  5444  	})
  5445  
  5446  	// Receive the SYN-ACK reply.
  5447  	b := c.GetV6Packet()
  5448  	tcp := header.TCP(header.IPv6(b).Payload())
  5449  	iss = seqnum.Value(tcp.SequenceNumber())
  5450  	tcpCheckers := []checker.TransportChecker{
  5451  		checker.SrcPort(context.StackPort),
  5452  		checker.DstPort(srcPort),
  5453  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
  5454  		checker.TCPAckNum(uint32(irs) + 1),
  5455  	}
  5456  
  5457  	if synCookieInUse {
  5458  		// When cookies are in use window scaling is disabled.
  5459  		tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
  5460  			WS:  -1,
  5461  			MSS: c.MSSWithoutOptionsV6(),
  5462  		}))
  5463  	}
  5464  
  5465  	checker.IPv6(t, b, checker.TCP(tcpCheckers...))
  5466  
  5467  	// Send ACK.
  5468  	c.SendV6Packet(nil, &context.Headers{
  5469  		SrcPort: srcPort,
  5470  		DstPort: context.StackPort,
  5471  		Flags:   header.TCPFlagAck,
  5472  		SeqNum:  irs + 1,
  5473  		AckNum:  iss + 1,
  5474  		RcvWnd:  30000,
  5475  	})
  5476  	return irs, iss
  5477  }
  5478  
  5479  // TestListenBacklogFull tests that netstack does not complete handshakes if the
  5480  // listen backlog for the endpoint is full.
  5481  func TestListenBacklogFull(t *testing.T) {
  5482  	c := context.New(t, defaultMTU)
  5483  	defer c.Cleanup()
  5484  
  5485  	// Create TCP endpoint.
  5486  	var err tcpip.Error
  5487  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  5488  	if err != nil {
  5489  		t.Fatalf("NewEndpoint failed: %s", err)
  5490  	}
  5491  
  5492  	// Bind to wildcard.
  5493  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  5494  		t.Fatalf("Bind failed: %s", err)
  5495  	}
  5496  
  5497  	// Test acceptance.
  5498  	// Start listening.
  5499  	listenBacklog := 10
  5500  	if err := c.EP.Listen(listenBacklog); err != nil {
  5501  		t.Fatalf("Listen failed: %s", err)
  5502  	}
  5503  
  5504  	lastPortOffset := uint16(0)
  5505  	for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
  5506  		executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
  5507  	}
  5508  
  5509  	time.Sleep(50 * time.Millisecond)
  5510  
  5511  	// Now execute send one more SYN. The stack should not respond as the backlog
  5512  	// is full at this point.
  5513  	c.SendPacket(nil, &context.Headers{
  5514  		SrcPort: context.TestPort + lastPortOffset,
  5515  		DstPort: context.StackPort,
  5516  		Flags:   header.TCPFlagSyn,
  5517  		SeqNum:  seqnum.Value(context.TestInitialSequenceNumber),
  5518  		RcvWnd:  30000,
  5519  	})
  5520  	c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
  5521  
  5522  	// Try to accept the connections in the backlog.
  5523  	we, ch := waiter.NewChannelEntry(nil)
  5524  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  5525  	defer c.WQ.EventUnregister(&we)
  5526  
  5527  	for i := 0; i < listenBacklog; i++ {
  5528  		_, _, err = c.EP.Accept(nil)
  5529  		if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  5530  			// Wait for connection to be established.
  5531  			select {
  5532  			case <-ch:
  5533  				_, _, err = c.EP.Accept(nil)
  5534  				if err != nil {
  5535  					t.Fatalf("Accept failed: %s", err)
  5536  				}
  5537  
  5538  			case <-time.After(1 * time.Second):
  5539  				t.Fatalf("Timed out waiting for accept")
  5540  			}
  5541  		}
  5542  	}
  5543  
  5544  	// Now verify that there are no more connections that can be accepted.
  5545  	_, _, err = c.EP.Accept(nil)
  5546  	if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  5547  		select {
  5548  		case <-ch:
  5549  			t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
  5550  		case <-time.After(1 * time.Second):
  5551  		}
  5552  	}
  5553  
  5554  	// Now a new handshake must succeed.
  5555  	executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
  5556  
  5557  	newEP, _, err := c.EP.Accept(nil)
  5558  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  5559  		// Wait for connection to be established.
  5560  		select {
  5561  		case <-ch:
  5562  			newEP, _, err = c.EP.Accept(nil)
  5563  			if err != nil {
  5564  				t.Fatalf("Accept failed: %s", err)
  5565  			}
  5566  
  5567  		case <-time.After(1 * time.Second):
  5568  			t.Fatalf("Timed out waiting for accept")
  5569  		}
  5570  	}
  5571  
  5572  	// Now verify that the TCP socket is usable and in a connected state.
  5573  	data := "Don't panic"
  5574  	var r strings.Reader
  5575  	r.Reset(data)
  5576  	newEP.Write(&r, tcpip.WriteOptions{})
  5577  	b := c.GetPacket()
  5578  	tcp := header.TCP(header.IPv4(b).Payload())
  5579  	if string(tcp.Payload()) != data {
  5580  		t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
  5581  	}
  5582  }
  5583  
  5584  // TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
  5585  // non unicast IPv4 address are not accepted.
  5586  func TestListenNoAcceptNonUnicastV4(t *testing.T) {
  5587  	multicastAddr := tcpiptestutil.MustParse4("224.0.1.2")
  5588  	otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3")
  5589  	subnet := context.StackAddrWithPrefix.Subnet()
  5590  	subnetBroadcastAddr := subnet.Broadcast()
  5591  
  5592  	tests := []struct {
  5593  		name    string
  5594  		srcAddr tcpip.Address
  5595  		dstAddr tcpip.Address
  5596  	}{
  5597  		{
  5598  			name:    "SourceUnspecified",
  5599  			srcAddr: header.IPv4Any,
  5600  			dstAddr: context.StackAddr,
  5601  		},
  5602  		{
  5603  			name:    "SourceBroadcast",
  5604  			srcAddr: header.IPv4Broadcast,
  5605  			dstAddr: context.StackAddr,
  5606  		},
  5607  		{
  5608  			name:    "SourceOurMulticast",
  5609  			srcAddr: multicastAddr,
  5610  			dstAddr: context.StackAddr,
  5611  		},
  5612  		{
  5613  			name:    "SourceOtherMulticast",
  5614  			srcAddr: otherMulticastAddr,
  5615  			dstAddr: context.StackAddr,
  5616  		},
  5617  		{
  5618  			name:    "DestUnspecified",
  5619  			srcAddr: context.TestAddr,
  5620  			dstAddr: header.IPv4Any,
  5621  		},
  5622  		{
  5623  			name:    "DestBroadcast",
  5624  			srcAddr: context.TestAddr,
  5625  			dstAddr: header.IPv4Broadcast,
  5626  		},
  5627  		{
  5628  			name:    "DestOurMulticast",
  5629  			srcAddr: context.TestAddr,
  5630  			dstAddr: multicastAddr,
  5631  		},
  5632  		{
  5633  			name:    "DestOtherMulticast",
  5634  			srcAddr: context.TestAddr,
  5635  			dstAddr: otherMulticastAddr,
  5636  		},
  5637  		{
  5638  			name:    "SrcSubnetBroadcast",
  5639  			srcAddr: subnetBroadcastAddr,
  5640  			dstAddr: context.StackAddr,
  5641  		},
  5642  		{
  5643  			name:    "DestSubnetBroadcast",
  5644  			srcAddr: context.TestAddr,
  5645  			dstAddr: subnetBroadcastAddr,
  5646  		},
  5647  	}
  5648  
  5649  	for _, test := range tests {
  5650  		t.Run(test.name, func(t *testing.T) {
  5651  			c := context.New(t, defaultMTU)
  5652  			defer c.Cleanup()
  5653  
  5654  			c.Create(-1)
  5655  
  5656  			if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
  5657  				t.Fatalf("JoinGroup failed: %s", err)
  5658  			}
  5659  
  5660  			if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  5661  				t.Fatalf("Bind failed: %s", err)
  5662  			}
  5663  
  5664  			if err := c.EP.Listen(1); err != nil {
  5665  				t.Fatalf("Listen failed: %s", err)
  5666  			}
  5667  
  5668  			irs := seqnum.Value(context.TestInitialSequenceNumber)
  5669  			c.SendPacketWithAddrs(nil, &context.Headers{
  5670  				SrcPort: context.TestPort,
  5671  				DstPort: context.StackPort,
  5672  				Flags:   header.TCPFlagSyn,
  5673  				SeqNum:  irs,
  5674  				RcvWnd:  30000,
  5675  			}, test.srcAddr, test.dstAddr)
  5676  			c.CheckNoPacket("Should not have received a response")
  5677  
  5678  			// Handle normal packet.
  5679  			c.SendPacketWithAddrs(nil, &context.Headers{
  5680  				SrcPort: context.TestPort,
  5681  				DstPort: context.StackPort,
  5682  				Flags:   header.TCPFlagSyn,
  5683  				SeqNum:  irs,
  5684  				RcvWnd:  30000,
  5685  			}, context.TestAddr, context.StackAddr)
  5686  			checker.IPv4(t, c.GetPacket(),
  5687  				checker.TCP(
  5688  					checker.SrcPort(context.StackPort),
  5689  					checker.DstPort(context.TestPort),
  5690  					checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
  5691  					checker.TCPAckNum(uint32(irs)+1)))
  5692  		})
  5693  	}
  5694  }
  5695  
  5696  // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
  5697  // non unicast IPv6 address are not accepted.
  5698  func TestListenNoAcceptNonUnicastV6(t *testing.T) {
  5699  	multicastAddr := tcpiptestutil.MustParse6("ff0e::101")
  5700  	otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102")
  5701  
  5702  	tests := []struct {
  5703  		name    string
  5704  		srcAddr tcpip.Address
  5705  		dstAddr tcpip.Address
  5706  	}{
  5707  		{
  5708  			"SourceUnspecified",
  5709  			header.IPv6Any,
  5710  			context.StackV6Addr,
  5711  		},
  5712  		{
  5713  			"SourceAllNodes",
  5714  			header.IPv6AllNodesMulticastAddress,
  5715  			context.StackV6Addr,
  5716  		},
  5717  		{
  5718  			"SourceOurMulticast",
  5719  			multicastAddr,
  5720  			context.StackV6Addr,
  5721  		},
  5722  		{
  5723  			"SourceOtherMulticast",
  5724  			otherMulticastAddr,
  5725  			context.StackV6Addr,
  5726  		},
  5727  		{
  5728  			"DestUnspecified",
  5729  			context.TestV6Addr,
  5730  			header.IPv6Any,
  5731  		},
  5732  		{
  5733  			"DestAllNodes",
  5734  			context.TestV6Addr,
  5735  			header.IPv6AllNodesMulticastAddress,
  5736  		},
  5737  		{
  5738  			"DestOurMulticast",
  5739  			context.TestV6Addr,
  5740  			multicastAddr,
  5741  		},
  5742  		{
  5743  			"DestOtherMulticast",
  5744  			context.TestV6Addr,
  5745  			otherMulticastAddr,
  5746  		},
  5747  	}
  5748  
  5749  	for _, test := range tests {
  5750  		t.Run(test.name, func(t *testing.T) {
  5751  			c := context.New(t, defaultMTU)
  5752  			defer c.Cleanup()
  5753  
  5754  			c.CreateV6Endpoint(true)
  5755  
  5756  			if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
  5757  				t.Fatalf("JoinGroup failed: %s", err)
  5758  			}
  5759  
  5760  			if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  5761  				t.Fatalf("Bind failed: %s", err)
  5762  			}
  5763  
  5764  			if err := c.EP.Listen(1); err != nil {
  5765  				t.Fatalf("Listen failed: %s", err)
  5766  			}
  5767  
  5768  			irs := seqnum.Value(context.TestInitialSequenceNumber)
  5769  			c.SendV6PacketWithAddrs(nil, &context.Headers{
  5770  				SrcPort: context.TestPort,
  5771  				DstPort: context.StackPort,
  5772  				Flags:   header.TCPFlagSyn,
  5773  				SeqNum:  irs,
  5774  				RcvWnd:  30000,
  5775  			}, test.srcAddr, test.dstAddr)
  5776  			c.CheckNoPacket("Should not have received a response")
  5777  
  5778  			// Handle normal packet.
  5779  			c.SendV6PacketWithAddrs(nil, &context.Headers{
  5780  				SrcPort: context.TestPort,
  5781  				DstPort: context.StackPort,
  5782  				Flags:   header.TCPFlagSyn,
  5783  				SeqNum:  irs,
  5784  				RcvWnd:  30000,
  5785  			}, context.TestV6Addr, context.StackV6Addr)
  5786  			checker.IPv6(t, c.GetV6Packet(),
  5787  				checker.TCP(
  5788  					checker.SrcPort(context.StackPort),
  5789  					checker.DstPort(context.TestPort),
  5790  					checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
  5791  					checker.TCPAckNum(uint32(irs)+1)))
  5792  		})
  5793  	}
  5794  }
  5795  
  5796  func TestListenSynRcvdQueueFull(t *testing.T) {
  5797  	c := context.New(t, defaultMTU)
  5798  	defer c.Cleanup()
  5799  
  5800  	// Create TCP endpoint.
  5801  	var err tcpip.Error
  5802  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  5803  	if err != nil {
  5804  		t.Fatalf("NewEndpoint failed: %s", err)
  5805  	}
  5806  
  5807  	// Bind to wildcard.
  5808  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  5809  		t.Fatalf("Bind failed: %s", err)
  5810  	}
  5811  
  5812  	// Test acceptance.
  5813  	if err := c.EP.Listen(1); err != nil {
  5814  		t.Fatalf("Listen failed: %s", err)
  5815  	}
  5816  
  5817  	// Send two SYN's the first one should get a SYN-ACK, the
  5818  	// second one should not get any response and is dropped as
  5819  	// the accept queue is full.
  5820  	irs := seqnum.Value(context.TestInitialSequenceNumber)
  5821  	c.SendPacket(nil, &context.Headers{
  5822  		SrcPort: context.TestPort,
  5823  		DstPort: context.StackPort,
  5824  		Flags:   header.TCPFlagSyn,
  5825  		SeqNum:  irs,
  5826  		RcvWnd:  30000,
  5827  	})
  5828  
  5829  	// Receive the SYN-ACK reply.
  5830  	b := c.GetPacket()
  5831  	tcp := header.TCP(header.IPv4(b).Payload())
  5832  	iss := seqnum.Value(tcp.SequenceNumber())
  5833  	tcpCheckers := []checker.TransportChecker{
  5834  		checker.SrcPort(context.StackPort),
  5835  		checker.DstPort(context.TestPort),
  5836  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
  5837  		checker.TCPAckNum(uint32(irs) + 1),
  5838  	}
  5839  	checker.IPv4(t, b, checker.TCP(tcpCheckers...))
  5840  
  5841  	// Now complete the previous connection.
  5842  	// Send ACK.
  5843  	c.SendPacket(nil, &context.Headers{
  5844  		SrcPort: context.TestPort,
  5845  		DstPort: context.StackPort,
  5846  		Flags:   header.TCPFlagAck,
  5847  		SeqNum:  irs + 1,
  5848  		AckNum:  iss + 1,
  5849  		RcvWnd:  30000,
  5850  	})
  5851  
  5852  	// Verify if that is delivered to the accept queue.
  5853  	we, ch := waiter.NewChannelEntry(nil)
  5854  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  5855  	defer c.WQ.EventUnregister(&we)
  5856  	<-ch
  5857  
  5858  	// Now execute send one more SYN. The stack should not respond as the backlog
  5859  	// is full at this point.
  5860  	c.SendPacket(nil, &context.Headers{
  5861  		SrcPort: context.TestPort + 1,
  5862  		DstPort: context.StackPort,
  5863  		Flags:   header.TCPFlagSyn,
  5864  		SeqNum:  seqnum.Value(889),
  5865  		RcvWnd:  30000,
  5866  	})
  5867  	c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
  5868  
  5869  	// Try to accept the connections in the backlog.
  5870  	newEP, _, err := c.EP.Accept(nil)
  5871  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  5872  		// Wait for connection to be established.
  5873  		select {
  5874  		case <-ch:
  5875  			newEP, _, err = c.EP.Accept(nil)
  5876  			if err != nil {
  5877  				t.Fatalf("Accept failed: %s", err)
  5878  			}
  5879  
  5880  		case <-time.After(1 * time.Second):
  5881  			t.Fatalf("Timed out waiting for accept")
  5882  		}
  5883  	}
  5884  
  5885  	// Now verify that the TCP socket is usable and in a connected state.
  5886  	data := "Don't panic"
  5887  	var r strings.Reader
  5888  	r.Reset(data)
  5889  	newEP.Write(&r, tcpip.WriteOptions{})
  5890  	pkt := c.GetPacket()
  5891  	tcp = header.IPv4(pkt).Payload()
  5892  	if string(tcp.Payload()) != data {
  5893  		t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
  5894  	}
  5895  }
  5896  
  5897  func TestListenBacklogFullSynCookieInUse(t *testing.T) {
  5898  	c := context.New(t, defaultMTU)
  5899  	defer c.Cleanup()
  5900  
  5901  	// Create TCP endpoint.
  5902  	var err tcpip.Error
  5903  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  5904  	if err != nil {
  5905  		t.Fatalf("NewEndpoint failed: %s", err)
  5906  	}
  5907  
  5908  	// Bind to wildcard.
  5909  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  5910  		t.Fatalf("Bind failed: %s", err)
  5911  	}
  5912  
  5913  	// Test for SynCookies usage after filling up the backlog.
  5914  	if err := c.EP.Listen(1); err != nil {
  5915  		t.Fatalf("Listen failed: %s", err)
  5916  	}
  5917  
  5918  	executeHandshake(t, c, context.TestPort, false)
  5919  
  5920  	// Wait for this to be delivered to the accept queue.
  5921  	time.Sleep(50 * time.Millisecond)
  5922  
  5923  	// Send a SYN request.
  5924  	irs := seqnum.Value(context.TestInitialSequenceNumber)
  5925  	c.SendPacket(nil, &context.Headers{
  5926  		// pick a different src port for new SYN.
  5927  		SrcPort: context.TestPort + 1,
  5928  		DstPort: context.StackPort,
  5929  		Flags:   header.TCPFlagSyn,
  5930  		SeqNum:  irs,
  5931  		RcvWnd:  30000,
  5932  	})
  5933  	// The Syn should be dropped as the endpoint's backlog is full.
  5934  	c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
  5935  
  5936  	// Verify that there is only one acceptable connection at this point.
  5937  	we, ch := waiter.NewChannelEntry(nil)
  5938  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  5939  	defer c.WQ.EventUnregister(&we)
  5940  
  5941  	_, _, err = c.EP.Accept(nil)
  5942  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  5943  		// Wait for connection to be established.
  5944  		select {
  5945  		case <-ch:
  5946  			_, _, err = c.EP.Accept(nil)
  5947  			if err != nil {
  5948  				t.Fatalf("Accept failed: %s", err)
  5949  			}
  5950  
  5951  		case <-time.After(1 * time.Second):
  5952  			t.Fatalf("Timed out waiting for accept")
  5953  		}
  5954  	}
  5955  
  5956  	// Now verify that there are no more connections that can be accepted.
  5957  	_, _, err = c.EP.Accept(nil)
  5958  	if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  5959  		select {
  5960  		case <-ch:
  5961  			t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
  5962  		case <-time.After(1 * time.Second):
  5963  		}
  5964  	}
  5965  }
  5966  
  5967  func TestSYNRetransmit(t *testing.T) {
  5968  	c := context.New(t, defaultMTU)
  5969  	defer c.Cleanup()
  5970  
  5971  	// Create TCP endpoint.
  5972  	var err tcpip.Error
  5973  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  5974  	if err != nil {
  5975  		t.Fatalf("NewEndpoint failed: %s", err)
  5976  	}
  5977  
  5978  	// Bind to wildcard.
  5979  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  5980  		t.Fatalf("Bind failed: %s", err)
  5981  	}
  5982  
  5983  	// Start listening.
  5984  	if err := c.EP.Listen(10); err != nil {
  5985  		t.Fatalf("Listen failed: %s", err)
  5986  	}
  5987  
  5988  	// Send the same SYN packet multiple times. We should still get a valid SYN-ACK
  5989  	// reply.
  5990  	irs := seqnum.Value(context.TestInitialSequenceNumber)
  5991  	for i := 0; i < 5; i++ {
  5992  		c.SendPacket(nil, &context.Headers{
  5993  			SrcPort: context.TestPort,
  5994  			DstPort: context.StackPort,
  5995  			Flags:   header.TCPFlagSyn,
  5996  			SeqNum:  irs,
  5997  			RcvWnd:  30000,
  5998  		})
  5999  	}
  6000  
  6001  	// Receive the SYN-ACK reply.
  6002  	tcpCheckers := []checker.TransportChecker{
  6003  		checker.SrcPort(context.StackPort),
  6004  		checker.DstPort(context.TestPort),
  6005  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
  6006  		checker.TCPAckNum(uint32(irs) + 1),
  6007  	}
  6008  	checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
  6009  }
  6010  
  6011  func TestSynRcvdBadSeqNumber(t *testing.T) {
  6012  	c := context.New(t, defaultMTU)
  6013  	defer c.Cleanup()
  6014  
  6015  	// Create TCP endpoint.
  6016  	var err tcpip.Error
  6017  	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  6018  	if err != nil {
  6019  		t.Fatalf("NewEndpoint failed: %s", err)
  6020  	}
  6021  
  6022  	// Bind to wildcard.
  6023  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  6024  		t.Fatalf("Bind failed: %s", err)
  6025  	}
  6026  
  6027  	// Start listening.
  6028  	if err := c.EP.Listen(10); err != nil {
  6029  		t.Fatalf("Listen failed: %s", err)
  6030  	}
  6031  
  6032  	// Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
  6033  	irs := seqnum.Value(context.TestInitialSequenceNumber)
  6034  	c.SendPacket(nil, &context.Headers{
  6035  		SrcPort: context.TestPort,
  6036  		DstPort: context.StackPort,
  6037  		Flags:   header.TCPFlagSyn,
  6038  		SeqNum:  irs,
  6039  		RcvWnd:  30000,
  6040  	})
  6041  
  6042  	// Receive the SYN-ACK reply.
  6043  	b := c.GetPacket()
  6044  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  6045  	iss := seqnum.Value(tcpHdr.SequenceNumber())
  6046  	tcpCheckers := []checker.TransportChecker{
  6047  		checker.SrcPort(context.StackPort),
  6048  		checker.DstPort(context.TestPort),
  6049  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
  6050  		checker.TCPAckNum(uint32(irs) + 1),
  6051  	}
  6052  	checker.IPv4(t, b, checker.TCP(tcpCheckers...))
  6053  
  6054  	// Now send a packet with an out-of-window sequence number
  6055  	largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
  6056  	c.SendPacket(nil, &context.Headers{
  6057  		SrcPort: context.TestPort,
  6058  		DstPort: context.StackPort,
  6059  		Flags:   header.TCPFlagAck,
  6060  		SeqNum:  largeSeqnum,
  6061  		AckNum:  iss + 1,
  6062  		RcvWnd:  30000,
  6063  	})
  6064  
  6065  	// Should receive an ACK with the expected SEQ number
  6066  	b = c.GetPacket()
  6067  	tcpCheckers = []checker.TransportChecker{
  6068  		checker.SrcPort(context.StackPort),
  6069  		checker.DstPort(context.TestPort),
  6070  		checker.TCPFlags(header.TCPFlagAck),
  6071  		checker.TCPAckNum(uint32(irs) + 1),
  6072  		checker.TCPSeqNum(uint32(iss + 1)),
  6073  	}
  6074  	checker.IPv4(t, b, checker.TCP(tcpCheckers...))
  6075  
  6076  	// Now that the socket replied appropriately with the ACK,
  6077  	// complete the connection to test that the large SEQ num
  6078  	// did not change the state from SYN-RCVD.
  6079  
  6080  	// Get setup to be notified about connection establishment.
  6081  	we, ch := waiter.NewChannelEntry(nil)
  6082  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  6083  	defer c.WQ.EventUnregister(&we)
  6084  
  6085  	// Send ACK to move to ESTABLISHED state.
  6086  	c.SendPacket(nil, &context.Headers{
  6087  		SrcPort: context.TestPort,
  6088  		DstPort: context.StackPort,
  6089  		Flags:   header.TCPFlagAck,
  6090  		SeqNum:  irs + 1,
  6091  		AckNum:  iss + 1,
  6092  		RcvWnd:  30000,
  6093  	})
  6094  
  6095  	<-ch
  6096  	newEP, _, err := c.EP.Accept(nil)
  6097  	if err != nil {
  6098  		t.Fatalf("Accept failed: %s", err)
  6099  	}
  6100  
  6101  	// Now verify that the TCP socket is usable and in a connected state.
  6102  	data := "Don't panic"
  6103  	var r strings.Reader
  6104  	r.Reset(data)
  6105  	if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil {
  6106  		t.Fatalf("Write failed: %s", err)
  6107  	}
  6108  
  6109  	pkt := c.GetPacket()
  6110  	tcpHdr = header.IPv4(pkt).Payload()
  6111  	if string(tcpHdr.Payload()) != data {
  6112  		t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
  6113  	}
  6114  }
  6115  
  6116  func TestPassiveConnectionAttemptIncrement(t *testing.T) {
  6117  	c := context.New(t, defaultMTU)
  6118  	defer c.Cleanup()
  6119  
  6120  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  6121  	if err != nil {
  6122  		t.Fatalf("NewEndpoint failed: %s", err)
  6123  	}
  6124  	c.EP = ep
  6125  	if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
  6126  		t.Fatalf("Bind failed: %s", err)
  6127  	}
  6128  	if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
  6129  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6130  	}
  6131  	if err := c.EP.Listen(1); err != nil {
  6132  		t.Fatalf("Listen failed: %s", err)
  6133  	}
  6134  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
  6135  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6136  	}
  6137  
  6138  	stats := c.Stack().Stats()
  6139  	want := stats.TCP.PassiveConnectionOpenings.Value() + 1
  6140  
  6141  	srcPort := uint16(context.TestPort)
  6142  	executeHandshake(t, c, srcPort+1, false)
  6143  
  6144  	we, ch := waiter.NewChannelEntry(nil)
  6145  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  6146  	defer c.WQ.EventUnregister(&we)
  6147  
  6148  	// Verify that there is only one acceptable connection at this point.
  6149  	_, _, err = c.EP.Accept(nil)
  6150  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6151  		// Wait for connection to be established.
  6152  		select {
  6153  		case <-ch:
  6154  			_, _, err = c.EP.Accept(nil)
  6155  			if err != nil {
  6156  				t.Fatalf("Accept failed: %s", err)
  6157  			}
  6158  
  6159  		case <-time.After(1 * time.Second):
  6160  			t.Fatalf("Timed out waiting for accept")
  6161  		}
  6162  	}
  6163  
  6164  	if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
  6165  		t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
  6166  	}
  6167  }
  6168  
  6169  func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
  6170  	c := context.New(t, defaultMTU)
  6171  	defer c.Cleanup()
  6172  
  6173  	stats := c.Stack().Stats()
  6174  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
  6175  	if err != nil {
  6176  		t.Fatalf("NewEndpoint failed: %s", err)
  6177  	}
  6178  	c.EP = ep
  6179  	if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
  6180  		t.Fatalf("Bind failed: %s", err)
  6181  	}
  6182  	if err := c.EP.Listen(1); err != nil {
  6183  		t.Fatalf("Listen failed: %s", err)
  6184  	}
  6185  
  6186  	srcPort := uint16(context.TestPort)
  6187  	// Now attempt a handshakes it will fill up the accept backlog.
  6188  	executeHandshake(t, c, srcPort, false)
  6189  
  6190  	// Give time for the final ACK to be processed as otherwise the next handshake could
  6191  	// get accepted before the previous one based on goroutine scheduling.
  6192  	time.Sleep(50 * time.Millisecond)
  6193  
  6194  	want := stats.TCP.ListenOverflowSynDrop.Value() + 1
  6195  
  6196  	// Now we will send one more SYN and this one should get dropped
  6197  	// Send a SYN request.
  6198  	c.SendPacket(nil, &context.Headers{
  6199  		SrcPort: srcPort + 2,
  6200  		DstPort: context.StackPort,
  6201  		Flags:   header.TCPFlagSyn,
  6202  		SeqNum:  seqnum.Value(context.TestInitialSequenceNumber),
  6203  		RcvWnd:  30000,
  6204  	})
  6205  
  6206  	checkValid := func() []error {
  6207  		var errors []error
  6208  		if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
  6209  			errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
  6210  		}
  6211  		if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
  6212  			errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
  6213  		}
  6214  		return errors
  6215  	}
  6216  
  6217  	start := time.Now()
  6218  	for time.Since(start) < time.Minute && len(checkValid()) > 0 {
  6219  		time.Sleep(50 * time.Millisecond)
  6220  	}
  6221  	for _, err := range checkValid() {
  6222  		t.Error(err)
  6223  	}
  6224  	if t.Failed() {
  6225  		t.FailNow()
  6226  	}
  6227  
  6228  	we, ch := waiter.NewChannelEntry(nil)
  6229  	c.WQ.EventRegister(&we, waiter.ReadableEvents)
  6230  	defer c.WQ.EventUnregister(&we)
  6231  
  6232  	// Now check that there is one acceptable connections.
  6233  	_, _, err = c.EP.Accept(nil)
  6234  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6235  		// Wait for connection to be established.
  6236  		<-ch
  6237  		_, _, err = c.EP.Accept(nil)
  6238  		if err != nil {
  6239  			t.Fatalf("Accept failed: %s", err)
  6240  		}
  6241  	}
  6242  }
  6243  
  6244  func TestListenDropIncrement(t *testing.T) {
  6245  	c := context.New(t, defaultMTU)
  6246  	defer c.Cleanup()
  6247  
  6248  	stats := c.Stack().Stats()
  6249  	c.Create(-1 /*epRcvBuf*/)
  6250  
  6251  	if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
  6252  		t.Fatalf("Bind failed: %s", err)
  6253  	}
  6254  	if err := c.EP.Listen(1 /*backlog*/); err != nil {
  6255  		t.Fatalf("Listen failed: %s", err)
  6256  	}
  6257  
  6258  	initialDropped := stats.DroppedPackets.Value()
  6259  
  6260  	// Send RST, FIN segments, that are expected to be dropped by the listener.
  6261  	c.SendPacket(nil, &context.Headers{
  6262  		SrcPort: context.TestPort,
  6263  		DstPort: context.StackPort,
  6264  		Flags:   header.TCPFlagRst,
  6265  	})
  6266  	c.SendPacket(nil, &context.Headers{
  6267  		SrcPort: context.TestPort,
  6268  		DstPort: context.StackPort,
  6269  		Flags:   header.TCPFlagFin,
  6270  	})
  6271  
  6272  	// To ensure that the RST, FIN sent earlier are indeed received and ignored
  6273  	// by the listener, send a SYN and wait for the SYN to be ACKd.
  6274  	irs := seqnum.Value(context.TestInitialSequenceNumber)
  6275  	c.SendPacket(nil, &context.Headers{
  6276  		SrcPort: context.TestPort,
  6277  		DstPort: context.StackPort,
  6278  		Flags:   header.TCPFlagSyn,
  6279  		SeqNum:  irs,
  6280  	})
  6281  	checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
  6282  		checker.DstPort(context.TestPort),
  6283  		checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
  6284  		checker.TCPAckNum(uint32(irs)+1),
  6285  	))
  6286  
  6287  	if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
  6288  		t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
  6289  	}
  6290  }
  6291  
  6292  func TestEndpointBindListenAcceptState(t *testing.T) {
  6293  	c := context.New(t, defaultMTU)
  6294  	defer c.Cleanup()
  6295  	wq := &waiter.Queue{}
  6296  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  6297  	if err != nil {
  6298  		t.Fatalf("NewEndpoint failed: %s", err)
  6299  	}
  6300  
  6301  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  6302  		t.Fatalf("Bind failed: %s", err)
  6303  	}
  6304  	if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
  6305  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6306  	}
  6307  
  6308  	ept := endpointTester{ep}
  6309  	ept.CheckReadError(t, &tcpip.ErrNotConnected{})
  6310  	if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
  6311  		t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
  6312  	}
  6313  
  6314  	if err := ep.Listen(10); err != nil {
  6315  		t.Fatalf("Listen failed: %s", err)
  6316  	}
  6317  	if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
  6318  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6319  	}
  6320  
  6321  	c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
  6322  
  6323  	// Try to accept the connection.
  6324  	we, ch := waiter.NewChannelEntry(nil)
  6325  	wq.EventRegister(&we, waiter.ReadableEvents)
  6326  	defer wq.EventUnregister(&we)
  6327  
  6328  	aep, _, err := ep.Accept(nil)
  6329  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6330  		// Wait for connection to be established.
  6331  		select {
  6332  		case <-ch:
  6333  			aep, _, err = ep.Accept(nil)
  6334  			if err != nil {
  6335  				t.Fatalf("Accept failed: %s", err)
  6336  			}
  6337  
  6338  		case <-time.After(1 * time.Second):
  6339  			t.Fatalf("Timed out waiting for accept")
  6340  		}
  6341  	}
  6342  	if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
  6343  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6344  	}
  6345  	{
  6346  		err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
  6347  		if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" {
  6348  			t.Errorf("Connect(...) mismatch (-want +got):\n%s", d)
  6349  		}
  6350  	}
  6351  	// Listening endpoint remains in listen state.
  6352  	if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
  6353  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6354  	}
  6355  
  6356  	ep.Close()
  6357  	// Give worker goroutines time to receive the close notification.
  6358  	time.Sleep(1 * time.Second)
  6359  	if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
  6360  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6361  	}
  6362  	// Accepted endpoint remains open when the listen endpoint is closed.
  6363  	if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
  6364  		t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
  6365  	}
  6366  
  6367  }
  6368  
  6369  // This test verifies that the auto tuning does not grow the receive buffer if
  6370  // the application is not reading the data actively.
  6371  func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
  6372  	const mtu = 1500
  6373  	const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
  6374  
  6375  	c := context.New(t, mtu)
  6376  	defer c.Cleanup()
  6377  
  6378  	stk := c.Stack()
  6379  	// Set lower limits for auto-tuning tests. This is required because the
  6380  	// test stops the worker which can cause packets to be dropped because
  6381  	// the segment queue holding unprocessed packets is limited to 500.
  6382  	const receiveBufferSize = 80 << 10 // 80KB.
  6383  	const maxReceiveBufferSize = receiveBufferSize * 10
  6384  	{
  6385  		opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
  6386  		if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  6387  			t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  6388  		}
  6389  	}
  6390  
  6391  	// Enable auto-tuning.
  6392  	{
  6393  		opt := tcpip.TCPModerateReceiveBufferOption(true)
  6394  		if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  6395  			t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
  6396  		}
  6397  	}
  6398  	// Change the expected window scale to match the value needed for the
  6399  	// maximum buffer size defined above.
  6400  	c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
  6401  
  6402  	rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
  6403  
  6404  	// NOTE: The timestamp values in the sent packets are meaningless to the
  6405  	// peer so we just increment the timestamp value by 1 every batch as we
  6406  	// are not really using them for anything. Send a single byte to verify
  6407  	// the advertised window.
  6408  	tsVal := rawEP.TSVal + 1
  6409  
  6410  	// Introduce a 25ms latency by delaying the first byte.
  6411  	latency := 25 * time.Millisecond
  6412  	time.Sleep(latency)
  6413  	// Send an initial payload with atleast segment overhead size. The receive
  6414  	// window would not grow for smaller segments.
  6415  	rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
  6416  
  6417  	pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
  6418  	rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
  6419  
  6420  	time.Sleep(25 * time.Millisecond)
  6421  
  6422  	// Allocate a large enough payload for the test.
  6423  	payloadSize := receiveBufferSize * 2
  6424  	b := make([]byte, payloadSize)
  6425  
  6426  	worker := (c.EP).(interface {
  6427  		StopWork()
  6428  		ResumeWork()
  6429  	})
  6430  	tsVal++
  6431  
  6432  	// Stop the worker goroutine.
  6433  	worker.StopWork()
  6434  	start := 0
  6435  	end := payloadSize / 2
  6436  	packetsSent := 0
  6437  	for ; start < end; start += mss {
  6438  		packetEnd := start + mss
  6439  		if start+mss > end {
  6440  			packetEnd = end
  6441  		}
  6442  		rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
  6443  		packetsSent++
  6444  	}
  6445  
  6446  	// Resume the worker so that it only sees the packets once all of them
  6447  	// are waiting to be read.
  6448  	worker.ResumeWork()
  6449  
  6450  	// Since we sent almost the full receive buffer worth of data (some may have
  6451  	// been dropped due to segment overheads), we should get a zero window back.
  6452  	pkt = c.GetPacket()
  6453  	tcpHdr := header.TCP(header.IPv4(pkt).Payload())
  6454  	gotRcvWnd := tcpHdr.WindowSize()
  6455  	wantAckNum := tcpHdr.AckNumber()
  6456  	if got, want := int(gotRcvWnd), 0; got != want {
  6457  		t.Fatalf("got rcvWnd: %d, want: %d", got, want)
  6458  	}
  6459  
  6460  	time.Sleep(25 * time.Millisecond)
  6461  	// Verify that sending more data when receiveBuffer is exhausted.
  6462  	rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
  6463  
  6464  	// Now read all the data from the endpoint and verify that advertised
  6465  	// window increases to the full available buffer size.
  6466  	for {
  6467  		_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
  6468  		if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6469  			break
  6470  		}
  6471  	}
  6472  
  6473  	// Verify that we receive a non-zero window update ACK. When running
  6474  	// under thread santizer this test can end up sending more than 1
  6475  	// ack, 1 for the non-zero window
  6476  	p := c.GetPacket()
  6477  	checker.IPv4(t, p, checker.TCP(
  6478  		checker.TCPAckNum(wantAckNum),
  6479  		func(t *testing.T, h header.Transport) {
  6480  			tcp, ok := h.(header.TCP)
  6481  			if !ok {
  6482  				return
  6483  			}
  6484  			// We use 10% here as the error margin upwards as the initial window we
  6485  			// got was afer 1 segment was already in the receive buffer queue.
  6486  			tolerance := 1.1
  6487  			if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
  6488  				t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
  6489  			}
  6490  		},
  6491  	))
  6492  }
  6493  
  6494  // This test verifies that the advertised window is auto-tuned up as the
  6495  // application is reading the data that is being received.
  6496  func TestReceiveBufferAutoTuning(t *testing.T) {
  6497  	const mtu = 1500
  6498  	const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
  6499  
  6500  	c := context.New(t, mtu)
  6501  	defer c.Cleanup()
  6502  
  6503  	// Enable Auto-tuning.
  6504  	stk := c.Stack()
  6505  	// Disable out of window rate limiting for this test by setting it to 0 as we
  6506  	// use out of window ACKs to measure the advertised window.
  6507  	var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption
  6508  	if err := stk.SetOption(tcpInvalidRateLimit); err != nil {
  6509  		t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err)
  6510  	}
  6511  
  6512  	const receiveBufferSize = 80 << 10 // 80KB.
  6513  	const maxReceiveBufferSize = receiveBufferSize * 10
  6514  	{
  6515  		opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
  6516  		if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  6517  			t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
  6518  		}
  6519  	}
  6520  
  6521  	// Enable auto-tuning.
  6522  	{
  6523  		opt := tcpip.TCPModerateReceiveBufferOption(true)
  6524  		if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  6525  			t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
  6526  		}
  6527  	}
  6528  	// Change the expected window scale to match the value needed for the
  6529  	// maximum buffer size used by stack.
  6530  	c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
  6531  
  6532  	rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
  6533  	tsVal := rawEP.TSVal
  6534  	rawEP.NextSeqNum--
  6535  	rawEP.SendPacketWithTS(nil, tsVal)
  6536  	rawEP.NextSeqNum++
  6537  	pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
  6538  	curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
  6539  	scaleRcvWnd := func(rcvWnd int) uint16 {
  6540  		return uint16(rcvWnd >> c.WindowScale)
  6541  	}
  6542  	// Allocate a large array to send to the endpoint.
  6543  	b := make([]byte, receiveBufferSize*48)
  6544  
  6545  	// In every iteration we will send double the number of bytes sent in
  6546  	// the previous iteration and read the same from the app. The received
  6547  	// window should grow by at least 2x of bytes read by the app in every
  6548  	// RTT.
  6549  	offset := 0
  6550  	payloadSize := receiveBufferSize / 8
  6551  	worker := (c.EP).(interface {
  6552  		StopWork()
  6553  		ResumeWork()
  6554  	})
  6555  	latency := 1 * time.Millisecond
  6556  	for i := 0; i < 5; i++ {
  6557  		tsVal++
  6558  
  6559  		// Stop the worker goroutine.
  6560  		worker.StopWork()
  6561  		start := offset
  6562  		end := offset + payloadSize
  6563  		totalSent := 0
  6564  		packetsSent := 0
  6565  		for ; start < end; start += mss {
  6566  			rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
  6567  			totalSent += mss
  6568  			packetsSent++
  6569  		}
  6570  
  6571  		// Resume it so that it only sees the packets once all of them
  6572  		// are waiting to be read.
  6573  		worker.ResumeWork()
  6574  
  6575  		// Give 1ms for the worker to process the packets.
  6576  		time.Sleep(1 * time.Millisecond)
  6577  
  6578  		lastACK := c.GetPacket()
  6579  		// Discard any intermediate ACKs and only check the last ACK we get in a
  6580  		// short time period of few ms.
  6581  		for {
  6582  			time.Sleep(1 * time.Millisecond)
  6583  			pkt := c.GetPacketNonBlocking()
  6584  			if pkt == nil {
  6585  				break
  6586  			}
  6587  			lastACK = pkt
  6588  		}
  6589  		if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
  6590  			t.Fatalf("advertised window got: %d, want <= %d", got, want)
  6591  		}
  6592  
  6593  		// Now read all the data from the endpoint and invoke the
  6594  		// moderation API to allow for receive buffer auto-tuning
  6595  		// to happen before we measure the new window.
  6596  		totalCopied := 0
  6597  		for {
  6598  			res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
  6599  			if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6600  				break
  6601  			}
  6602  			totalCopied += res.Count
  6603  		}
  6604  
  6605  		// Invoke the moderation API. This is required for auto-tuning
  6606  		// to happen. This method is normally expected to be invoked
  6607  		// from a higher layer than tcpip.Endpoint. So we simulate
  6608  		// copying to userspace by invoking it explicitly here.
  6609  		c.EP.ModerateRecvBuf(totalCopied)
  6610  
  6611  		// Now send a keep-alive packet to trigger an ACK so that we can
  6612  		// measure the new window.
  6613  		rawEP.NextSeqNum--
  6614  		rawEP.SendPacketWithTS(nil, tsVal)
  6615  		rawEP.NextSeqNum++
  6616  
  6617  		if i == 0 {
  6618  			// In the first iteration the receiver based RTT is not
  6619  			// yet known as a result the moderation code should not
  6620  			// increase the advertised window.
  6621  			rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
  6622  		} else {
  6623  			// Read loop above could generate an ACK if the window had dropped to
  6624  			// zero and then read had opened it up.
  6625  			lastACK := c.GetPacket()
  6626  			// Discard any intermediate ACKs and only check the last ACK we get in a
  6627  			// short time period of few ms.
  6628  			for {
  6629  				time.Sleep(1 * time.Millisecond)
  6630  				pkt := c.GetPacketNonBlocking()
  6631  				if pkt == nil {
  6632  					break
  6633  				}
  6634  				lastACK = pkt
  6635  			}
  6636  			curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
  6637  			// If thew new current window is close maxReceiveBufferSize then terminate
  6638  			// the loop. This can happen before all iterations are done due to timing
  6639  			// differences when running the test.
  6640  			if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
  6641  				break
  6642  			}
  6643  			// Increase the latency after first two iterations to
  6644  			// establish a low RTT value in the receiver since it
  6645  			// only tracks the lowest value. This ensures that when
  6646  			// ModerateRcvBuf is called the elapsed time is always >
  6647  			// rtt. Without this the test is flaky due to delays due
  6648  			// to scheduling/wakeup etc.
  6649  			latency += 50 * time.Millisecond
  6650  		}
  6651  		time.Sleep(latency)
  6652  		offset += payloadSize
  6653  		payloadSize *= 2
  6654  	}
  6655  	// Check that at the end of our iterations the receive window grew close to the maximum
  6656  	// permissible size of maxReceiveBufferSize/2
  6657  	if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
  6658  		t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
  6659  	}
  6660  
  6661  }
  6662  
  6663  func TestDelayEnabled(t *testing.T) {
  6664  	c := context.New(t, defaultMTU)
  6665  	defer c.Cleanup()
  6666  	checkDelayOption(t, c, false, false) // Delay is disabled by default.
  6667  
  6668  	for _, delayEnabled := range []bool{false, true} {
  6669  		t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) {
  6670  			c := context.New(t, defaultMTU)
  6671  			defer c.Cleanup()
  6672  			opt := tcpip.TCPDelayEnabled(delayEnabled)
  6673  			if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  6674  				t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err)
  6675  			}
  6676  			checkDelayOption(t, c, opt, delayEnabled)
  6677  		})
  6678  	}
  6679  }
  6680  
  6681  func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) {
  6682  	t.Helper()
  6683  
  6684  	var gotDelayEnabled tcpip.TCPDelayEnabled
  6685  	if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
  6686  		t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
  6687  	}
  6688  	if gotDelayEnabled != wantDelayEnabled {
  6689  		t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
  6690  	}
  6691  
  6692  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
  6693  	if err != nil {
  6694  		t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
  6695  	}
  6696  	gotDelayOption := ep.SocketOptions().GetDelayOption()
  6697  	if gotDelayOption != wantDelayOption {
  6698  		t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
  6699  	}
  6700  }
  6701  
  6702  func TestTCPLingerTimeout(t *testing.T) {
  6703  	c := context.New(t, 1500 /* mtu */)
  6704  	defer c.Cleanup()
  6705  
  6706  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  6707  
  6708  	testCases := []struct {
  6709  		name             string
  6710  		tcpLingerTimeout time.Duration
  6711  		want             time.Duration
  6712  	}{
  6713  		{"NegativeLingerTimeout", -123123, -1},
  6714  		// Zero is treated same as the stack's default TCP_LINGER2 timeout.
  6715  		{"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
  6716  		{"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
  6717  		// Values > stack's TCPLingerTimeout are capped to the stack's
  6718  		// value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
  6719  		{"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
  6720  	}
  6721  	for _, tc := range testCases {
  6722  		t.Run(tc.name, func(t *testing.T) {
  6723  			v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)
  6724  			if err := c.EP.SetSockOpt(&v); err != nil {
  6725  				t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err)
  6726  			}
  6727  
  6728  			v = 0
  6729  			if err := c.EP.GetSockOpt(&v); err != nil {
  6730  				t.Fatalf("GetSockOpt(&%T) = %s", v, err)
  6731  			}
  6732  			if got, want := time.Duration(v), tc.want; got != want {
  6733  				t.Fatalf("got linger timeout = %s, want = %s", got, want)
  6734  			}
  6735  		})
  6736  	}
  6737  }
  6738  
  6739  func TestTCPTimeWaitRSTIgnored(t *testing.T) {
  6740  	c := context.New(t, defaultMTU)
  6741  	defer c.Cleanup()
  6742  
  6743  	wq := &waiter.Queue{}
  6744  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  6745  	if err != nil {
  6746  		t.Fatalf("NewEndpoint failed: %s", err)
  6747  	}
  6748  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  6749  		t.Fatalf("Bind failed: %s", err)
  6750  	}
  6751  
  6752  	if err := ep.Listen(10); err != nil {
  6753  		t.Fatalf("Listen failed: %s", err)
  6754  	}
  6755  
  6756  	// Send a SYN request.
  6757  	iss := seqnum.Value(context.TestInitialSequenceNumber)
  6758  	c.SendPacket(nil, &context.Headers{
  6759  		SrcPort: context.TestPort,
  6760  		DstPort: context.StackPort,
  6761  		Flags:   header.TCPFlagSyn,
  6762  		SeqNum:  iss,
  6763  		RcvWnd:  30000,
  6764  	})
  6765  
  6766  	// Receive the SYN-ACK reply.
  6767  	b := c.GetPacket()
  6768  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  6769  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  6770  
  6771  	ackHeaders := &context.Headers{
  6772  		SrcPort: context.TestPort,
  6773  		DstPort: context.StackPort,
  6774  		Flags:   header.TCPFlagAck,
  6775  		SeqNum:  iss + 1,
  6776  		AckNum:  c.IRS + 1,
  6777  	}
  6778  
  6779  	// Send ACK.
  6780  	c.SendPacket(nil, ackHeaders)
  6781  
  6782  	// Try to accept the connection.
  6783  	we, ch := waiter.NewChannelEntry(nil)
  6784  	wq.EventRegister(&we, waiter.ReadableEvents)
  6785  	defer wq.EventUnregister(&we)
  6786  
  6787  	c.EP, _, err = ep.Accept(nil)
  6788  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6789  		// Wait for connection to be established.
  6790  		select {
  6791  		case <-ch:
  6792  			c.EP, _, err = ep.Accept(nil)
  6793  			if err != nil {
  6794  				t.Fatalf("Accept failed: %s", err)
  6795  			}
  6796  
  6797  		case <-time.After(1 * time.Second):
  6798  			t.Fatalf("Timed out waiting for accept")
  6799  		}
  6800  	}
  6801  
  6802  	c.EP.Close()
  6803  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  6804  		checker.SrcPort(context.StackPort),
  6805  		checker.DstPort(context.TestPort),
  6806  		checker.TCPSeqNum(uint32(c.IRS+1)),
  6807  		checker.TCPAckNum(uint32(iss)+1),
  6808  		checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
  6809  
  6810  	finHeaders := &context.Headers{
  6811  		SrcPort: context.TestPort,
  6812  		DstPort: context.StackPort,
  6813  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  6814  		SeqNum:  iss + 1,
  6815  		AckNum:  c.IRS + 2,
  6816  	}
  6817  
  6818  	c.SendPacket(nil, finHeaders)
  6819  
  6820  	// Get the ACK to the FIN we just sent.
  6821  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  6822  		checker.SrcPort(context.StackPort),
  6823  		checker.DstPort(context.TestPort),
  6824  		checker.TCPSeqNum(uint32(c.IRS+2)),
  6825  		checker.TCPAckNum(uint32(iss)+2),
  6826  		checker.TCPFlags(header.TCPFlagAck)))
  6827  
  6828  	// Now send a RST and this should be ignored and not
  6829  	// generate an ACK.
  6830  	c.SendPacket(nil, &context.Headers{
  6831  		SrcPort: context.TestPort,
  6832  		DstPort: context.StackPort,
  6833  		Flags:   header.TCPFlagRst,
  6834  		SeqNum:  iss + 1,
  6835  		AckNum:  c.IRS + 2,
  6836  	})
  6837  
  6838  	c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
  6839  
  6840  	// Out of order ACK should generate an immediate ACK in
  6841  	// TIME_WAIT.
  6842  	c.SendPacket(nil, &context.Headers{
  6843  		SrcPort: context.TestPort,
  6844  		DstPort: context.StackPort,
  6845  		Flags:   header.TCPFlagAck,
  6846  		SeqNum:  iss + 1,
  6847  		AckNum:  c.IRS + 3,
  6848  	})
  6849  
  6850  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  6851  		checker.SrcPort(context.StackPort),
  6852  		checker.DstPort(context.TestPort),
  6853  		checker.TCPSeqNum(uint32(c.IRS+2)),
  6854  		checker.TCPAckNum(uint32(iss)+2),
  6855  		checker.TCPFlags(header.TCPFlagAck)))
  6856  }
  6857  
  6858  func TestTCPTimeWaitOutOfOrder(t *testing.T) {
  6859  	c := context.New(t, defaultMTU)
  6860  	defer c.Cleanup()
  6861  
  6862  	wq := &waiter.Queue{}
  6863  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  6864  	if err != nil {
  6865  		t.Fatalf("NewEndpoint failed: %s", err)
  6866  	}
  6867  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  6868  		t.Fatalf("Bind failed: %s", err)
  6869  	}
  6870  
  6871  	if err := ep.Listen(10); err != nil {
  6872  		t.Fatalf("Listen failed: %s", err)
  6873  	}
  6874  
  6875  	// Send a SYN request.
  6876  	iss := seqnum.Value(context.TestInitialSequenceNumber)
  6877  	c.SendPacket(nil, &context.Headers{
  6878  		SrcPort: context.TestPort,
  6879  		DstPort: context.StackPort,
  6880  		Flags:   header.TCPFlagSyn,
  6881  		SeqNum:  iss,
  6882  		RcvWnd:  30000,
  6883  	})
  6884  
  6885  	// Receive the SYN-ACK reply.
  6886  	b := c.GetPacket()
  6887  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  6888  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  6889  
  6890  	ackHeaders := &context.Headers{
  6891  		SrcPort: context.TestPort,
  6892  		DstPort: context.StackPort,
  6893  		Flags:   header.TCPFlagAck,
  6894  		SeqNum:  iss + 1,
  6895  		AckNum:  c.IRS + 1,
  6896  	}
  6897  
  6898  	// Send ACK.
  6899  	c.SendPacket(nil, ackHeaders)
  6900  
  6901  	// Try to accept the connection.
  6902  	we, ch := waiter.NewChannelEntry(nil)
  6903  	wq.EventRegister(&we, waiter.ReadableEvents)
  6904  	defer wq.EventUnregister(&we)
  6905  
  6906  	c.EP, _, err = ep.Accept(nil)
  6907  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  6908  		// Wait for connection to be established.
  6909  		select {
  6910  		case <-ch:
  6911  			c.EP, _, err = ep.Accept(nil)
  6912  			if err != nil {
  6913  				t.Fatalf("Accept failed: %s", err)
  6914  			}
  6915  
  6916  		case <-time.After(1 * time.Second):
  6917  			t.Fatalf("Timed out waiting for accept")
  6918  		}
  6919  	}
  6920  
  6921  	c.EP.Close()
  6922  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  6923  		checker.SrcPort(context.StackPort),
  6924  		checker.DstPort(context.TestPort),
  6925  		checker.TCPSeqNum(uint32(c.IRS+1)),
  6926  		checker.TCPAckNum(uint32(iss)+1),
  6927  		checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
  6928  
  6929  	finHeaders := &context.Headers{
  6930  		SrcPort: context.TestPort,
  6931  		DstPort: context.StackPort,
  6932  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  6933  		SeqNum:  iss + 1,
  6934  		AckNum:  c.IRS + 2,
  6935  	}
  6936  
  6937  	c.SendPacket(nil, finHeaders)
  6938  
  6939  	// Get the ACK to the FIN we just sent.
  6940  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  6941  		checker.SrcPort(context.StackPort),
  6942  		checker.DstPort(context.TestPort),
  6943  		checker.TCPSeqNum(uint32(c.IRS+2)),
  6944  		checker.TCPAckNum(uint32(iss)+2),
  6945  		checker.TCPFlags(header.TCPFlagAck)))
  6946  
  6947  	// Out of order ACK should generate an immediate ACK in
  6948  	// TIME_WAIT.
  6949  	c.SendPacket(nil, &context.Headers{
  6950  		SrcPort: context.TestPort,
  6951  		DstPort: context.StackPort,
  6952  		Flags:   header.TCPFlagAck,
  6953  		SeqNum:  iss + 1,
  6954  		AckNum:  c.IRS + 3,
  6955  	})
  6956  
  6957  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  6958  		checker.SrcPort(context.StackPort),
  6959  		checker.DstPort(context.TestPort),
  6960  		checker.TCPSeqNum(uint32(c.IRS+2)),
  6961  		checker.TCPAckNum(uint32(iss)+2),
  6962  		checker.TCPFlags(header.TCPFlagAck)))
  6963  }
  6964  
  6965  func TestTCPTimeWaitNewSyn(t *testing.T) {
  6966  	c := context.New(t, defaultMTU)
  6967  	defer c.Cleanup()
  6968  
  6969  	wq := &waiter.Queue{}
  6970  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  6971  	if err != nil {
  6972  		t.Fatalf("NewEndpoint failed: %s", err)
  6973  	}
  6974  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  6975  		t.Fatalf("Bind failed: %s", err)
  6976  	}
  6977  
  6978  	if err := ep.Listen(10); err != nil {
  6979  		t.Fatalf("Listen failed: %s", err)
  6980  	}
  6981  
  6982  	// Send a SYN request.
  6983  	iss := seqnum.Value(context.TestInitialSequenceNumber)
  6984  	c.SendPacket(nil, &context.Headers{
  6985  		SrcPort: context.TestPort,
  6986  		DstPort: context.StackPort,
  6987  		Flags:   header.TCPFlagSyn,
  6988  		SeqNum:  iss,
  6989  		RcvWnd:  30000,
  6990  	})
  6991  
  6992  	// Receive the SYN-ACK reply.
  6993  	b := c.GetPacket()
  6994  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  6995  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  6996  
  6997  	ackHeaders := &context.Headers{
  6998  		SrcPort: context.TestPort,
  6999  		DstPort: context.StackPort,
  7000  		Flags:   header.TCPFlagAck,
  7001  		SeqNum:  iss + 1,
  7002  		AckNum:  c.IRS + 1,
  7003  	}
  7004  
  7005  	// Send ACK.
  7006  	c.SendPacket(nil, ackHeaders)
  7007  
  7008  	// Try to accept the connection.
  7009  	we, ch := waiter.NewChannelEntry(nil)
  7010  	wq.EventRegister(&we, waiter.ReadableEvents)
  7011  	defer wq.EventUnregister(&we)
  7012  
  7013  	c.EP, _, err = ep.Accept(nil)
  7014  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  7015  		// Wait for connection to be established.
  7016  		select {
  7017  		case <-ch:
  7018  			c.EP, _, err = ep.Accept(nil)
  7019  			if err != nil {
  7020  				t.Fatalf("Accept failed: %s", err)
  7021  			}
  7022  
  7023  		case <-time.After(1 * time.Second):
  7024  			t.Fatalf("Timed out waiting for accept")
  7025  		}
  7026  	}
  7027  
  7028  	c.EP.Close()
  7029  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7030  		checker.SrcPort(context.StackPort),
  7031  		checker.DstPort(context.TestPort),
  7032  		checker.TCPSeqNum(uint32(c.IRS+1)),
  7033  		checker.TCPAckNum(uint32(iss)+1),
  7034  		checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
  7035  
  7036  	finHeaders := &context.Headers{
  7037  		SrcPort: context.TestPort,
  7038  		DstPort: context.StackPort,
  7039  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  7040  		SeqNum:  iss + 1,
  7041  		AckNum:  c.IRS + 2,
  7042  	}
  7043  
  7044  	c.SendPacket(nil, finHeaders)
  7045  
  7046  	// Get the ACK to the FIN we just sent.
  7047  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7048  		checker.SrcPort(context.StackPort),
  7049  		checker.DstPort(context.TestPort),
  7050  		checker.TCPSeqNum(uint32(c.IRS+2)),
  7051  		checker.TCPAckNum(uint32(iss)+2),
  7052  		checker.TCPFlags(header.TCPFlagAck)))
  7053  
  7054  	// Send a SYN request w/ sequence number lower than
  7055  	// the highest sequence number sent. We just reuse
  7056  	// the same number.
  7057  	iss = seqnum.Value(context.TestInitialSequenceNumber)
  7058  	c.SendPacket(nil, &context.Headers{
  7059  		SrcPort: context.TestPort,
  7060  		DstPort: context.StackPort,
  7061  		Flags:   header.TCPFlagSyn,
  7062  		SeqNum:  iss,
  7063  		RcvWnd:  30000,
  7064  	})
  7065  
  7066  	c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
  7067  
  7068  	// drain any older notifications from the notification channel before attempting
  7069  	// 2nd connection.
  7070  	select {
  7071  	case <-ch:
  7072  	default:
  7073  	}
  7074  
  7075  	// Send a SYN request w/ sequence number higher than
  7076  	// the highest sequence number sent.
  7077  	iss = iss.Add(3)
  7078  	c.SendPacket(nil, &context.Headers{
  7079  		SrcPort: context.TestPort,
  7080  		DstPort: context.StackPort,
  7081  		Flags:   header.TCPFlagSyn,
  7082  		SeqNum:  iss,
  7083  		RcvWnd:  30000,
  7084  	})
  7085  
  7086  	// Receive the SYN-ACK reply.
  7087  	b = c.GetPacket()
  7088  	tcpHdr = header.IPv4(b).Payload()
  7089  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  7090  
  7091  	ackHeaders = &context.Headers{
  7092  		SrcPort: context.TestPort,
  7093  		DstPort: context.StackPort,
  7094  		Flags:   header.TCPFlagAck,
  7095  		SeqNum:  iss + 1,
  7096  		AckNum:  c.IRS + 1,
  7097  	}
  7098  
  7099  	// Send ACK.
  7100  	c.SendPacket(nil, ackHeaders)
  7101  
  7102  	// Try to accept the connection.
  7103  	c.EP, _, err = ep.Accept(nil)
  7104  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  7105  		// Wait for connection to be established.
  7106  		select {
  7107  		case <-ch:
  7108  			c.EP, _, err = ep.Accept(nil)
  7109  			if err != nil {
  7110  				t.Fatalf("Accept failed: %s", err)
  7111  			}
  7112  
  7113  		case <-time.After(1 * time.Second):
  7114  			t.Fatalf("Timed out waiting for accept")
  7115  		}
  7116  	}
  7117  }
  7118  
  7119  func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
  7120  	c := context.New(t, defaultMTU)
  7121  	defer c.Cleanup()
  7122  
  7123  	// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
  7124  	// after 5 seconds in TIME_WAIT state.
  7125  	tcpTimeWaitTimeout := 5 * time.Second
  7126  	opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
  7127  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  7128  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
  7129  	}
  7130  
  7131  	want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
  7132  
  7133  	wq := &waiter.Queue{}
  7134  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  7135  	if err != nil {
  7136  		t.Fatalf("NewEndpoint failed: %s", err)
  7137  	}
  7138  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  7139  		t.Fatalf("Bind failed: %s", err)
  7140  	}
  7141  
  7142  	if err := ep.Listen(10); err != nil {
  7143  		t.Fatalf("Listen failed: %s", err)
  7144  	}
  7145  
  7146  	// Send a SYN request.
  7147  	iss := seqnum.Value(context.TestInitialSequenceNumber)
  7148  	c.SendPacket(nil, &context.Headers{
  7149  		SrcPort: context.TestPort,
  7150  		DstPort: context.StackPort,
  7151  		Flags:   header.TCPFlagSyn,
  7152  		SeqNum:  iss,
  7153  		RcvWnd:  30000,
  7154  	})
  7155  
  7156  	// Receive the SYN-ACK reply.
  7157  	b := c.GetPacket()
  7158  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  7159  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  7160  
  7161  	ackHeaders := &context.Headers{
  7162  		SrcPort: context.TestPort,
  7163  		DstPort: context.StackPort,
  7164  		Flags:   header.TCPFlagAck,
  7165  		SeqNum:  iss + 1,
  7166  		AckNum:  c.IRS + 1,
  7167  	}
  7168  
  7169  	// Send ACK.
  7170  	c.SendPacket(nil, ackHeaders)
  7171  
  7172  	// Try to accept the connection.
  7173  	we, ch := waiter.NewChannelEntry(nil)
  7174  	wq.EventRegister(&we, waiter.ReadableEvents)
  7175  	defer wq.EventUnregister(&we)
  7176  
  7177  	c.EP, _, err = ep.Accept(nil)
  7178  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  7179  		// Wait for connection to be established.
  7180  		select {
  7181  		case <-ch:
  7182  			c.EP, _, err = ep.Accept(nil)
  7183  			if err != nil {
  7184  				t.Fatalf("Accept failed: %s", err)
  7185  			}
  7186  
  7187  		case <-time.After(1 * time.Second):
  7188  			t.Fatalf("Timed out waiting for accept")
  7189  		}
  7190  	}
  7191  
  7192  	c.EP.Close()
  7193  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7194  		checker.SrcPort(context.StackPort),
  7195  		checker.DstPort(context.TestPort),
  7196  		checker.TCPSeqNum(uint32(c.IRS+1)),
  7197  		checker.TCPAckNum(uint32(iss)+1),
  7198  		checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
  7199  
  7200  	finHeaders := &context.Headers{
  7201  		SrcPort: context.TestPort,
  7202  		DstPort: context.StackPort,
  7203  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  7204  		SeqNum:  iss + 1,
  7205  		AckNum:  c.IRS + 2,
  7206  	}
  7207  
  7208  	c.SendPacket(nil, finHeaders)
  7209  
  7210  	// Get the ACK to the FIN we just sent.
  7211  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7212  		checker.SrcPort(context.StackPort),
  7213  		checker.DstPort(context.TestPort),
  7214  		checker.TCPSeqNum(uint32(c.IRS+2)),
  7215  		checker.TCPAckNum(uint32(iss)+2),
  7216  		checker.TCPFlags(header.TCPFlagAck)))
  7217  
  7218  	time.Sleep(2 * time.Second)
  7219  
  7220  	// Now send a duplicate FIN. This should cause the TIME_WAIT to extend
  7221  	// by another 5 seconds and also send us a duplicate ACK as it should
  7222  	// indicate that the final ACK was potentially lost.
  7223  	c.SendPacket(nil, finHeaders)
  7224  
  7225  	// Get the ACK to the FIN we just sent.
  7226  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7227  		checker.SrcPort(context.StackPort),
  7228  		checker.DstPort(context.TestPort),
  7229  		checker.TCPSeqNum(uint32(c.IRS+2)),
  7230  		checker.TCPAckNum(uint32(iss)+2),
  7231  		checker.TCPFlags(header.TCPFlagAck)))
  7232  
  7233  	// Sleep for 4 seconds so at this point we are 1 second past the
  7234  	// original tcpLingerTimeout of 5 seconds.
  7235  	time.Sleep(4 * time.Second)
  7236  
  7237  	// Send an ACK and it should not generate any packet as the socket
  7238  	// should still be in TIME_WAIT for another another 5 seconds due
  7239  	// to the duplicate FIN we sent earlier.
  7240  	*ackHeaders = *finHeaders
  7241  	ackHeaders.SeqNum = ackHeaders.SeqNum + 1
  7242  	ackHeaders.Flags = header.TCPFlagAck
  7243  	c.SendPacket(nil, ackHeaders)
  7244  
  7245  	c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
  7246  	// Now sleep for another 2 seconds so that we are past the
  7247  	// extended TIME_WAIT of 7 seconds (2 + 5).
  7248  	time.Sleep(2 * time.Second)
  7249  
  7250  	// Resend the same ACK.
  7251  	c.SendPacket(nil, ackHeaders)
  7252  
  7253  	// Receive the RST that should be generated as there is no valid
  7254  	// endpoint.
  7255  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7256  		checker.SrcPort(context.StackPort),
  7257  		checker.DstPort(context.TestPort),
  7258  		checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
  7259  		checker.TCPAckNum(0),
  7260  		checker.TCPFlags(header.TCPFlagRst)))
  7261  
  7262  	if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
  7263  		t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
  7264  	}
  7265  	if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
  7266  		t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
  7267  	}
  7268  }
  7269  
  7270  func TestTCPCloseWithData(t *testing.T) {
  7271  	c := context.New(t, defaultMTU)
  7272  	defer c.Cleanup()
  7273  
  7274  	// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
  7275  	// after 5 seconds in TIME_WAIT state.
  7276  	tcpTimeWaitTimeout := 5 * time.Second
  7277  	opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
  7278  	if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
  7279  		t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
  7280  	}
  7281  
  7282  	wq := &waiter.Queue{}
  7283  	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  7284  	if err != nil {
  7285  		t.Fatalf("NewEndpoint failed: %s", err)
  7286  	}
  7287  	if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  7288  		t.Fatalf("Bind failed: %s", err)
  7289  	}
  7290  
  7291  	if err := ep.Listen(10); err != nil {
  7292  		t.Fatalf("Listen failed: %s", err)
  7293  	}
  7294  
  7295  	// Send a SYN request.
  7296  	iss := seqnum.Value(context.TestInitialSequenceNumber)
  7297  	c.SendPacket(nil, &context.Headers{
  7298  		SrcPort: context.TestPort,
  7299  		DstPort: context.StackPort,
  7300  		Flags:   header.TCPFlagSyn,
  7301  		SeqNum:  iss,
  7302  		RcvWnd:  30000,
  7303  	})
  7304  
  7305  	// Receive the SYN-ACK reply.
  7306  	b := c.GetPacket()
  7307  	tcpHdr := header.TCP(header.IPv4(b).Payload())
  7308  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
  7309  
  7310  	ackHeaders := &context.Headers{
  7311  		SrcPort: context.TestPort,
  7312  		DstPort: context.StackPort,
  7313  		Flags:   header.TCPFlagAck,
  7314  		SeqNum:  iss + 1,
  7315  		AckNum:  c.IRS + 1,
  7316  		RcvWnd:  30000,
  7317  	}
  7318  
  7319  	// Send ACK.
  7320  	c.SendPacket(nil, ackHeaders)
  7321  
  7322  	// Try to accept the connection.
  7323  	we, ch := waiter.NewChannelEntry(nil)
  7324  	wq.EventRegister(&we, waiter.ReadableEvents)
  7325  	defer wq.EventUnregister(&we)
  7326  
  7327  	c.EP, _, err = ep.Accept(nil)
  7328  	if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
  7329  		// Wait for connection to be established.
  7330  		select {
  7331  		case <-ch:
  7332  			c.EP, _, err = ep.Accept(nil)
  7333  			if err != nil {
  7334  				t.Fatalf("Accept failed: %s", err)
  7335  			}
  7336  
  7337  		case <-time.After(1 * time.Second):
  7338  			t.Fatalf("Timed out waiting for accept")
  7339  		}
  7340  	}
  7341  
  7342  	// Now trigger a passive close by sending a FIN.
  7343  	finHeaders := &context.Headers{
  7344  		SrcPort: context.TestPort,
  7345  		DstPort: context.StackPort,
  7346  		Flags:   header.TCPFlagAck | header.TCPFlagFin,
  7347  		SeqNum:  iss + 1,
  7348  		AckNum:  c.IRS + 2,
  7349  		RcvWnd:  30000,
  7350  	}
  7351  
  7352  	c.SendPacket(nil, finHeaders)
  7353  
  7354  	// Get the ACK to the FIN we just sent.
  7355  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7356  		checker.SrcPort(context.StackPort),
  7357  		checker.DstPort(context.TestPort),
  7358  		checker.TCPSeqNum(uint32(c.IRS+1)),
  7359  		checker.TCPAckNum(uint32(iss)+2),
  7360  		checker.TCPFlags(header.TCPFlagAck)))
  7361  
  7362  	// Now write a few bytes and then close the endpoint.
  7363  	data := []byte{1, 2, 3}
  7364  
  7365  	var r bytes.Reader
  7366  	r.Reset(data)
  7367  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  7368  		t.Fatalf("Write failed: %s", err)
  7369  	}
  7370  
  7371  	// Check that data is received.
  7372  	b = c.GetPacket()
  7373  	checker.IPv4(t, b,
  7374  		checker.PayloadLen(len(data)+header.TCPMinimumSize),
  7375  		checker.TCP(
  7376  			checker.DstPort(context.TestPort),
  7377  			checker.TCPSeqNum(uint32(c.IRS)+1),
  7378  			checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
  7379  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  7380  		),
  7381  	)
  7382  
  7383  	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
  7384  		t.Errorf("got data = %x, want = %x", p, data)
  7385  	}
  7386  
  7387  	c.EP.Close()
  7388  	// Check the FIN.
  7389  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7390  		checker.SrcPort(context.StackPort),
  7391  		checker.DstPort(context.TestPort),
  7392  		checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
  7393  		checker.TCPAckNum(uint32(iss+2)),
  7394  		checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
  7395  
  7396  	// First send a partial ACK.
  7397  	ackHeaders = &context.Headers{
  7398  		SrcPort: context.TestPort,
  7399  		DstPort: context.StackPort,
  7400  		Flags:   header.TCPFlagAck,
  7401  		SeqNum:  iss + 2,
  7402  		AckNum:  c.IRS + 1 + seqnum.Value(len(data)-1),
  7403  		RcvWnd:  30000,
  7404  	}
  7405  	c.SendPacket(nil, ackHeaders)
  7406  
  7407  	// Now send a full ACK.
  7408  	ackHeaders = &context.Headers{
  7409  		SrcPort: context.TestPort,
  7410  		DstPort: context.StackPort,
  7411  		Flags:   header.TCPFlagAck,
  7412  		SeqNum:  iss + 2,
  7413  		AckNum:  c.IRS + 1 + seqnum.Value(len(data)),
  7414  		RcvWnd:  30000,
  7415  	}
  7416  	c.SendPacket(nil, ackHeaders)
  7417  
  7418  	// Now ACK the FIN.
  7419  	ackHeaders.AckNum++
  7420  	c.SendPacket(nil, ackHeaders)
  7421  
  7422  	// Now send an ACK and we should get a RST back as the endpoint should
  7423  	// be in CLOSED state.
  7424  	ackHeaders = &context.Headers{
  7425  		SrcPort: context.TestPort,
  7426  		DstPort: context.StackPort,
  7427  		Flags:   header.TCPFlagAck,
  7428  		SeqNum:  iss + 2,
  7429  		AckNum:  c.IRS + 1 + seqnum.Value(len(data)),
  7430  		RcvWnd:  30000,
  7431  	}
  7432  	c.SendPacket(nil, ackHeaders)
  7433  
  7434  	// Check the RST.
  7435  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7436  		checker.SrcPort(context.StackPort),
  7437  		checker.DstPort(context.TestPort),
  7438  		checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
  7439  		checker.TCPAckNum(0),
  7440  		checker.TCPFlags(header.TCPFlagRst)))
  7441  }
  7442  
  7443  func TestTCPUserTimeout(t *testing.T) {
  7444  	c := context.New(t, defaultMTU)
  7445  	defer c.Cleanup()
  7446  
  7447  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  7448  
  7449  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
  7450  	c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
  7451  	defer c.WQ.EventUnregister(&waitEntry)
  7452  
  7453  	origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
  7454  
  7455  	// Ensure that on the next retransmit timer fire, the user timeout has
  7456  	// expired.
  7457  	initRTO := 1 * time.Second
  7458  	userTimeout := initRTO / 2
  7459  	v := tcpip.TCPUserTimeoutOption(userTimeout)
  7460  	if err := c.EP.SetSockOpt(&v); err != nil {
  7461  		t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err)
  7462  	}
  7463  
  7464  	// Send some data and wait before ACKing it.
  7465  	view := make([]byte, 3)
  7466  	var r bytes.Reader
  7467  	r.Reset(view)
  7468  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
  7469  		t.Fatalf("Write failed: %s", err)
  7470  	}
  7471  
  7472  	next := uint32(c.IRS) + 1
  7473  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  7474  	checker.IPv4(t, c.GetPacket(),
  7475  		checker.PayloadLen(len(view)+header.TCPMinimumSize),
  7476  		checker.TCP(
  7477  			checker.DstPort(context.TestPort),
  7478  			checker.TCPSeqNum(next),
  7479  			checker.TCPAckNum(uint32(iss)),
  7480  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
  7481  		),
  7482  	)
  7483  
  7484  	// Wait for the retransmit timer to be fired and the user timeout to cause
  7485  	// close of the connection.
  7486  	select {
  7487  	case <-notifyCh:
  7488  	case <-time.After(2 * initRTO):
  7489  		t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
  7490  	}
  7491  
  7492  	// No packet should be received as the connection should be silently
  7493  	// closed due to timeout.
  7494  	c.CheckNoPacket("unexpected packet received after userTimeout has expired")
  7495  
  7496  	next += uint32(len(view))
  7497  
  7498  	// The connection should be terminated after userTimeout has expired.
  7499  	// Send an ACK to trigger a RST from the stack as the endpoint should
  7500  	// be dead.
  7501  	c.SendPacket(nil, &context.Headers{
  7502  		SrcPort: context.TestPort,
  7503  		DstPort: c.Port,
  7504  		Flags:   header.TCPFlagAck,
  7505  		SeqNum:  iss,
  7506  		AckNum:  seqnum.Value(next),
  7507  		RcvWnd:  30000,
  7508  	})
  7509  
  7510  	checker.IPv4(t, c.GetPacket(),
  7511  		checker.TCP(
  7512  			checker.DstPort(context.TestPort),
  7513  			checker.TCPSeqNum(next),
  7514  			checker.TCPAckNum(uint32(0)),
  7515  			checker.TCPFlags(header.TCPFlagRst),
  7516  		),
  7517  	)
  7518  
  7519  	ept := endpointTester{c.EP}
  7520  	ept.CheckReadError(t, &tcpip.ErrTimeout{})
  7521  
  7522  	if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
  7523  		t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
  7524  	}
  7525  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
  7526  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
  7527  	}
  7528  }
  7529  
  7530  func TestKeepaliveWithUserTimeout(t *testing.T) {
  7531  	c := context.New(t, defaultMTU)
  7532  	defer c.Cleanup()
  7533  
  7534  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
  7535  
  7536  	origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
  7537  
  7538  	const keepAliveIdle = 100 * time.Millisecond
  7539  	const keepAliveInterval = 3 * time.Second
  7540  	keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle)
  7541  	if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil {
  7542  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err)
  7543  	}
  7544  	keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval)
  7545  	if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil {
  7546  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err)
  7547  	}
  7548  	if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
  7549  		t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
  7550  	}
  7551  	c.EP.SocketOptions().SetKeepAlive(true)
  7552  
  7553  	// Set userTimeout to be the duration to be 1 keepalive
  7554  	// probes. Which means that after the first probe is sent
  7555  	// the second one should cause the connection to be
  7556  	// closed due to userTimeout being hit.
  7557  	userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval)
  7558  	if err := c.EP.SetSockOpt(&userTimeout); err != nil {
  7559  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err)
  7560  	}
  7561  
  7562  	// Check that the connection is still alive.
  7563  	ept := endpointTester{c.EP}
  7564  	ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
  7565  
  7566  	// Now receive 1 keepalives, but don't ACK it.
  7567  	b := c.GetPacket()
  7568  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  7569  	checker.IPv4(t, b,
  7570  		checker.TCP(
  7571  			checker.DstPort(context.TestPort),
  7572  			checker.TCPSeqNum(uint32(c.IRS)),
  7573  			checker.TCPAckNum(uint32(iss)),
  7574  			checker.TCPFlags(header.TCPFlagAck),
  7575  		),
  7576  	)
  7577  
  7578  	// Sleep for a litte over the KeepAlive interval to make sure
  7579  	// the timer has time to fire after the last ACK and close the
  7580  	// close the socket.
  7581  	time.Sleep(keepAliveInterval + keepAliveInterval/2)
  7582  
  7583  	// The connection should be closed with a timeout.
  7584  	// Send an ACK to trigger a RST from the stack as the endpoint should
  7585  	// be dead.
  7586  	c.SendPacket(nil, &context.Headers{
  7587  		SrcPort: context.TestPort,
  7588  		DstPort: c.Port,
  7589  		Flags:   header.TCPFlagAck,
  7590  		SeqNum:  iss,
  7591  		AckNum:  c.IRS + 1,
  7592  		RcvWnd:  30000,
  7593  	})
  7594  
  7595  	checker.IPv4(t, c.GetPacket(),
  7596  		checker.TCP(
  7597  			checker.DstPort(context.TestPort),
  7598  			checker.TCPSeqNum(uint32(c.IRS+1)),
  7599  			checker.TCPAckNum(uint32(0)),
  7600  			checker.TCPFlags(header.TCPFlagRst),
  7601  		),
  7602  	)
  7603  
  7604  	ept.CheckReadError(t, &tcpip.ErrTimeout{})
  7605  	if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
  7606  		t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
  7607  	}
  7608  	if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
  7609  		t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
  7610  	}
  7611  }
  7612  
  7613  func TestIncreaseWindowOnRead(t *testing.T) {
  7614  	// This test ensures that the endpoint sends an ack,
  7615  	// after read() when the window grows by more than 1 MSS.
  7616  	c := context.New(t, defaultMTU)
  7617  	defer c.Cleanup()
  7618  
  7619  	const rcvBuf = 65535 * 10
  7620  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
  7621  
  7622  	// Write chunks of ~30000 bytes. It's important that two
  7623  	// payloads make it equal or longer than MSS.
  7624  	remain := rcvBuf * 2
  7625  	sent := 0
  7626  	data := make([]byte, defaultMTU/2)
  7627  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  7628  	for remain > len(data) {
  7629  		c.SendPacket(data, &context.Headers{
  7630  			SrcPort: context.TestPort,
  7631  			DstPort: c.Port,
  7632  			Flags:   header.TCPFlagAck,
  7633  			SeqNum:  iss.Add(seqnum.Size(sent)),
  7634  			AckNum:  c.IRS.Add(1),
  7635  			RcvWnd:  30000,
  7636  		})
  7637  		sent += len(data)
  7638  		remain -= len(data)
  7639  		pkt := c.GetPacket()
  7640  		checker.IPv4(t, pkt,
  7641  			checker.PayloadLen(header.TCPMinimumSize),
  7642  			checker.TCP(
  7643  				checker.DstPort(context.TestPort),
  7644  				checker.TCPSeqNum(uint32(c.IRS)+1),
  7645  				checker.TCPAckNum(uint32(iss)+uint32(sent)),
  7646  				checker.TCPFlags(header.TCPFlagAck),
  7647  			),
  7648  		)
  7649  		// Break once the window drops below defaultMTU/2
  7650  		if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
  7651  			break
  7652  		}
  7653  	}
  7654  
  7655  	// We now have < 1 MSS in the buffer space. Read at least > 2 MSS
  7656  	// worth of data as receive buffer space
  7657  	w := tcpip.LimitedWriter{
  7658  		W: ioutil.Discard,
  7659  		// defaultMTU is a good enough estimate for the MSS used for this
  7660  		// connection.
  7661  		N: defaultMTU * 2,
  7662  	}
  7663  	for w.N != 0 {
  7664  		_, err := c.EP.Read(&w, tcpip.ReadOptions{})
  7665  		if err != nil {
  7666  			t.Fatalf("Read failed: %s", err)
  7667  		}
  7668  	}
  7669  
  7670  	// After reading > MSS worth of data, we surely crossed MSS. See the ack:
  7671  	checker.IPv4(t, c.GetPacket(),
  7672  		checker.PayloadLen(header.TCPMinimumSize),
  7673  		checker.TCP(
  7674  			checker.DstPort(context.TestPort),
  7675  			checker.TCPSeqNum(uint32(c.IRS)+1),
  7676  			checker.TCPAckNum(uint32(iss)+uint32(sent)),
  7677  			checker.TCPWindow(uint16(0xffff)),
  7678  			checker.TCPFlags(header.TCPFlagAck),
  7679  		),
  7680  	)
  7681  }
  7682  
  7683  func TestIncreaseWindowOnBufferResize(t *testing.T) {
  7684  	// This test ensures that the endpoint sends an ack,
  7685  	// after available recv buffer grows to more than 1 MSS.
  7686  	c := context.New(t, defaultMTU)
  7687  	defer c.Cleanup()
  7688  
  7689  	const rcvBuf = 65535 * 10
  7690  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
  7691  
  7692  	// Write chunks of ~30000 bytes. It's important that two
  7693  	// payloads make it equal or longer than MSS.
  7694  	remain := rcvBuf
  7695  	sent := 0
  7696  	data := make([]byte, defaultMTU/2)
  7697  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  7698  	for remain > len(data) {
  7699  		c.SendPacket(data, &context.Headers{
  7700  			SrcPort: context.TestPort,
  7701  			DstPort: c.Port,
  7702  			Flags:   header.TCPFlagAck,
  7703  			SeqNum:  iss.Add(seqnum.Size(sent)),
  7704  			AckNum:  c.IRS.Add(1),
  7705  			RcvWnd:  30000,
  7706  		})
  7707  		sent += len(data)
  7708  		remain -= len(data)
  7709  		checker.IPv4(t, c.GetPacket(),
  7710  			checker.PayloadLen(header.TCPMinimumSize),
  7711  			checker.TCP(
  7712  				checker.DstPort(context.TestPort),
  7713  				checker.TCPSeqNum(uint32(c.IRS)+1),
  7714  				checker.TCPAckNum(uint32(iss)+uint32(sent)),
  7715  				checker.TCPWindowLessThanEq(0xffff),
  7716  				checker.TCPFlags(header.TCPFlagAck),
  7717  			),
  7718  		)
  7719  	}
  7720  
  7721  	// Increasing the buffer from should generate an ACK,
  7722  	// since window grew from small value to larger equal MSS
  7723  	c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true)
  7724  	checker.IPv4(t, c.GetPacket(),
  7725  		checker.PayloadLen(header.TCPMinimumSize),
  7726  		checker.TCP(
  7727  			checker.DstPort(context.TestPort),
  7728  			checker.TCPSeqNum(uint32(c.IRS)+1),
  7729  			checker.TCPAckNum(uint32(iss)+uint32(sent)),
  7730  			checker.TCPWindow(uint16(0xffff)),
  7731  			checker.TCPFlags(header.TCPFlagAck),
  7732  		),
  7733  	)
  7734  }
  7735  
  7736  func TestTCPDeferAccept(t *testing.T) {
  7737  	c := context.New(t, defaultMTU)
  7738  	defer c.Cleanup()
  7739  
  7740  	c.Create(-1)
  7741  
  7742  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  7743  		t.Fatal("Bind failed:", err)
  7744  	}
  7745  
  7746  	if err := c.EP.Listen(10); err != nil {
  7747  		t.Fatal("Listen failed:", err)
  7748  	}
  7749  
  7750  	const tcpDeferAccept = 1 * time.Second
  7751  	tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
  7752  	if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil {
  7753  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err)
  7754  	}
  7755  
  7756  	irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
  7757  
  7758  	_, _, err := c.EP.Accept(nil)
  7759  	if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
  7760  		t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
  7761  	}
  7762  
  7763  	// Send data. This should result in an acceptable endpoint.
  7764  	c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
  7765  		SrcPort: context.TestPort,
  7766  		DstPort: context.StackPort,
  7767  		Flags:   header.TCPFlagAck,
  7768  		SeqNum:  irs + 1,
  7769  		AckNum:  iss + 1,
  7770  	})
  7771  
  7772  	// Receive ACK for the data we sent.
  7773  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7774  		checker.DstPort(context.TestPort),
  7775  		checker.TCPFlags(header.TCPFlagAck),
  7776  		checker.TCPSeqNum(uint32(iss+1)),
  7777  		checker.TCPAckNum(uint32(irs+5))))
  7778  
  7779  	// Give a bit of time for the socket to be delivered to the accept queue.
  7780  	time.Sleep(50 * time.Millisecond)
  7781  	aep, _, err := c.EP.Accept(nil)
  7782  	if err != nil {
  7783  		t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
  7784  	}
  7785  
  7786  	aep.Close()
  7787  	// Closing aep without reading the data should trigger a RST.
  7788  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7789  		checker.DstPort(context.TestPort),
  7790  		checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
  7791  		checker.TCPSeqNum(uint32(iss+1)),
  7792  		checker.TCPAckNum(uint32(irs+5))))
  7793  }
  7794  
  7795  func TestTCPDeferAcceptTimeout(t *testing.T) {
  7796  	c := context.New(t, defaultMTU)
  7797  	defer c.Cleanup()
  7798  
  7799  	c.Create(-1)
  7800  
  7801  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  7802  		t.Fatal("Bind failed:", err)
  7803  	}
  7804  
  7805  	if err := c.EP.Listen(10); err != nil {
  7806  		t.Fatal("Listen failed:", err)
  7807  	}
  7808  
  7809  	const tcpDeferAccept = 1 * time.Second
  7810  	tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
  7811  	if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil {
  7812  		t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err)
  7813  	}
  7814  
  7815  	irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
  7816  
  7817  	_, _, err := c.EP.Accept(nil)
  7818  	if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
  7819  		t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
  7820  	}
  7821  
  7822  	// Sleep for a little of the tcpDeferAccept timeout.
  7823  	time.Sleep(tcpDeferAccept + 100*time.Millisecond)
  7824  
  7825  	// On timeout expiry we should get a SYN-ACK retransmission.
  7826  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7827  		checker.SrcPort(context.StackPort),
  7828  		checker.DstPort(context.TestPort),
  7829  		checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
  7830  		checker.TCPAckNum(uint32(irs)+1)))
  7831  
  7832  	// Send data. This should result in an acceptable endpoint.
  7833  	c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
  7834  		SrcPort: context.TestPort,
  7835  		DstPort: context.StackPort,
  7836  		Flags:   header.TCPFlagAck,
  7837  		SeqNum:  irs + 1,
  7838  		AckNum:  iss + 1,
  7839  	})
  7840  
  7841  	// Receive ACK for the data we sent.
  7842  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7843  		checker.SrcPort(context.StackPort),
  7844  		checker.DstPort(context.TestPort),
  7845  		checker.TCPFlags(header.TCPFlagAck),
  7846  		checker.TCPSeqNum(uint32(iss+1)),
  7847  		checker.TCPAckNum(uint32(irs+5))))
  7848  
  7849  	// Give sometime for the endpoint to be delivered to the accept queue.
  7850  	time.Sleep(50 * time.Millisecond)
  7851  	aep, _, err := c.EP.Accept(nil)
  7852  	if err != nil {
  7853  		t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
  7854  	}
  7855  
  7856  	aep.Close()
  7857  	// Closing aep without reading the data should trigger a RST.
  7858  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7859  		checker.SrcPort(context.StackPort),
  7860  		checker.DstPort(context.TestPort),
  7861  		checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
  7862  		checker.TCPSeqNum(uint32(iss+1)),
  7863  		checker.TCPAckNum(uint32(irs+5))))
  7864  }
  7865  
  7866  func TestResetDuringClose(t *testing.T) {
  7867  	c := context.New(t, defaultMTU)
  7868  	defer c.Cleanup()
  7869  
  7870  	c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */)
  7871  	// Send some data to make sure there is some unread
  7872  	// data to trigger a reset on c.Close.
  7873  	irs := c.IRS
  7874  	iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
  7875  	c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
  7876  		SrcPort: context.TestPort,
  7877  		DstPort: c.Port,
  7878  		Flags:   header.TCPFlagAck,
  7879  		SeqNum:  iss,
  7880  		AckNum:  irs.Add(1),
  7881  		RcvWnd:  30000,
  7882  	})
  7883  
  7884  	// Receive ACK for the data we sent.
  7885  	checker.IPv4(t, c.GetPacket(), checker.TCP(
  7886  		checker.DstPort(context.TestPort),
  7887  		checker.TCPFlags(header.TCPFlagAck),
  7888  		checker.TCPSeqNum(uint32(irs.Add(1))),
  7889  		checker.TCPAckNum(uint32(iss)+4)))
  7890  
  7891  	// Close in a separate goroutine so that we can trigger
  7892  	// a race with the RST we send below. This should not
  7893  	// panic due to the route being released depeding on
  7894  	// whether Close() sends an active RST or the RST sent
  7895  	// below is processed by the worker first.
  7896  	var wg sync.WaitGroup
  7897  
  7898  	wg.Add(1)
  7899  	go func() {
  7900  		defer wg.Done()
  7901  		c.SendPacket(nil, &context.Headers{
  7902  			SrcPort: context.TestPort,
  7903  			DstPort: c.Port,
  7904  			SeqNum:  iss.Add(4),
  7905  			AckNum:  c.IRS.Add(5),
  7906  			RcvWnd:  30000,
  7907  			Flags:   header.TCPFlagRst,
  7908  		})
  7909  	}()
  7910  
  7911  	wg.Add(1)
  7912  	go func() {
  7913  		defer wg.Done()
  7914  		c.EP.Close()
  7915  	}()
  7916  
  7917  	wg.Wait()
  7918  }
  7919  
  7920  func TestStackTimeWaitReuse(t *testing.T) {
  7921  	c := context.New(t, defaultMTU)
  7922  	defer c.Cleanup()
  7923  
  7924  	s := c.Stack()
  7925  	var twReuse tcpip.TCPTimeWaitReuseOption
  7926  	if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
  7927  		t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
  7928  	}
  7929  	if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
  7930  		t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
  7931  	}
  7932  }
  7933  
  7934  func TestSetStackTimeWaitReuse(t *testing.T) {
  7935  	c := context.New(t, defaultMTU)
  7936  	defer c.Cleanup()
  7937  
  7938  	s := c.Stack()
  7939  	testCases := []struct {
  7940  		v   int
  7941  		err tcpip.Error
  7942  	}{
  7943  		{int(tcpip.TCPTimeWaitReuseDisabled), nil},
  7944  		{int(tcpip.TCPTimeWaitReuseGlobal), nil},
  7945  		{int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
  7946  		{int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}},
  7947  		{int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}},
  7948  	}
  7949  
  7950  	for _, tc := range testCases {
  7951  		opt := tcpip.TCPTimeWaitReuseOption(tc.v)
  7952  		err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)
  7953  		if got, want := err, tc.err; got != want {
  7954  			t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err)
  7955  		}
  7956  		if tc.err != nil {
  7957  			continue
  7958  		}
  7959  
  7960  		var twReuse tcpip.TCPTimeWaitReuseOption
  7961  		if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
  7962  			t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
  7963  		}
  7964  
  7965  		if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
  7966  			t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
  7967  		}
  7968  	}
  7969  }
  7970  
  7971  // generateRandomPayload generates a random byte slice of the specified length
  7972  // causing a fatal test failure if it is unable to do so.
  7973  func generateRandomPayload(t *testing.T, n int) []byte {
  7974  	t.Helper()
  7975  	buf := make([]byte, n)
  7976  	if _, err := rand.Read(buf); err != nil {
  7977  		t.Fatalf("rand.Read(buf) failed: %s", err)
  7978  	}
  7979  	return buf
  7980  }