gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/stack/conntrack_test.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "testing" 19 20 "gvisor.dev/gvisor/pkg/buffer" 21 "gvisor.dev/gvisor/pkg/tcpip" 22 "gvisor.dev/gvisor/pkg/tcpip/faketime" 23 "gvisor.dev/gvisor/pkg/tcpip/header" 24 "gvisor.dev/gvisor/pkg/tcpip/seqnum" 25 "gvisor.dev/gvisor/pkg/tcpip/testutil" 26 "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" 27 ) 28 29 func TestReap(t *testing.T) { 30 // Initialize conntrack. 31 clock := faketime.NewManualClock() 32 ct := ConnTrack{ 33 clock: clock, 34 } 35 ct.init() 36 ct.checkNumTuples(t, 0) 37 38 // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls 39 // rt.RequiresTXTransportChecksum. 40 var rt Route 41 rt.routeInfo.Loop = PacketLoop 42 43 // Simulate sending a SYN. This will get the connection into conntrack, but 44 // the connection won't be considered established. Thus the timeout for 45 // reaping is unestablishedTimeout. 46 pkt1 := genTCPPacket(genTCPOpts{}) 47 pkt1.tuple = ct.getConnAndUpdate(pkt1, true /* skipChecksumValidation */) 48 if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) { 49 t.Fatal("handlePacket() shouldn't perform any NAT") 50 } 51 ct.checkNumTuples(t, 1) 52 53 // Travel a little into the future and send the same SYN. This should update 54 // lastUsed, but per #6748 didn't. 55 clock.Advance(unestablishedTimeout / 2) 56 pkt2 := genTCPPacket(genTCPOpts{}) 57 pkt2.tuple = ct.getConnAndUpdate(pkt2, true /* skipChecksumValidation */) 58 if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) { 59 t.Fatal("handlePacket() shouldn't perform any NAT") 60 } 61 ct.checkNumTuples(t, 1) 62 63 // Travel farther into the future - enough that failing to update lastUsed 64 // would cause a reaping - and reap the whole table. Make sure the connection 65 // hasn't been reaped. 66 clock.Advance(unestablishedTimeout * 3 / 4) 67 ct.reapEverything() 68 ct.checkNumTuples(t, 1) 69 70 // Travel past unestablishedTimeout to confirm the tuple is gone. 71 clock.Advance(unestablishedTimeout / 2) 72 ct.reapEverything() 73 ct.checkNumTuples(t, 0) 74 } 75 76 func TestWindowScaling(t *testing.T) { 77 tcs := []struct { 78 name string 79 windowSize uint16 80 synScale uint8 81 synAckScale uint8 82 dataLen int 83 finalSeq uint32 84 }{ 85 { 86 name: "no scale, full overlap", 87 windowSize: 4, 88 dataLen: 2, 89 finalSeq: 2, 90 }, 91 { 92 name: "no scale, partial overlap", 93 windowSize: 4, 94 dataLen: 8, 95 finalSeq: 4, 96 }, 97 { 98 name: "scale, full overlap", 99 windowSize: 4, 100 synScale: 1, 101 synAckScale: 1, 102 dataLen: 6, 103 finalSeq: 6, 104 }, 105 { 106 name: "scale, partial overlap", 107 windowSize: 4, 108 synScale: 1, 109 synAckScale: 1, 110 dataLen: 10, 111 finalSeq: 8, 112 }, 113 { 114 name: "SYN scale larger", 115 windowSize: 4, 116 synScale: 2, 117 synAckScale: 1, 118 dataLen: 10, 119 finalSeq: 8, 120 }, 121 { 122 name: "SYN/ACK scale larger", 123 windowSize: 4, 124 synScale: 1, 125 synAckScale: 2, 126 dataLen: 10, 127 finalSeq: 10, 128 }, 129 } 130 131 for _, tc := range tcs { 132 t.Run(tc.name, func(t *testing.T) { 133 testWindowScaling(t, tc.windowSize, tc.synScale, tc.synAckScale, tc.dataLen, tc.finalSeq) 134 }) 135 } 136 } 137 138 // testWindowScaling performs a TCP handshake with the given parameters, 139 // attaching dataLen bytes as the payload to the final ACK. 140 func testWindowScaling(t *testing.T, windowSize uint16, synScale, synAckScale uint8, dataLen int, finalSeq uint32) { 141 // Initialize conntrack. 142 clock := faketime.NewManualClock() 143 ct := ConnTrack{ 144 clock: clock, 145 } 146 ct.init() 147 ct.checkNumTuples(t, 0) 148 149 // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls 150 // rt.RequiresTXTransportChecksum. 151 var rt Route 152 rt.routeInfo.Loop = PacketLoop 153 154 var ( 155 rwnd = windowSize 156 seqOrig = uint32(10) 157 seqRepl = uint32(20) 158 flags = header.TCPFlags(header.TCPFlagSyn) 159 originatorAddr = testutil.MustParse4("1.0.0.1") 160 responderAddr = testutil.MustParse4("1.0.0.2") 161 originatorPort = uint16(5555) 162 responderPort = uint16(6666) 163 ) 164 165 // Send SYN outbound through conntrack, simulating the Output hook. 166 synPkt := genTCPPacket(genTCPOpts{ 167 windowSize: &rwnd, 168 windowScale: synScale, 169 seqNum: &seqOrig, 170 flags: &flags, 171 srcAddr: &originatorAddr, 172 dstAddr: &responderAddr, 173 srcPort: &originatorPort, 174 dstPort: &responderPort, 175 }) 176 synPkt.tuple = ct.getConnAndUpdate(synPkt, true /* skipChecksumValidation */) 177 if synPkt.tuple.conn.handlePacket(synPkt, Output, &rt) { 178 t.Fatal("handlePacket() shouldn't perform any NAT") 179 } 180 ct.checkNumTuples(t, 1) 181 182 // Simulate the Postrouting hook. 183 synPkt.tuple.conn.finalize() 184 conn := synPkt.tuple.conn 185 synPkt.tuple = nil 186 ct.checkNumTuples(t, 2) 187 conn.stateMu.Lock() 188 if got, want := conn.tcb.State(), tcpconntrack.ResultConnecting; got != want { 189 t.Fatalf("connection in state %v, but wanted %v", got, want) 190 } 191 conn.stateMu.Unlock() 192 conn.checkOriginalSeq(t, seqOrig+1) 193 194 // Send SYN/ACK, simulating the Prerouting hook. 195 seqOrig++ 196 flags |= header.TCPFlagAck 197 synAckPkt := genTCPPacket(genTCPOpts{ 198 windowSize: &windowSize, 199 windowScale: synAckScale, 200 seqNum: &seqRepl, 201 ackNum: &seqOrig, 202 flags: &flags, 203 srcAddr: &responderAddr, 204 dstAddr: &originatorAddr, 205 srcPort: &responderPort, 206 dstPort: &originatorPort, 207 }) 208 synAckPkt.tuple = ct.getConnAndUpdate(synAckPkt, true /* skipChecksumValidation */) 209 if synAckPkt.tuple.conn.handlePacket(synAckPkt, Prerouting, &rt) { 210 t.Fatal("handlePacket() shouldn't perform any NAT") 211 } 212 ct.checkNumTuples(t, 2) 213 214 // Simulate the Input hook. 215 synAckPkt.tuple.conn.finalize() 216 synAckPkt.tuple = nil 217 ct.checkNumTuples(t, 2) 218 conn.stateMu.Lock() 219 if got, want := conn.tcb.State(), tcpconntrack.ResultAlive; got != want { 220 t.Fatalf("connection in state %v, but wanted %v", got, want) 221 } 222 conn.stateMu.Unlock() 223 conn.checkReplySeq(t, seqRepl+1) 224 225 // Send ACK with a payload, simulating the Output hook. 226 seqRepl++ 227 flags = header.TCPFlagAck 228 ackPkt := genTCPPacket(genTCPOpts{ 229 windowSize: &windowSize, 230 seqNum: &seqOrig, 231 ackNum: &seqRepl, 232 flags: &flags, 233 data: make([]byte, dataLen), 234 srcAddr: &originatorAddr, 235 dstAddr: &responderAddr, 236 srcPort: &originatorPort, 237 dstPort: &responderPort, 238 }) 239 ackPkt.tuple = ct.getConnAndUpdate(ackPkt, true /* skipChecksumValidation */) 240 if ackPkt.tuple.conn.handlePacket(ackPkt, Output, &rt) { 241 t.Fatal("handlePacket() shouldn't perform any NAT") 242 } 243 ct.checkNumTuples(t, 2) 244 245 // Simulate the Postrouting hook. 246 ackPkt.tuple.conn.finalize() 247 ackPkt.tuple = nil 248 ct.checkNumTuples(t, 2) 249 conn.stateMu.Lock() 250 if got, want := conn.tcb.State(), tcpconntrack.ResultAlive; got != want { 251 t.Fatalf("connection in state %v, but wanted %v", got, want) 252 } 253 conn.stateMu.Unlock() 254 // Depending on the test, all or a fraction of dataLen will go towards 255 // advancing the sequence number. 256 conn.checkOriginalSeq(t, finalSeq+seqOrig) 257 258 // Go into the future to make sure we don't reap active connections quickly. 259 clock.Advance(unestablishedTimeout * 2) 260 ct.reapEverything() 261 ct.checkNumTuples(t, 2) 262 263 // Go way into the future to make sure we eventually reap active connections. 264 clock.Advance(establishedTimeout) 265 ct.reapEverything() 266 ct.checkNumTuples(t, 0) 267 } 268 269 type genTCPOpts struct { 270 windowSize *uint16 271 windowScale uint8 272 seqNum *uint32 273 ackNum *uint32 274 flags *header.TCPFlags 275 data []byte 276 srcAddr *tcpip.Address 277 dstAddr *tcpip.Address 278 srcPort *uint16 279 dstPort *uint16 280 } 281 282 // genTCPPacket returns an initialized IPv4 TCP packet. 283 func genTCPPacket(opts genTCPOpts) *PacketBuffer { 284 // Get values from opts. 285 windowSize := uint16(50000) 286 if opts.windowSize != nil { 287 windowSize = *opts.windowSize 288 } 289 tcpHdrSize := uint8(header.TCPMinimumSize) 290 if opts.windowScale != 0 { 291 tcpHdrSize += 4 // 3 bytes of window scale plus 1 of padding. 292 } 293 seqNum := uint32(7777) 294 if opts.seqNum != nil { 295 seqNum = *opts.seqNum 296 } 297 ackNum := uint32(8888) 298 if opts.ackNum != nil { 299 ackNum = *opts.ackNum 300 } 301 flags := header.TCPFlagSyn 302 if opts.flags != nil { 303 flags = *opts.flags 304 } 305 srcAddr := testutil.MustParse4("1.0.0.1") 306 if opts.srcAddr != nil { 307 srcAddr = *opts.srcAddr 308 } 309 dstAddr := testutil.MustParse4("1.0.0.2") 310 if opts.dstAddr != nil { 311 dstAddr = *opts.dstAddr 312 } 313 srcPort := uint16(5555) 314 if opts.srcPort != nil { 315 srcPort = *opts.srcPort 316 } 317 dstPort := uint16(6666) 318 if opts.dstPort != nil { 319 dstPort = *opts.dstPort 320 } 321 322 // Initialize the PacketBuffer. 323 packetLen := header.IPv4MinimumSize + uint16(tcpHdrSize) 324 pkt := NewPacketBuffer(PacketBufferOptions{ 325 ReserveHeaderBytes: int(packetLen), 326 Payload: buffer.MakeWithData(opts.data), 327 }) 328 pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber 329 pkt.TransportProtocolNumber = header.TCPProtocolNumber 330 331 // Craft the TCP header, including the window scale option if necessary. 332 tcpHdr := header.TCP(pkt.TransportHeader().Push(int(tcpHdrSize))) 333 tcpHdr[:header.TCPMinimumSize].Encode(&header.TCPFields{ 334 SrcPort: srcPort, 335 DstPort: dstPort, 336 SeqNum: seqNum, 337 AckNum: ackNum, 338 DataOffset: tcpHdrSize, 339 Flags: flags, 340 WindowSize: windowSize, 341 Checksum: 0, // Conntrack doesn't verify the checksum. 342 }) 343 if opts.windowScale != 0 { 344 // Set the window scale option, which is 3 bytes long. The option is 345 // properly padded because the final remaining byte is already zeroed. 346 _ = header.EncodeWSOption(int(opts.windowScale), tcpHdr[header.TCPMinimumSize:]) 347 } 348 349 // Craft an IPv4 header. 350 ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) 351 ipHdr.Encode(&header.IPv4Fields{ 352 TotalLength: packetLen, 353 Protocol: uint8(header.TCPProtocolNumber), 354 SrcAddr: srcAddr, 355 DstAddr: dstAddr, 356 Checksum: 0, // Conntrack doesn't verify the checksum. 357 }) 358 359 return pkt 360 } 361 362 // checkNumTuples checks that there are exactly want tuples tracked by 363 // conntrack. 364 func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) { 365 t.Helper() 366 ct.mu.RLock() 367 defer ct.mu.RUnlock() 368 369 var total int 370 for idx := range ct.buckets { 371 ct.buckets[idx].mu.RLock() 372 total += ct.buckets[idx].tuples.Len() 373 ct.buckets[idx].mu.RUnlock() 374 } 375 376 if total != want { 377 t.Fatalf("checkNumTuples: got %d, wanted %d", total, want) 378 } 379 } 380 381 func (ct *ConnTrack) reapEverything() { 382 var bucket int 383 for { 384 newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */) 385 // We started reaping at bucket 0. If the next bucket isn't after our 386 // current bucket, we've gone through them all. 387 if newBucket <= bucket { 388 break 389 } 390 bucket = newBucket 391 } 392 } 393 394 func (cn *conn) checkOriginalSeq(t *testing.T, seq uint32) { 395 t.Helper() 396 cn.stateMu.Lock() 397 defer cn.stateMu.Unlock() 398 399 if got, want := cn.tcb.OriginalSendSequenceNumber(), seqnum.Value(seq); got != want { 400 t.Fatalf("checkOriginalSeq: got %d, wanted %d", got, want) 401 } 402 } 403 404 func (cn *conn) checkReplySeq(t *testing.T, seq uint32) { 405 t.Helper() 406 cn.stateMu.Lock() 407 defer cn.stateMu.Unlock() 408 409 if got, want := cn.tcb.ReplySendSequenceNumber(), seqnum.Value(seq); got != want { 410 t.Fatalf("checkReplySeq: got %d, wanted %d", got, want) 411 } 412 }