gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/stack/conntrack_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"testing"
    19  
    20  	"gvisor.dev/gvisor/pkg/buffer"
    21  	"gvisor.dev/gvisor/pkg/tcpip"
    22  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    23  	"gvisor.dev/gvisor/pkg/tcpip/header"
    24  	"gvisor.dev/gvisor/pkg/tcpip/seqnum"
    25  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    26  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
    27  )
    28  
    29  func TestReap(t *testing.T) {
    30  	// Initialize conntrack.
    31  	clock := faketime.NewManualClock()
    32  	ct := ConnTrack{
    33  		clock: clock,
    34  	}
    35  	ct.init()
    36  	ct.checkNumTuples(t, 0)
    37  
    38  	// We set rt.routeInfo.Loop to avoid a panic when handlePacket calls
    39  	// rt.RequiresTXTransportChecksum.
    40  	var rt Route
    41  	rt.routeInfo.Loop = PacketLoop
    42  
    43  	// Simulate sending a SYN. This will get the connection into conntrack, but
    44  	// the connection won't be considered established. Thus the timeout for
    45  	// reaping is unestablishedTimeout.
    46  	pkt1 := genTCPPacket(genTCPOpts{})
    47  	pkt1.tuple = ct.getConnAndUpdate(pkt1, true /* skipChecksumValidation */)
    48  	if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) {
    49  		t.Fatal("handlePacket() shouldn't perform any NAT")
    50  	}
    51  	ct.checkNumTuples(t, 1)
    52  
    53  	// Travel a little into the future and send the same SYN. This should update
    54  	// lastUsed, but per #6748 didn't.
    55  	clock.Advance(unestablishedTimeout / 2)
    56  	pkt2 := genTCPPacket(genTCPOpts{})
    57  	pkt2.tuple = ct.getConnAndUpdate(pkt2, true /* skipChecksumValidation */)
    58  	if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) {
    59  		t.Fatal("handlePacket() shouldn't perform any NAT")
    60  	}
    61  	ct.checkNumTuples(t, 1)
    62  
    63  	// Travel farther into the future - enough that failing to update lastUsed
    64  	// would cause a reaping - and reap the whole table. Make sure the connection
    65  	// hasn't been reaped.
    66  	clock.Advance(unestablishedTimeout * 3 / 4)
    67  	ct.reapEverything()
    68  	ct.checkNumTuples(t, 1)
    69  
    70  	// Travel past unestablishedTimeout to confirm the tuple is gone.
    71  	clock.Advance(unestablishedTimeout / 2)
    72  	ct.reapEverything()
    73  	ct.checkNumTuples(t, 0)
    74  }
    75  
    76  func TestWindowScaling(t *testing.T) {
    77  	tcs := []struct {
    78  		name        string
    79  		windowSize  uint16
    80  		synScale    uint8
    81  		synAckScale uint8
    82  		dataLen     int
    83  		finalSeq    uint32
    84  	}{
    85  		{
    86  			name:       "no scale, full overlap",
    87  			windowSize: 4,
    88  			dataLen:    2,
    89  			finalSeq:   2,
    90  		},
    91  		{
    92  			name:       "no scale, partial overlap",
    93  			windowSize: 4,
    94  			dataLen:    8,
    95  			finalSeq:   4,
    96  		},
    97  		{
    98  			name:        "scale, full overlap",
    99  			windowSize:  4,
   100  			synScale:    1,
   101  			synAckScale: 1,
   102  			dataLen:     6,
   103  			finalSeq:    6,
   104  		},
   105  		{
   106  			name:        "scale, partial overlap",
   107  			windowSize:  4,
   108  			synScale:    1,
   109  			synAckScale: 1,
   110  			dataLen:     10,
   111  			finalSeq:    8,
   112  		},
   113  		{
   114  			name:        "SYN scale larger",
   115  			windowSize:  4,
   116  			synScale:    2,
   117  			synAckScale: 1,
   118  			dataLen:     10,
   119  			finalSeq:    8,
   120  		},
   121  		{
   122  			name:        "SYN/ACK scale larger",
   123  			windowSize:  4,
   124  			synScale:    1,
   125  			synAckScale: 2,
   126  			dataLen:     10,
   127  			finalSeq:    10,
   128  		},
   129  	}
   130  
   131  	for _, tc := range tcs {
   132  		t.Run(tc.name, func(t *testing.T) {
   133  			testWindowScaling(t, tc.windowSize, tc.synScale, tc.synAckScale, tc.dataLen, tc.finalSeq)
   134  		})
   135  	}
   136  }
   137  
   138  // testWindowScaling performs a TCP handshake with the given parameters,
   139  // attaching dataLen bytes as the payload to the final ACK.
   140  func testWindowScaling(t *testing.T, windowSize uint16, synScale, synAckScale uint8, dataLen int, finalSeq uint32) {
   141  	// Initialize conntrack.
   142  	clock := faketime.NewManualClock()
   143  	ct := ConnTrack{
   144  		clock: clock,
   145  	}
   146  	ct.init()
   147  	ct.checkNumTuples(t, 0)
   148  
   149  	// We set rt.routeInfo.Loop to avoid a panic when handlePacket calls
   150  	// rt.RequiresTXTransportChecksum.
   151  	var rt Route
   152  	rt.routeInfo.Loop = PacketLoop
   153  
   154  	var (
   155  		rwnd           = windowSize
   156  		seqOrig        = uint32(10)
   157  		seqRepl        = uint32(20)
   158  		flags          = header.TCPFlags(header.TCPFlagSyn)
   159  		originatorAddr = testutil.MustParse4("1.0.0.1")
   160  		responderAddr  = testutil.MustParse4("1.0.0.2")
   161  		originatorPort = uint16(5555)
   162  		responderPort  = uint16(6666)
   163  	)
   164  
   165  	// Send SYN outbound through conntrack, simulating the Output hook.
   166  	synPkt := genTCPPacket(genTCPOpts{
   167  		windowSize:  &rwnd,
   168  		windowScale: synScale,
   169  		seqNum:      &seqOrig,
   170  		flags:       &flags,
   171  		srcAddr:     &originatorAddr,
   172  		dstAddr:     &responderAddr,
   173  		srcPort:     &originatorPort,
   174  		dstPort:     &responderPort,
   175  	})
   176  	synPkt.tuple = ct.getConnAndUpdate(synPkt, true /* skipChecksumValidation */)
   177  	if synPkt.tuple.conn.handlePacket(synPkt, Output, &rt) {
   178  		t.Fatal("handlePacket() shouldn't perform any NAT")
   179  	}
   180  	ct.checkNumTuples(t, 1)
   181  
   182  	// Simulate the Postrouting hook.
   183  	synPkt.tuple.conn.finalize()
   184  	conn := synPkt.tuple.conn
   185  	synPkt.tuple = nil
   186  	ct.checkNumTuples(t, 2)
   187  	conn.stateMu.Lock()
   188  	if got, want := conn.tcb.State(), tcpconntrack.ResultConnecting; got != want {
   189  		t.Fatalf("connection in state %v, but wanted %v", got, want)
   190  	}
   191  	conn.stateMu.Unlock()
   192  	conn.checkOriginalSeq(t, seqOrig+1)
   193  
   194  	// Send SYN/ACK, simulating the Prerouting hook.
   195  	seqOrig++
   196  	flags |= header.TCPFlagAck
   197  	synAckPkt := genTCPPacket(genTCPOpts{
   198  		windowSize:  &windowSize,
   199  		windowScale: synAckScale,
   200  		seqNum:      &seqRepl,
   201  		ackNum:      &seqOrig,
   202  		flags:       &flags,
   203  		srcAddr:     &responderAddr,
   204  		dstAddr:     &originatorAddr,
   205  		srcPort:     &responderPort,
   206  		dstPort:     &originatorPort,
   207  	})
   208  	synAckPkt.tuple = ct.getConnAndUpdate(synAckPkt, true /* skipChecksumValidation */)
   209  	if synAckPkt.tuple.conn.handlePacket(synAckPkt, Prerouting, &rt) {
   210  		t.Fatal("handlePacket() shouldn't perform any NAT")
   211  	}
   212  	ct.checkNumTuples(t, 2)
   213  
   214  	// Simulate the Input hook.
   215  	synAckPkt.tuple.conn.finalize()
   216  	synAckPkt.tuple = nil
   217  	ct.checkNumTuples(t, 2)
   218  	conn.stateMu.Lock()
   219  	if got, want := conn.tcb.State(), tcpconntrack.ResultAlive; got != want {
   220  		t.Fatalf("connection in state %v, but wanted %v", got, want)
   221  	}
   222  	conn.stateMu.Unlock()
   223  	conn.checkReplySeq(t, seqRepl+1)
   224  
   225  	// Send ACK with a payload, simulating the Output hook.
   226  	seqRepl++
   227  	flags = header.TCPFlagAck
   228  	ackPkt := genTCPPacket(genTCPOpts{
   229  		windowSize: &windowSize,
   230  		seqNum:     &seqOrig,
   231  		ackNum:     &seqRepl,
   232  		flags:      &flags,
   233  		data:       make([]byte, dataLen),
   234  		srcAddr:    &originatorAddr,
   235  		dstAddr:    &responderAddr,
   236  		srcPort:    &originatorPort,
   237  		dstPort:    &responderPort,
   238  	})
   239  	ackPkt.tuple = ct.getConnAndUpdate(ackPkt, true /* skipChecksumValidation */)
   240  	if ackPkt.tuple.conn.handlePacket(ackPkt, Output, &rt) {
   241  		t.Fatal("handlePacket() shouldn't perform any NAT")
   242  	}
   243  	ct.checkNumTuples(t, 2)
   244  
   245  	// Simulate the Postrouting hook.
   246  	ackPkt.tuple.conn.finalize()
   247  	ackPkt.tuple = nil
   248  	ct.checkNumTuples(t, 2)
   249  	conn.stateMu.Lock()
   250  	if got, want := conn.tcb.State(), tcpconntrack.ResultAlive; got != want {
   251  		t.Fatalf("connection in state %v, but wanted %v", got, want)
   252  	}
   253  	conn.stateMu.Unlock()
   254  	// Depending on the test, all or a fraction of dataLen will go towards
   255  	// advancing the sequence number.
   256  	conn.checkOriginalSeq(t, finalSeq+seqOrig)
   257  
   258  	// Go into the future to make sure we don't reap active connections quickly.
   259  	clock.Advance(unestablishedTimeout * 2)
   260  	ct.reapEverything()
   261  	ct.checkNumTuples(t, 2)
   262  
   263  	// Go way into the future to make sure we eventually reap active connections.
   264  	clock.Advance(establishedTimeout)
   265  	ct.reapEverything()
   266  	ct.checkNumTuples(t, 0)
   267  }
   268  
   269  type genTCPOpts struct {
   270  	windowSize  *uint16
   271  	windowScale uint8
   272  	seqNum      *uint32
   273  	ackNum      *uint32
   274  	flags       *header.TCPFlags
   275  	data        []byte
   276  	srcAddr     *tcpip.Address
   277  	dstAddr     *tcpip.Address
   278  	srcPort     *uint16
   279  	dstPort     *uint16
   280  }
   281  
   282  // genTCPPacket returns an initialized IPv4 TCP packet.
   283  func genTCPPacket(opts genTCPOpts) *PacketBuffer {
   284  	// Get values from opts.
   285  	windowSize := uint16(50000)
   286  	if opts.windowSize != nil {
   287  		windowSize = *opts.windowSize
   288  	}
   289  	tcpHdrSize := uint8(header.TCPMinimumSize)
   290  	if opts.windowScale != 0 {
   291  		tcpHdrSize += 4 // 3 bytes of window scale plus 1 of padding.
   292  	}
   293  	seqNum := uint32(7777)
   294  	if opts.seqNum != nil {
   295  		seqNum = *opts.seqNum
   296  	}
   297  	ackNum := uint32(8888)
   298  	if opts.ackNum != nil {
   299  		ackNum = *opts.ackNum
   300  	}
   301  	flags := header.TCPFlagSyn
   302  	if opts.flags != nil {
   303  		flags = *opts.flags
   304  	}
   305  	srcAddr := testutil.MustParse4("1.0.0.1")
   306  	if opts.srcAddr != nil {
   307  		srcAddr = *opts.srcAddr
   308  	}
   309  	dstAddr := testutil.MustParse4("1.0.0.2")
   310  	if opts.dstAddr != nil {
   311  		dstAddr = *opts.dstAddr
   312  	}
   313  	srcPort := uint16(5555)
   314  	if opts.srcPort != nil {
   315  		srcPort = *opts.srcPort
   316  	}
   317  	dstPort := uint16(6666)
   318  	if opts.dstPort != nil {
   319  		dstPort = *opts.dstPort
   320  	}
   321  
   322  	// Initialize the PacketBuffer.
   323  	packetLen := header.IPv4MinimumSize + uint16(tcpHdrSize)
   324  	pkt := NewPacketBuffer(PacketBufferOptions{
   325  		ReserveHeaderBytes: int(packetLen),
   326  		Payload:            buffer.MakeWithData(opts.data),
   327  	})
   328  	pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
   329  	pkt.TransportProtocolNumber = header.TCPProtocolNumber
   330  
   331  	// Craft the TCP header, including the window scale option if necessary.
   332  	tcpHdr := header.TCP(pkt.TransportHeader().Push(int(tcpHdrSize)))
   333  	tcpHdr[:header.TCPMinimumSize].Encode(&header.TCPFields{
   334  		SrcPort:    srcPort,
   335  		DstPort:    dstPort,
   336  		SeqNum:     seqNum,
   337  		AckNum:     ackNum,
   338  		DataOffset: tcpHdrSize,
   339  		Flags:      flags,
   340  		WindowSize: windowSize,
   341  		Checksum:   0, // Conntrack doesn't verify the checksum.
   342  	})
   343  	if opts.windowScale != 0 {
   344  		// Set the window scale option, which is 3 bytes long. The option is
   345  		// properly padded because the final remaining byte is already zeroed.
   346  		_ = header.EncodeWSOption(int(opts.windowScale), tcpHdr[header.TCPMinimumSize:])
   347  	}
   348  
   349  	// Craft an IPv4 header.
   350  	ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
   351  	ipHdr.Encode(&header.IPv4Fields{
   352  		TotalLength: packetLen,
   353  		Protocol:    uint8(header.TCPProtocolNumber),
   354  		SrcAddr:     srcAddr,
   355  		DstAddr:     dstAddr,
   356  		Checksum:    0, // Conntrack doesn't verify the checksum.
   357  	})
   358  
   359  	return pkt
   360  }
   361  
   362  // checkNumTuples checks that there are exactly want tuples tracked by
   363  // conntrack.
   364  func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) {
   365  	t.Helper()
   366  	ct.mu.RLock()
   367  	defer ct.mu.RUnlock()
   368  
   369  	var total int
   370  	for idx := range ct.buckets {
   371  		ct.buckets[idx].mu.RLock()
   372  		total += ct.buckets[idx].tuples.Len()
   373  		ct.buckets[idx].mu.RUnlock()
   374  	}
   375  
   376  	if total != want {
   377  		t.Fatalf("checkNumTuples: got %d, wanted %d", total, want)
   378  	}
   379  }
   380  
   381  func (ct *ConnTrack) reapEverything() {
   382  	var bucket int
   383  	for {
   384  		newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */)
   385  		// We started reaping at bucket 0. If the next bucket isn't after our
   386  		// current bucket, we've gone through them all.
   387  		if newBucket <= bucket {
   388  			break
   389  		}
   390  		bucket = newBucket
   391  	}
   392  }
   393  
   394  func (cn *conn) checkOriginalSeq(t *testing.T, seq uint32) {
   395  	t.Helper()
   396  	cn.stateMu.Lock()
   397  	defer cn.stateMu.Unlock()
   398  
   399  	if got, want := cn.tcb.OriginalSendSequenceNumber(), seqnum.Value(seq); got != want {
   400  		t.Fatalf("checkOriginalSeq: got %d, wanted %d", got, want)
   401  	}
   402  }
   403  
   404  func (cn *conn) checkReplySeq(t *testing.T, seq uint32) {
   405  	t.Helper()
   406  	cn.stateMu.Lock()
   407  	defer cn.stateMu.Unlock()
   408  
   409  	if got, want := cn.tcb.ReplySendSequenceNumber(), seqnum.Value(seq); got != want {
   410  		t.Fatalf("checkReplySeq: got %d, wanted %d", got, want)
   411  	}
   412  }