gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/tcp/test/e2e/forwarder_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package forwarder_test
    16  
    17  import (
    18  	"os"
    19  	"testing"
    20  	"time"
    21  
    22  	"gvisor.dev/gvisor/pkg/atomicbitops"
    23  	"gvisor.dev/gvisor/pkg/refs"
    24  	"gvisor.dev/gvisor/pkg/tcpip"
    25  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    26  	"gvisor.dev/gvisor/pkg/tcpip/header"
    27  	"gvisor.dev/gvisor/pkg/tcpip/seqnum"
    28  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    29  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/test/e2e"
    30  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
    31  )
    32  
    33  func TestForwarderSendMSSLessThanMTU(t *testing.T) {
    34  	const maxPayload = 100
    35  	const mtu = 1200
    36  	c := context.New(t, mtu)
    37  	defer c.Cleanup()
    38  
    39  	s := c.Stack()
    40  	ch := make(chan tcpip.Error, 1)
    41  	f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
    42  		var err tcpip.Error
    43  		c.EP, err = r.CreateEndpoint(&c.WQ)
    44  		ch <- err
    45  		close(ch)
    46  		r.Complete(false)
    47  	})
    48  	s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
    49  
    50  	// Do 3-way handshake.
    51  	c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
    52  
    53  	// Wait for connection to be available.
    54  	select {
    55  	case err := <-ch:
    56  		if err != nil {
    57  			t.Fatalf("Error creating endpoint: %s", err)
    58  		}
    59  	case <-time.After(2 * time.Second):
    60  		t.Fatalf("Timed out waiting for connection")
    61  	}
    62  
    63  	// Check that data gets properly segmented.
    64  	e2e.CheckBrokenUpWrite(t, c, maxPayload)
    65  }
    66  
    67  func TestForwarderDoesNotRejectECNFlags(t *testing.T) {
    68  	testCases := []struct {
    69  		name  string
    70  		flags header.TCPFlags
    71  	}{
    72  		{name: "non-setup ECN SYN w/ ECE", flags: header.TCPFlagEce},
    73  		{name: "non-setup ECN SYN w/ CWR", flags: header.TCPFlagCwr},
    74  		{name: "setup ECN SYN", flags: header.TCPFlagEce | header.TCPFlagCwr},
    75  	}
    76  
    77  	for _, tc := range testCases {
    78  		t.Run(tc.name, func(t *testing.T) {
    79  			const maxPayload = 100
    80  			const mtu = 1200
    81  			c := context.New(t, mtu)
    82  			defer c.Cleanup()
    83  
    84  			s := c.Stack()
    85  			ch := make(chan tcpip.Error, 1)
    86  			f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
    87  				var err tcpip.Error
    88  				c.EP, err = r.CreateEndpoint(&c.WQ)
    89  				ch <- err
    90  				close(ch)
    91  				r.Complete(false)
    92  			})
    93  			s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
    94  
    95  			// Do 3-way handshake.
    96  			c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, Flags: tc.flags})
    97  
    98  			// Wait for connection to be available.
    99  			select {
   100  			case err := <-ch:
   101  				if err != nil {
   102  					t.Fatalf("Error creating endpoint: %s", err)
   103  				}
   104  			case <-time.After(2 * time.Second):
   105  				t.Fatalf("Timed out waiting for connection")
   106  			}
   107  		})
   108  	}
   109  }
   110  
   111  func TestForwarderFailedConnect(t *testing.T) {
   112  	const mtu = 1200
   113  	c := context.New(t, mtu)
   114  	defer c.Cleanup()
   115  
   116  	s := c.Stack()
   117  	ch := make(chan tcpip.Error, 1)
   118  	f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
   119  		var err tcpip.Error
   120  		c.EP, err = r.CreateEndpoint(&c.WQ)
   121  		ch <- err
   122  		close(ch)
   123  		r.Complete(false)
   124  	})
   125  	s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
   126  
   127  	// Initiate a connection that will be forwarded by the Forwarder.
   128  	// Send a SYN request.
   129  	iss := seqnum.Value(context.TestInitialSequenceNumber)
   130  	c.SendPacket(nil, &context.Headers{
   131  		SrcPort: context.TestPort,
   132  		DstPort: context.StackPort,
   133  		Flags:   header.TCPFlagSyn,
   134  		SeqNum:  iss,
   135  		RcvWnd:  30000,
   136  	})
   137  
   138  	// Receive the SYN-ACK reply. Make sure MSS and other expected options
   139  	// are present.
   140  	v := c.GetPacket()
   141  	defer v.Release()
   142  	tcp := header.TCP(header.IPv4(v.AsSlice()).Payload())
   143  	c.IRS = seqnum.Value(tcp.SequenceNumber())
   144  
   145  	tcpCheckers := []checker.TransportChecker{
   146  		checker.SrcPort(context.StackPort),
   147  		checker.DstPort(context.TestPort),
   148  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
   149  		checker.TCPAckNum(uint32(iss) + 1),
   150  	}
   151  	checker.IPv4(t, v, checker.TCP(tcpCheckers...))
   152  
   153  	// Now send an active RST to abort the handshake.
   154  	c.SendPacket(nil, &context.Headers{
   155  		SrcPort: context.TestPort,
   156  		DstPort: context.StackPort,
   157  		Flags:   header.TCPFlagRst,
   158  		SeqNum:  iss + 1,
   159  		RcvWnd:  0,
   160  	})
   161  
   162  	// Wait for connect to fail.
   163  	select {
   164  	case err := <-ch:
   165  		if err == nil {
   166  			t.Fatalf("endpoint creation should have failed")
   167  		}
   168  	case <-time.After(2 * time.Second):
   169  		t.Fatalf("Timed out waiting for connection to fail")
   170  	}
   171  }
   172  
   173  func TestForwarderDroppedStats(t *testing.T) {
   174  	const maxPayload = 100
   175  	const mtu = 1200
   176  	c := context.New(t, mtu)
   177  	defer c.Cleanup()
   178  
   179  	const maxInFlight = 2
   180  	iters := atomicbitops.FromInt64(maxInFlight)
   181  	s := c.Stack()
   182  	checkedStats := make(chan struct{})
   183  	done := make(chan struct{})
   184  	f := tcp.NewForwarder(s, 65536, maxInFlight, func(r *tcp.ForwarderRequest) {
   185  		<-checkedStats
   186  		// Complete all requests without doing anything
   187  		r.Complete(false)
   188  		if iter := iters.Add(-1); iter == 0 {
   189  			close(done)
   190  		}
   191  	})
   192  	s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
   193  
   194  	for i := 0; i < maxInFlight+1; i++ {
   195  		iss := seqnum.Value(context.TestInitialSequenceNumber + i)
   196  		c.SendPacket(nil, &context.Headers{
   197  			SrcPort: uint16(context.TestPort + i),
   198  			DstPort: context.StackPort,
   199  			Flags:   header.TCPFlagSyn,
   200  			SeqNum:  iss,
   201  			RcvWnd:  30000,
   202  		})
   203  	}
   204  
   205  	// Verify that we got one ignored packet.
   206  	if curr := s.Stats().TCP.ForwardMaxInFlightDrop.Value(); curr != 1 {
   207  		t.Errorf("Expected one dropped connection, but got %d", curr)
   208  	}
   209  	close(checkedStats)
   210  	<-done
   211  }
   212  
   213  func TestMain(m *testing.M) {
   214  	refs.SetLeakMode(refs.LeaksPanic)
   215  	code := m.Run()
   216  	// Allow TCP async work to complete to avoid false reports of leaks.
   217  	// TODO(gvisor.dev/issue/5940): Use fake clock in tests.
   218  	time.Sleep(1 * time.Second)
   219  	refs.DoLeakCheck()
   220  	os.Exit(code)
   221  }