gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/tcp/test/e2e/tcp_sack_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp_sack_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"log"
    21  	"os"
    22  	"slices"
    23  	"testing"
    24  	"time"
    25  
    26  	"gvisor.dev/gvisor/pkg/buffer"
    27  	"gvisor.dev/gvisor/pkg/refs"
    28  	"gvisor.dev/gvisor/pkg/tcpip"
    29  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    30  	"gvisor.dev/gvisor/pkg/tcpip/header"
    31  	"gvisor.dev/gvisor/pkg/tcpip/seqnum"
    32  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    33  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    34  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/test/e2e"
    35  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
    36  	"gvisor.dev/gvisor/pkg/test/testutil"
    37  )
    38  
    39  const (
    40  	maxPayload   = 10
    41  	tsOptionSize = 12
    42  	mtu          = header.TCPMinimumSize + header.IPv4MinimumSize + e2e.MaxTCPOptionSize + maxPayload
    43  )
    44  
    45  // TestSackPermittedConnect establishes a connection with the SACK option
    46  // enabled.
    47  func TestSackPermittedConnect(t *testing.T) {
    48  	for _, sackEnabled := range []bool{false, true} {
    49  		t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
    50  			c := context.New(t, e2e.DefaultMTU)
    51  			defer c.Cleanup()
    52  
    53  			e2e.SetStackSACKPermitted(t, c, sackEnabled)
    54  			e2e.SetStackTCPRecovery(t, c, 0)
    55  			rep := e2e.CreateConnectedWithSACKPermittedOption(c)
    56  			data := []byte{1, 2, 3}
    57  
    58  			rep.SendPacket(data, nil)
    59  			savedSeqNum := rep.NextSeqNum
    60  			rep.VerifyACKNoSACK()
    61  
    62  			// Make an out of order packet and send it.
    63  			rep.NextSeqNum += 3
    64  			sackBlocks := []header.SACKBlock{
    65  				{rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
    66  			}
    67  			rep.SendPacket(data, nil)
    68  
    69  			// Restore the saved sequence number so that the
    70  			// VerifyXXX calls use the right sequence number for
    71  			// checking ACK numbers.
    72  			rep.NextSeqNum = savedSeqNum
    73  			if sackEnabled {
    74  				rep.VerifyACKHasSACK(sackBlocks)
    75  			} else {
    76  				rep.VerifyACKNoSACK()
    77  			}
    78  
    79  			// Send the missing segment.
    80  			rep.SendPacket(data, nil)
    81  			// The ACK should contain the cumulative ACK for all 9
    82  			// bytes sent and no SACK blocks.
    83  			rep.NextSeqNum += 3
    84  			// Check that no SACK block is returned in the ACK.
    85  			rep.VerifyACKNoSACK()
    86  		})
    87  	}
    88  }
    89  
    90  // TestSackDisabledConnect establishes a connection with the SACK option
    91  // disabled and verifies that no SACKs are sent for out of order segments.
    92  func TestSackDisabledConnect(t *testing.T) {
    93  	for _, sackEnabled := range []bool{false, true} {
    94  		t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
    95  			c := context.New(t, e2e.DefaultMTU)
    96  			defer c.Cleanup()
    97  
    98  			e2e.SetStackSACKPermitted(t, c, sackEnabled)
    99  			e2e.SetStackTCPRecovery(t, c, 0)
   100  
   101  			rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
   102  
   103  			data := []byte{1, 2, 3}
   104  
   105  			rep.SendPacket(data, nil)
   106  			savedSeqNum := rep.NextSeqNum
   107  			rep.VerifyACKNoSACK()
   108  
   109  			// Make an out of order packet and send it.
   110  			rep.NextSeqNum += 3
   111  			rep.SendPacket(data, nil)
   112  
   113  			// The ACK should contain the older sequence number and
   114  			// no SACK blocks.
   115  			rep.NextSeqNum = savedSeqNum
   116  			rep.VerifyACKNoSACK()
   117  
   118  			// Send the missing segment.
   119  			rep.SendPacket(data, nil)
   120  			// The ACK should contain the cumulative ACK for all 9
   121  			// bytes sent and no SACK blocks.
   122  			rep.NextSeqNum += 3
   123  			// Check that no SACK block is returned in the ACK.
   124  			rep.VerifyACKNoSACK()
   125  		})
   126  	}
   127  }
   128  
   129  // TestSackPermittedAccept accepts and establishes a connection with the
   130  // SACKPermitted option enabled if the connection request specifies the
   131  // SACKPermitted option. In case of SYN cookies SACK should be disabled as we
   132  // don't encode the SACK information in the cookie.
   133  func TestSackPermittedAccept(t *testing.T) {
   134  	type testCase struct {
   135  		cookieEnabled bool
   136  		sackPermitted bool
   137  		wndScale      int
   138  		wndSize       uint16
   139  	}
   140  
   141  	testCases := []testCase{
   142  		// When cookie is used window scaling is disabled.
   143  		{true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
   144  		{false, true, 5, 0x8000},  // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
   145  	}
   146  
   147  	for _, tc := range testCases {
   148  		t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
   149  			for _, sackEnabled := range []bool{false, true} {
   150  				t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
   151  					c := context.New(t, e2e.DefaultMTU)
   152  					defer c.Cleanup()
   153  
   154  					if tc.cookieEnabled {
   155  						opt := tcpip.TCPAlwaysUseSynCookies(true)
   156  						if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
   157  							t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
   158  						}
   159  					}
   160  					e2e.SetStackSACKPermitted(t, c, sackEnabled)
   161  					e2e.SetStackTCPRecovery(t, c, 0)
   162  
   163  					rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS, SACKPermitted: tc.sackPermitted})
   164  					//  Now verify no SACK blocks are
   165  					//  received when sack is disabled.
   166  					data := []byte{1, 2, 3}
   167  					rep.SendPacket(data, nil)
   168  					rep.VerifyACKNoSACK()
   169  
   170  					savedSeqNum := rep.NextSeqNum
   171  
   172  					// Make an out of order packet and send
   173  					// it.
   174  					rep.NextSeqNum += 3
   175  					sackBlocks := []header.SACKBlock{
   176  						{rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
   177  					}
   178  					rep.SendPacket(data, nil)
   179  
   180  					// The ACK should contain the older
   181  					// sequence number.
   182  					rep.NextSeqNum = savedSeqNum
   183  					if sackEnabled && tc.sackPermitted {
   184  						rep.VerifyACKHasSACK(sackBlocks)
   185  					} else {
   186  						rep.VerifyACKNoSACK()
   187  					}
   188  
   189  					// Send the missing segment.
   190  					rep.SendPacket(data, nil)
   191  					// The ACK should contain the cumulative
   192  					// ACK for all 9 bytes sent and no SACK
   193  					// blocks.
   194  					rep.NextSeqNum += 3
   195  					// Check that no SACK block is returned
   196  					// in the ACK.
   197  					rep.VerifyACKNoSACK()
   198  				})
   199  			}
   200  		})
   201  	}
   202  }
   203  
   204  // TestSackDisabledAccept accepts and establishes a connection with
   205  // the SACKPermitted option disabled and verifies that no SACKs are
   206  // sent for out of order packets.
   207  func TestSackDisabledAccept(t *testing.T) {
   208  	type testCase struct {
   209  		cookieEnabled bool
   210  		wndScale      int
   211  		wndSize       uint16
   212  	}
   213  
   214  	testCases := []testCase{
   215  		// When cookie is used window scaling is disabled.
   216  		{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
   217  		{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
   218  	}
   219  
   220  	for _, tc := range testCases {
   221  		t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
   222  			for _, sackEnabled := range []bool{false, true} {
   223  				t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
   224  					c := context.New(t, e2e.DefaultMTU)
   225  					defer c.Cleanup()
   226  
   227  					if tc.cookieEnabled {
   228  						opt := tcpip.TCPAlwaysUseSynCookies(true)
   229  						if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
   230  							t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
   231  						}
   232  					}
   233  
   234  					e2e.SetStackSACKPermitted(t, c, sackEnabled)
   235  					e2e.SetStackTCPRecovery(t, c, 0)
   236  
   237  					rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS})
   238  
   239  					//  Now verify no SACK blocks are
   240  					//  received when sack is disabled.
   241  					data := []byte{1, 2, 3}
   242  					rep.SendPacket(data, nil)
   243  					rep.VerifyACKNoSACK()
   244  					savedSeqNum := rep.NextSeqNum
   245  
   246  					// Make an out of order packet and send
   247  					// it.
   248  					rep.NextSeqNum += 3
   249  					rep.SendPacket(data, nil)
   250  
   251  					// The ACK should contain the older
   252  					// sequence number and no SACK blocks.
   253  					rep.NextSeqNum = savedSeqNum
   254  					rep.VerifyACKNoSACK()
   255  
   256  					// Send the missing segment.
   257  					rep.SendPacket(data, nil)
   258  					// The ACK should contain the cumulative
   259  					// ACK for all 9 bytes sent and no SACK
   260  					// blocks.
   261  					rep.NextSeqNum += 3
   262  					// Check that no SACK block is returned
   263  					// in the ACK.
   264  					rep.VerifyACKNoSACK()
   265  				})
   266  			}
   267  		})
   268  	}
   269  }
   270  
   271  func TestUpdateSACKBlocks(t *testing.T) {
   272  	testCases := []struct {
   273  		segStart   seqnum.Value
   274  		segEnd     seqnum.Value
   275  		rcvNxt     seqnum.Value
   276  		sackBlocks []header.SACKBlock
   277  		updated    []header.SACKBlock
   278  	}{
   279  		// Trivial cases where current SACK block list is empty and we
   280  		// have an out of order delivery.
   281  		{10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
   282  		{10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
   283  		{10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
   284  
   285  		// Cases where current SACK block list is not empty and we have
   286  		// an out of order delivery. Tests that the updated SACK block
   287  		// list has the first block as the one that contains the new
   288  		// SACK block representing the segment that was just delivered.
   289  		{10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
   290  		{24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
   291  		{24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
   292  
   293  		// Ensure that we only retain header.MaxSACKBlocks and drop the
   294  		// oldest one if adding a new block exceeds
   295  		// header.MaxSACKBlocks.
   296  		{24, 30, 9,
   297  			[]header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
   298  			[]header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
   299  
   300  		// Cases where segment extends an existing SACK block.
   301  		{10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
   302  		{10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
   303  		{10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
   304  		{15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
   305  		{15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
   306  		{11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
   307  		{10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
   308  		{10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
   309  		{10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
   310  		{15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
   311  		{15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
   312  		{11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
   313  
   314  		// Cases where segment contains rcvNxt.
   315  		{10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
   316  	}
   317  
   318  	for _, tc := range testCases {
   319  		var sack tcp.SACKInfo
   320  		copy(sack.Blocks[:], tc.sackBlocks)
   321  		sack.NumBlocks = len(tc.sackBlocks)
   322  		tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
   323  		if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !slices.Equal(got, want) {
   324  			t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
   325  		}
   326  
   327  	}
   328  }
   329  
   330  func TestTrimSackBlockList(t *testing.T) {
   331  	testCases := []struct {
   332  		rcvNxt     seqnum.Value
   333  		sackBlocks []header.SACKBlock
   334  		trimmed    []header.SACKBlock
   335  	}{
   336  		// Simple cases where we trim whole entries.
   337  		{2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
   338  		{21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
   339  		{31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
   340  		{40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
   341  		// Cases where we need to update a block.
   342  		{12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
   343  		{23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
   344  		{33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
   345  		{41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
   346  	}
   347  	for _, tc := range testCases {
   348  		var sack tcp.SACKInfo
   349  		copy(sack.Blocks[:], tc.sackBlocks)
   350  		sack.NumBlocks = len(tc.sackBlocks)
   351  		tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
   352  		if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !slices.Equal(got, want) {
   353  			t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
   354  		}
   355  	}
   356  }
   357  
   358  func TestSACKRecovery(t *testing.T) {
   359  	const maxPayload = 10
   360  	// See: tcp.makeOptions for why tsOptionSize is set to 12 here.
   361  	const tsOptionSize = 12
   362  	// Enabling SACK means the payload size is reduced to account
   363  	// for the extra space required for the TCP options.
   364  	//
   365  	// We increase the MTU by e2e.MaxTCPOptionSize bytes to account for SACK
   366  	// and Timestamp options.
   367  	c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+e2e.MaxTCPOptionSize+maxPayload))
   368  	defer c.Cleanup()
   369  
   370  	c.Stack().AddTCPProbe(func(s *stack.TCPEndpointState) {
   371  		// We use log.Printf instead of t.Logf here because this probe
   372  		// can fire even when the test function has finished. This is
   373  		// because closing the endpoint in cleanup() does not mean the
   374  		// actual worker loop terminates immediately as it still has to
   375  		// do a full TCP shutdown. But this test can finish running
   376  		// before the shutdown is done. Using t.Logf in such a case
   377  		// causes the test to panic due to logging after test finished.
   378  		log.Printf("state: %+v\n", s)
   379  	})
   380  	e2e.SetStackSACKPermitted(t, c, true)
   381  	e2e.SetStackTCPRecovery(t, c, 0)
   382  	e2e.CreateConnectedWithSACKAndTS(c)
   383  
   384  	const iterations = 3
   385  	data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
   386  	for i := range data {
   387  		data[i] = byte(i)
   388  	}
   389  
   390  	// Write all the data in one shot. Packets will only be written at the
   391  	// MTU size though.
   392  	var r bytes.Reader
   393  	r.Reset(data)
   394  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
   395  		t.Fatalf("Write failed: %s", err)
   396  	}
   397  
   398  	// Do slow start for a few iterations.
   399  	seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   400  	expected := tcp.InitialCwnd
   401  	bytesRead := 0
   402  	for i := 0; i < iterations; i++ {
   403  		expected = tcp.InitialCwnd << uint(i)
   404  		if i > 0 {
   405  			// Acknowledge all the data received so far if not on
   406  			// first iteration.
   407  			c.SendAck(seq, bytesRead)
   408  		}
   409  
   410  		// Read all packets expected on this iteration. Don't
   411  		// acknowledge any of them just yet, so that we can measure the
   412  		// congestion window.
   413  		for j := 0; j < expected; j++ {
   414  			c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
   415  			bytesRead += maxPayload
   416  		}
   417  
   418  		// Check we don't receive any more packets on this iteration.
   419  		// The timeout can't be too high or we'll trigger a timeout.
   420  		c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
   421  	}
   422  
   423  	// Send 3 duplicate acks. This should force an immediate retransmit of
   424  	// the pending packet and put the sender into fast recovery.
   425  	rtxOffset := bytesRead - maxPayload*expected
   426  	start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
   427  	end := start.Add(10)
   428  	for i := 0; i < 3; i++ {
   429  		c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
   430  		end = end.Add(10)
   431  	}
   432  
   433  	// Receive the retransmitted packet.
   434  	c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
   435  
   436  	metricPollFn := func() error {
   437  		tcpStats := c.Stack().Stats().TCP
   438  		stats := []struct {
   439  			stat *tcpip.StatCounter
   440  			name string
   441  			want uint64
   442  		}{
   443  			{tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
   444  			{tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
   445  			{tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
   446  			{tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
   447  		}
   448  		for _, s := range stats {
   449  			if got, want := s.stat.Value(), s.want; got != want {
   450  				return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
   451  			}
   452  		}
   453  		return nil
   454  	}
   455  
   456  	if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
   457  		t.Error(err)
   458  	}
   459  
   460  	// Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
   461  	// window inflation and sending of packets is completely handled by the
   462  	// SACK Recovery algorithm. We should see no packets being released, as
   463  	// the cwnd at this point after entering recovery should be half of the
   464  	// outstanding number of packets in flight.
   465  	for i := 0; i < 7; i++ {
   466  		c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
   467  		end = end.Add(10)
   468  	}
   469  
   470  	recover := bytesRead
   471  
   472  	// Ensure no new packets arrive.
   473  	c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
   474  		50*time.Millisecond)
   475  
   476  	// Acknowledge half of the pending data. This along with the 10 sacked
   477  	// segments above should reduce the outstanding below the current
   478  	// congestion window allowing the sender to transmit data.
   479  	rtxOffset = bytesRead - expected*maxPayload/2
   480  
   481  	// Now send a partial ACK w/ a SACK block that indicates that the next 3
   482  	// segments are lost and we have received 6 segments after the lost
   483  	// segments. This should cause the sender to immediately transmit all 3
   484  	// segments in response to this ACK unlike in FastRecovery where only 1
   485  	// segment is retransmitted per ACK.
   486  	start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
   487  	end = start.Add(60)
   488  	c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
   489  
   490  	// At this point, we acked expected/2 packets and we SACKED 6 packets and
   491  	// 3 segments were considered lost due to the SACK block we sent.
   492  	//
   493  	// So total packets outstanding can be calculated as follows after 7
   494  	// iterations of slow start -> 10/20/40/80/160/320/640. So expected
   495  	// should be 640 at start, then we went to recover at which point the
   496  	// cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the
   497  	// network).
   498  	// Outstanding at this point after acking half the window
   499  	// (320 packets) will be:
   500  	//    outstanding = 640-320-6(due to SACK block)-3 = 311
   501  	//
   502  	// The last 3 is due to the fact that the first 3 packets after
   503  	// rtxOffset will be considered lost due to the SACK blocks sent.
   504  	// Receive the retransmit due to partial ack.
   505  
   506  	c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
   507  	// Receive the 2 extra packets that should have been retransmitted as
   508  	// those should be considered lost and immediately retransmitted based
   509  	// on the SACK information in the previous ACK sent above.
   510  	for i := 0; i < 2; i++ {
   511  		c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize)
   512  	}
   513  
   514  	// Now we should get 9 more new unsent packets as the cwnd is 323 and
   515  	// outstanding is 311.
   516  	for i := 0; i < 9; i++ {
   517  		c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
   518  		bytesRead += maxPayload
   519  	}
   520  
   521  	metricPollFn = func() error {
   522  		// In SACK recovery only the first segment is fast retransmitted when
   523  		// entering recovery.
   524  		if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
   525  			return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
   526  		}
   527  
   528  		if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
   529  			return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
   530  		}
   531  
   532  		if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
   533  			return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
   534  		}
   535  
   536  		if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
   537  			return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
   538  		}
   539  		return nil
   540  	}
   541  	if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
   542  		t.Error(err)
   543  	}
   544  
   545  	c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
   546  
   547  	// Acknowledge all pending data to recover point.
   548  	c.SendAck(seq, recover)
   549  
   550  	// At this point, the cwnd should reset to expected/2 and there are 9
   551  	// packets outstanding.
   552  	//
   553  	// Now in the first iteration since there are 9 packets outstanding.
   554  	// We would expect to get expected/2  - 9 packets. But subsequent
   555  	// iterations will send us expected/2  + 1 (per iteration).
   556  	expected = expected/2 - 9
   557  	for i := 0; i < iterations; i++ {
   558  		// Read all packets expected on this iteration. Don't
   559  		// acknowledge any of them just yet, so that we can measure the
   560  		// congestion window.
   561  		for j := 0; j < expected; j++ {
   562  			c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
   563  			bytesRead += maxPayload
   564  		}
   565  		// Check we don't receive any more packets on this iteration.
   566  		// The timeout can't be too high or we'll trigger a timeout.
   567  		c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond)
   568  
   569  		// Acknowledge all the data received so far.
   570  		c.SendAck(seq, bytesRead)
   571  
   572  		// In cogestion avoidance, the packets trains increase by 1 in
   573  		// each iteration.
   574  		if i == 0 {
   575  			// After the first iteration we expect to get the full
   576  			// congestion window worth of packets in every
   577  			// iteration.
   578  			expected += 9
   579  		}
   580  		expected++
   581  	}
   582  }
   583  
   584  // TestRecoveryEntry tests the following two properties of entering recovery:
   585  //   - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK
   586  //     scoreboard but dupack count is still below threshold.
   587  //   - Only enter recovery when at least one more byte of data beyond the highest
   588  //     byte that was outstanding when fast retransmit was last entered is acked.
   589  func TestRecoveryEntry(t *testing.T) {
   590  	c := context.New(t, uint32(mtu))
   591  	defer c.Cleanup()
   592  
   593  	numPackets := 5
   594  	data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, false /* enableRACK */)
   595  
   596  	// Ack #1 packet.
   597  	seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   598  	c.SendAck(seq, maxPayload)
   599  
   600  	// Now SACK #3, #4 and #5 packets. This will simulate a situation where
   601  	// SND.UNA should be considered lost and the sender should enter fast recovery
   602  	// (even though dupack count is still below threshold).
   603  	p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
   604  	p3End := p3Start.Add(maxPayload)
   605  	p4Start := p3End
   606  	p4End := p4Start.Add(maxPayload)
   607  	p5Start := p4End
   608  	p5End := p5Start.Add(maxPayload)
   609  	c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}})
   610  
   611  	// Expect #2 to be retransmitted.
   612  	c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize)
   613  
   614  	metricPollFn := func() error {
   615  		tcpStats := c.Stack().Stats().TCP
   616  		stats := []struct {
   617  			stat *tcpip.StatCounter
   618  			name string
   619  			want uint64
   620  		}{
   621  			// SACK recovery must have happened.
   622  			{tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
   623  			{tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
   624  			// #2 was retransmitted.
   625  			{tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
   626  			// No RTOs should have fired yet.
   627  			{tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
   628  		}
   629  		for _, s := range stats {
   630  			if got, want := s.stat.Value(), s.want; got != want {
   631  				return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
   632  			}
   633  		}
   634  		return nil
   635  	}
   636  	if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
   637  		t.Error(err)
   638  	}
   639  
   640  	// Send 4 more packets.
   641  	var r bytes.Reader
   642  	data = append(data, data...)
   643  	r.Reset(data[5*maxPayload : 9*maxPayload])
   644  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
   645  		t.Fatalf("Write failed: %s", err)
   646  	}
   647  
   648  	var sackBlocks []header.SACKBlock
   649  	bytesRead := numPackets * maxPayload
   650  	for i := 0; i < 4; i++ {
   651  		c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
   652  		if i > 0 {
   653  			pStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
   654  			sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)})
   655  			c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks)
   656  		}
   657  		bytesRead += maxPayload
   658  	}
   659  
   660  	// #6 should be retransmitted after RTO. The sender should NOT enter fast
   661  	// recovery because the highest byte that was outstanding when fast recovery
   662  	// was last entered is #5 packet's end. And the sender requires at least one
   663  	// more byte beyond that (#6 packet start) to be acked to enter recovery.
   664  	c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize)
   665  	c.SendAck(seq, 9*maxPayload)
   666  
   667  	metricPollFn = func() error {
   668  		tcpStats := c.Stack().Stats().TCP
   669  		stats := []struct {
   670  			stat *tcpip.StatCounter
   671  			name string
   672  			want uint64
   673  		}{
   674  			// Only 1 SACK recovery must have happened.
   675  			{tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
   676  			{tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
   677  			// #2 and #6 were retransmitted.
   678  			{tcpStats.Retransmits, "stats.TCP.Retransmits", 2},
   679  			// RTO should have fired once.
   680  			{tcpStats.Timeouts, "stats.TCP.Timeouts", 1},
   681  		}
   682  		for _, s := range stats {
   683  			if got, want := s.stat.Value(), s.want; got != want {
   684  				return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
   685  			}
   686  		}
   687  		return nil
   688  	}
   689  	if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
   690  		t.Error(err)
   691  	}
   692  }
   693  
   694  func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery, numSpuriousRTO uint64) {
   695  	t.Helper()
   696  
   697  	metricPollFn := func() error {
   698  		tcpStats := c.Stack().Stats().TCP
   699  		stats := []struct {
   700  			stat *tcpip.StatCounter
   701  			name string
   702  			want uint64
   703  		}{
   704  			{tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery},
   705  			{tcpStats.SpuriousRTORecovery, "stats.TCP.SpuriousRTORecovery", numSpuriousRTO},
   706  		}
   707  		for _, s := range stats {
   708  			if got, want := s.stat.Value(), s.want; got != want {
   709  				return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
   710  			}
   711  		}
   712  		return nil
   713  	}
   714  
   715  	if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
   716  		t.Error(err)
   717  	}
   718  }
   719  
   720  func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b *buffer.View, data []byte) {
   721  	payloadLen := uint32(len(tcpHdr.Payload()))
   722  	checker.IPv4(t, b,
   723  		checker.TCP(
   724  			checker.DstPort(context.TestPort),
   725  			checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead),
   726  			checker.TCPAckNum(context.TestInitialSequenceNumber+1),
   727  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
   728  		),
   729  	)
   730  	pdata := data[bytesRead : bytesRead+payloadLen]
   731  	if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
   732  		t.Fatalf("got data = %v, want = %v", p, pdata)
   733  	}
   734  }
   735  
   736  func buildTSOptionFromHeader(tcpHdr header.TCP) []byte {
   737  	parsedOpts := tcpHdr.ParsedOptions()
   738  	tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
   739  	header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
   740  	return tsOpt[:]
   741  }
   742  
   743  func TestDetectSpuriousRecoveryWithRTO(t *testing.T) {
   744  	c := context.New(t, uint32(mtu))
   745  	defer c.Cleanup()
   746  
   747  	probeDone := make(chan struct{})
   748  	c.Stack().AddTCPProbe(func(s *stack.TCPEndpointState) {
   749  		if s.Sender.RetransmitTS == 0 {
   750  			t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
   751  		}
   752  		if !s.Sender.SpuriousRecovery {
   753  			t.Fatalf("Spurious recovery was not detected")
   754  		}
   755  		close(probeDone)
   756  	})
   757  
   758  	e2e.SetStackSACKPermitted(t, c, true)
   759  	e2e.CreateConnectedWithSACKAndTS(c)
   760  	numPackets := 5
   761  	data := make([]byte, numPackets*maxPayload)
   762  	for i := range data {
   763  		data[i] = byte(i)
   764  	}
   765  	// Write the data.
   766  	var r bytes.Reader
   767  	r.Reset(data)
   768  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
   769  		t.Fatalf("Write failed: %s", err)
   770  	}
   771  
   772  	var options []byte
   773  	var bytesRead uint32
   774  	for i := 0; i < numPackets; i++ {
   775  		b := c.GetPacket()
   776  		defer b.Release()
   777  		tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload())
   778  		checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
   779  
   780  		// Get options only for the first packet. This will be sent with
   781  		// the ACK to indicate the acknowledgement is for the original
   782  		// packet.
   783  		if i == 0 && c.TimeStampEnabled {
   784  			options = buildTSOptionFromHeader(tcpHdr)
   785  		}
   786  		bytesRead += uint32(len(tcpHdr.Payload()))
   787  	}
   788  
   789  	seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   790  	// Expect #5 segment with TLP.
   791  	c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
   792  
   793  	// Expect #1 segment because of RTO.
   794  	c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
   795  
   796  	info := tcpip.TCPInfoOption{}
   797  	if err := c.EP.GetSockOpt(&info); err != nil {
   798  		t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
   799  	}
   800  
   801  	if info.CcState != tcpip.RTORecovery {
   802  		t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery)
   803  	}
   804  
   805  	// Acknowledge the data.
   806  	rcvWnd := seqnum.Size(30000)
   807  	c.SendPacket(nil, &context.Headers{
   808  		SrcPort: context.TestPort,
   809  		DstPort: c.Port,
   810  		Flags:   header.TCPFlagAck,
   811  		SeqNum:  seq,
   812  		AckNum:  c.IRS.Add(1 + seqnum.Size(maxPayload)),
   813  		RcvWnd:  rcvWnd,
   814  		TCPOpts: options,
   815  	})
   816  
   817  	// Wait for the probe function to finish processing the
   818  	// ACK before the test completes.
   819  	<-probeDone
   820  
   821  	verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */, 1 /* numSpuriousRTO */)
   822  }
   823  
   824  func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) {
   825  	c := context.New(t, uint32(mtu))
   826  	defer c.Cleanup()
   827  
   828  	numAck := 0
   829  	probeDone := make(chan struct{})
   830  	c.Stack().AddTCPProbe(func(s *stack.TCPEndpointState) {
   831  		if numAck < 3 {
   832  			numAck++
   833  			return
   834  		}
   835  
   836  		if s.Sender.RetransmitTS == 0 {
   837  			t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
   838  		}
   839  		if !s.Sender.SpuriousRecovery {
   840  			t.Fatalf("Spurious recovery was not detected")
   841  		}
   842  		close(probeDone)
   843  	})
   844  
   845  	e2e.SetStackSACKPermitted(t, c, true)
   846  	e2e.CreateConnectedWithSACKAndTS(c)
   847  	numPackets := 5
   848  	data := make([]byte, numPackets*maxPayload)
   849  	for i := range data {
   850  		data[i] = byte(i)
   851  	}
   852  	// Write the data.
   853  	var r bytes.Reader
   854  	r.Reset(data)
   855  	if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
   856  		t.Fatalf("Write failed: %s", err)
   857  	}
   858  
   859  	var options []byte
   860  	var bytesRead uint32
   861  	for i := 0; i < numPackets; i++ {
   862  		b := c.GetPacket()
   863  		defer b.Release()
   864  		tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload())
   865  		checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
   866  
   867  		// Get options only for the first packet. This will be sent with
   868  		// the ACK to indicate the acknowledgement is for the original
   869  		// packet.
   870  		if i == 0 && c.TimeStampEnabled {
   871  			options = buildTSOptionFromHeader(tcpHdr)
   872  		}
   873  		bytesRead += uint32(len(tcpHdr.Payload()))
   874  	}
   875  
   876  	// Receive the retransmitted packet after TLP.
   877  	c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
   878  
   879  	seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   880  	// Send ACK for #3 and #4 segments to avoid entering TLP.
   881  	start := c.IRS.Add(3*maxPayload + 1)
   882  	end := start.Add(2 * maxPayload)
   883  	c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
   884  
   885  	c.SendAck(seq, 0 /* bytesReceived */)
   886  	c.SendAck(seq, 0 /* bytesReceived */)
   887  
   888  	// Receive the retransmitted packet after three duplicate ACKs.
   889  	c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
   890  
   891  	info := tcpip.TCPInfoOption{}
   892  	if err := c.EP.GetSockOpt(&info); err != nil {
   893  		t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
   894  	}
   895  
   896  	if info.CcState != tcpip.SACKRecovery {
   897  		t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery)
   898  	}
   899  
   900  	// Acknowledge the data.
   901  	rcvWnd := seqnum.Size(30000)
   902  	c.SendPacket(nil, &context.Headers{
   903  		SrcPort: context.TestPort,
   904  		DstPort: c.Port,
   905  		Flags:   header.TCPFlagAck,
   906  		SeqNum:  seq,
   907  		AckNum:  c.IRS.Add(1 + seqnum.Size(maxPayload)),
   908  		RcvWnd:  rcvWnd,
   909  		TCPOpts: options,
   910  	})
   911  
   912  	// Wait for the probe function to finish processing the
   913  	// ACK before the test completes.
   914  	<-probeDone
   915  
   916  	verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */, 0 /* numSpuriousRTO */)
   917  }
   918  
   919  func TestNoSpuriousRecoveryWithDSACK(t *testing.T) {
   920  	c := context.New(t, uint32(mtu))
   921  	defer c.Cleanup()
   922  	e2e.SetStackSACKPermitted(t, c, true)
   923  	e2e.CreateConnectedWithSACKAndTS(c)
   924  	numPackets := 5
   925  	data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */)
   926  
   927  	// Receive the retransmitted packet after TLP.
   928  	c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
   929  
   930  	// Send ACK for #3 and #4 segments to avoid entering TLP.
   931  	start := c.IRS.Add(3*maxPayload + 1)
   932  	end := start.Add(2 * maxPayload)
   933  	seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   934  	c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
   935  
   936  	c.SendAck(seq, 0 /* bytesReceived */)
   937  	c.SendAck(seq, 0 /* bytesReceived */)
   938  
   939  	// Receive the retransmitted packet after three duplicate ACKs.
   940  	c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
   941  
   942  	// Acknowledge the data with DSACK for #1 segment.
   943  	start = c.IRS.Add(maxPayload + 1)
   944  	end = start.Add(2 * maxPayload)
   945  	seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
   946  	c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}})
   947  
   948  	verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */, 0 /* numSpuriousRTO */)
   949  }
   950  
   951  func TestMain(m *testing.M) {
   952  	refs.SetLeakMode(refs.LeaksPanic)
   953  	code := m.Run()
   954  	// Allow TCP async work to complete to avoid false reports of leaks.
   955  	// TODO(gvisor.dev/issue/5940): Use fake clock in tests.
   956  	time.Sleep(1 * time.Second)
   957  	refs.DoLeakCheck()
   958  	os.Exit(code)
   959  }