gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/packetimpact/tests/tcp_zero_receive_window_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp_zero_receive_window_test
    16  
    17  import (
    18  	"flag"
    19  	"fmt"
    20  	"math"
    21  	"testing"
    22  	"time"
    23  
    24  	"golang.org/x/sys/unix"
    25  	"gvisor.dev/gvisor/pkg/tcpip/header"
    26  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    27  	"gvisor.dev/gvisor/test/packetimpact/testbench"
    28  )
    29  
    30  func init() {
    31  	testbench.Initialize(flag.CommandLine)
    32  }
    33  
    34  // TestZeroReceiveWindow tests if the DUT sends a zero receive window eventually.
    35  func TestZeroReceiveWindow(t *testing.T) {
    36  	// minPayloadLen is the smallest size we can use for a payload in this test.
    37  	// Any smaller than this and the receive buffer will fill up before the
    38  	// receive window can shrink to zero.
    39  
    40  	// To solve for minPayloadLen: minPayloadLen(DefaultReceiveBufferSize) =
    41  	// 	maxWndSize(minPayloadLen + segOverheadSize)
    42  	maxWndSize := math.MaxUint16
    43  	minPayloadLen := int(math.Ceil(float64(maxWndSize*tcp.SegOverheadSize) / float64(tcp.DefaultReceiveBufferSize-maxWndSize)))
    44  	for _, payloadLen := range []int{minPayloadLen, 512, 1024} {
    45  		t.Run(fmt.Sprintf("TestZeroReceiveWindow_with_%dbytes_payload", payloadLen), func(t *testing.T) {
    46  			dut := testbench.NewDUT(t)
    47  			listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
    48  			defer dut.Close(t, listenFd)
    49  			conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
    50  			defer conn.Close(t)
    51  
    52  			conn.Connect(t)
    53  			acceptFd, _ := dut.Accept(t, listenFd)
    54  			defer dut.Close(t, acceptFd)
    55  
    56  			dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
    57  
    58  			fillRecvBuffer(t, &conn, &dut, acceptFd, payloadLen)
    59  		})
    60  	}
    61  }
    62  
    63  func fillRecvBuffer(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, acceptFd int32, payloadLen int) {
    64  	// Expect the DUT to eventually advertise zero receive window.
    65  	// The test would timeout otherwise.
    66  	for readOnce := false; ; {
    67  		samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
    68  		conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
    69  		gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
    70  		if err != nil {
    71  			t.Fatalf("expected packet was not received: %s", err)
    72  		}
    73  		// Read once to trigger the subsequent window update from the
    74  		// DUT to grow the right edge of the receive window from what
    75  		// was advertised in the SYN-ACK. This ensures that we test
    76  		// for the full default buffer size (1MB on gVisor at the time
    77  		// of writing this comment), thus testing for cases when the
    78  		// scaled receive window size ends up > 65535 (0xffff).
    79  		if !readOnce {
    80  			if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen {
    81  				t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
    82  			}
    83  			readOnce = true
    84  		}
    85  		windowSize := *gotTCP.WindowSize
    86  		t.Logf("got window size = %d", windowSize)
    87  		if windowSize == 0 {
    88  			break
    89  		}
    90  		if payloadLen > int(windowSize) {
    91  			payloadLen = int(windowSize)
    92  		}
    93  	}
    94  }
    95  
    96  func TestZeroToNonZeroWindowUpdate(t *testing.T) {
    97  	dut := testbench.NewDUT(t)
    98  	listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
    99  	defer dut.Close(t, listenFd)
   100  	conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
   101  	defer conn.Close(t)
   102  
   103  	conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
   104  	synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
   105  	if err != nil {
   106  		t.Fatalf("didn't get synack during handshake: %s", err)
   107  	}
   108  	conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
   109  
   110  	acceptFd, _ := dut.Accept(t, listenFd)
   111  	defer dut.Close(t, acceptFd)
   112  
   113  	dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
   114  
   115  	mss := header.ParseSynOptions(synAck.Options, true).MSS
   116  	fillRecvBuffer(t, &conn, &dut, acceptFd, int(mss))
   117  
   118  	// Read < mss worth of data from the receive buffer and expect the DUT to
   119  	// not send a non-zero window update.
   120  	payloadLen := mss - 1
   121  	if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != int(payloadLen) {
   122  		t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
   123  	}
   124  	// Send a zero-window-probe to force an ACK from the receiver with any
   125  	// window updates.
   126  	conn.Send(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1)), Flags: testbench.TCPFlags(header.TCPFlagAck)})
   127  	gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
   128  	if err != nil {
   129  		t.Fatalf("expected packet was not received: %s", err)
   130  	}
   131  	if windowSize := *gotTCP.WindowSize; windowSize != 0 {
   132  		t.Fatalf("got non zero window = %d", windowSize)
   133  	}
   134  
   135  	// Now, ensure that the DUT eventually sends non-zero window update.
   136  	seqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))
   137  	ackNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
   138  	recvCheckWindowUpdate := func(readLen int) uint16 {
   139  		if got := dut.Recv(t, acceptFd, int32(readLen), 0); len(got) != readLen {
   140  			t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, readLen, len(got), readLen)
   141  		}
   142  		conn.Send(t, testbench.TCP{SeqNum: seqNum, Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: make([]byte, 1)})
   143  		gotTCP, err := conn.Expect(t, testbench.TCP{AckNum: ackNum, Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
   144  		if err != nil {
   145  			t.Fatalf("expected packet was not received: %s", err)
   146  		}
   147  		return *gotTCP.WindowSize
   148  	}
   149  
   150  	if !dut.Uname.IsLinux() {
   151  		if win := recvCheckWindowUpdate(1); win == 0 {
   152  			t.Fatal("expected non-zero window update")
   153  		}
   154  	} else {
   155  		// Linux stack takes additional socket reads to send out window update,
   156  		// its a function of sysctl_tcp_rmem among other things.
   157  		// https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_input.c#L687
   158  		for {
   159  			if win := recvCheckWindowUpdate(int(payloadLen)); win != 0 {
   160  				break
   161  			}
   162  		}
   163  	}
   164  }
   165  
   166  // TestNonZeroReceiveWindow tests for the DUT to never send a zero receive
   167  // window when the data is being read from the socket buffer.
   168  func TestNonZeroReceiveWindow(t *testing.T) {
   169  	for _, payloadLen := range []int{64, 512, 1024} {
   170  		t.Run(fmt.Sprintf("TestZeroReceiveWindow_with_%dbytes_payload", payloadLen), func(t *testing.T) {
   171  			dut := testbench.NewDUT(t)
   172  			listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
   173  			defer dut.Close(t, listenFd)
   174  			conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
   175  			defer conn.Close(t)
   176  
   177  			conn.Connect(t)
   178  			acceptFd, _ := dut.Accept(t, listenFd)
   179  			defer dut.Close(t, acceptFd)
   180  
   181  			dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
   182  
   183  			samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
   184  			var rcvWindow uint16
   185  			initRcv := false
   186  			// This loop keeps a running rcvWindow value from the initial ACK for the data
   187  			// we sent. Once we have received ACKs with non-zero receive windows, we break
   188  			// the loop.
   189  			for {
   190  				conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
   191  				gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
   192  				if err != nil {
   193  					t.Fatalf("expected packet was not received: %s", err)
   194  				}
   195  				if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen {
   196  					t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen)
   197  				}
   198  				if *gotTCP.WindowSize == 0 {
   199  					t.Fatalf("expected non-zero receive window.")
   200  				}
   201  				if !initRcv {
   202  					rcvWindow = uint16(*gotTCP.WindowSize)
   203  					initRcv = true
   204  				}
   205  				if rcvWindow <= uint16(payloadLen) {
   206  					break
   207  				}
   208  				rcvWindow -= uint16(payloadLen)
   209  			}
   210  		})
   211  	}
   212  }