github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/tcp_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp_test 16 17 import ( 18 "bytes" 19 "fmt" 20 "io/ioutil" 21 "math" 22 "strings" 23 "testing" 24 "time" 25 26 "github.com/google/go-cmp/cmp" 27 "github.com/SagerNet/gvisor/pkg/rand" 28 "github.com/SagerNet/gvisor/pkg/sync" 29 "github.com/SagerNet/gvisor/pkg/tcpip" 30 "github.com/SagerNet/gvisor/pkg/tcpip/checker" 31 "github.com/SagerNet/gvisor/pkg/tcpip/header" 32 "github.com/SagerNet/gvisor/pkg/tcpip/link/loopback" 33 "github.com/SagerNet/gvisor/pkg/tcpip/link/sniffer" 34 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4" 35 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6" 36 "github.com/SagerNet/gvisor/pkg/tcpip/seqnum" 37 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 38 tcpiptestutil "github.com/SagerNet/gvisor/pkg/tcpip/testutil" 39 "github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp" 40 "github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp/testing/context" 41 "github.com/SagerNet/gvisor/pkg/test/testutil" 42 "github.com/SagerNet/gvisor/pkg/waiter" 43 ) 44 45 // endpointTester provides helper functions to test a tcpip.Endpoint. 46 type endpointTester struct { 47 ep tcpip.Endpoint 48 } 49 50 // CheckReadError issues a read to the endpoint and checking for an error. 51 func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { 52 t.Helper() 53 res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) 54 if got != want { 55 t.Fatalf("ep.Read = %s, want %s", got, want) 56 } 57 if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" { 58 t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff) 59 } 60 } 61 62 // CheckRead issues a read to the endpoint and checking for a success, returning 63 // the data read. 64 func (e *endpointTester) CheckRead(t *testing.T) []byte { 65 t.Helper() 66 var buf bytes.Buffer 67 res, err := e.ep.Read(&buf, tcpip.ReadOptions{}) 68 if err != nil { 69 t.Fatalf("ep.Read = _, %s; want _, nil", err) 70 } 71 if diff := cmp.Diff(tcpip.ReadResult{ 72 Count: buf.Len(), 73 Total: buf.Len(), 74 }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { 75 t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) 76 } 77 return buf.Bytes() 78 } 79 80 // CheckReadFull reads from the endpoint for exactly count bytes. 81 func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { 82 t.Helper() 83 var buf bytes.Buffer 84 w := tcpip.LimitedWriter{ 85 W: &buf, 86 N: int64(count), 87 } 88 for w.N != 0 { 89 _, err := e.ep.Read(&w, tcpip.ReadOptions{}) 90 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 91 // Wait for receive to be notified. 92 select { 93 case <-notifyRead: 94 case <-time.After(timeout): 95 t.Fatalf("Timed out waiting for data to arrive") 96 } 97 continue 98 } else if err != nil { 99 t.Fatalf("ep.Read = _, %s; want _, nil", err) 100 } 101 } 102 return buf.Bytes() 103 } 104 105 const ( 106 // defaultMTU is the MTU, in bytes, used throughout the tests, except 107 // where another value is explicitly used. It is chosen to match the MTU 108 // of loopback interfaces on linux systems. 109 defaultMTU = 65535 110 111 // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an 112 // IPv4 endpoint when the MTU is set to defaultMTU in the test. 113 defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize 114 ) 115 116 func TestGiveUpConnect(t *testing.T) { 117 c := context.New(t, defaultMTU) 118 defer c.Cleanup() 119 120 var wq waiter.Queue 121 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 122 if err != nil { 123 t.Fatalf("NewEndpoint failed: %s", err) 124 } 125 126 // Register for notification, then start connection attempt. 127 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 128 wq.EventRegister(&waitEntry, waiter.EventHUp) 129 defer wq.EventUnregister(&waitEntry) 130 131 { 132 err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 133 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 134 t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) 135 } 136 } 137 138 // Close the connection, wait for completion. 139 ep.Close() 140 141 // Wait for ep to become writable. 142 <-notifyCh 143 144 // Call Connect again to retreive the handshake failure status 145 // and stats updates. 146 { 147 err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 148 if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { 149 t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) 150 } 151 } 152 153 if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { 154 t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) 155 } 156 157 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 158 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) 159 } 160 } 161 162 // Test for ICMP error handling without completing handshake. 163 func TestConnectICMPError(t *testing.T) { 164 c := context.New(t, defaultMTU) 165 defer c.Cleanup() 166 167 var wq waiter.Queue 168 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 169 if err != nil { 170 t.Fatalf("NewEndpoint failed: %s", err) 171 } 172 173 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 174 wq.EventRegister(&waitEntry, waiter.EventHUp) 175 defer wq.EventUnregister(&waitEntry) 176 177 { 178 err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 179 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 180 t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) 181 } 182 } 183 184 syn := c.GetPacket() 185 checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) 186 187 wep := ep.(interface { 188 StopWork() 189 ResumeWork() 190 LastErrorLocked() tcpip.Error 191 }) 192 193 // Stop the protocol loop, ensure that the ICMP error is processed and 194 // the last ICMP error is read before the loop is resumed. This sanity 195 // tests the handshake completion logic on ICMP errors. 196 wep.StopWork() 197 198 c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) 199 200 for { 201 if err := wep.LastErrorLocked(); err != nil { 202 if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { 203 t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) 204 } 205 break 206 } 207 time.Sleep(time.Millisecond) 208 } 209 210 wep.ResumeWork() 211 212 <-notifyCh 213 214 // The stack would have unregistered the endpoint because of the ICMP error. 215 // Expect a RST for any subsequent packets sent to the endpoint. 216 c.SendPacket(nil, &context.Headers{ 217 SrcPort: context.TestPort, 218 DstPort: context.StackPort, 219 Flags: header.TCPFlagAck, 220 SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, 221 AckNum: c.IRS + 1, 222 }) 223 224 checker.IPv4(t, c.GetPacket(), checker.TCP( 225 checker.SrcPort(context.StackPort), 226 checker.DstPort(context.TestPort), 227 checker.TCPSeqNum(uint32(c.IRS+1)), 228 checker.TCPAckNum(0), 229 checker.TCPFlags(header.TCPFlagRst))) 230 } 231 232 func TestConnectIncrementActiveConnection(t *testing.T) { 233 c := context.New(t, defaultMTU) 234 defer c.Cleanup() 235 236 stats := c.Stack().Stats() 237 want := stats.TCP.ActiveConnectionOpenings.Value() + 1 238 239 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 240 if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { 241 t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) 242 } 243 } 244 245 func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { 246 c := context.New(t, defaultMTU) 247 defer c.Cleanup() 248 249 stats := c.Stack().Stats() 250 want := stats.TCP.FailedConnectionAttempts.Value() 251 252 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 253 if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { 254 t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) 255 } 256 if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { 257 t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) 258 } 259 } 260 261 func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { 262 c := context.New(t, defaultMTU) 263 defer c.Cleanup() 264 265 stats := c.Stack().Stats() 266 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 267 if err != nil { 268 t.Fatalf("NewEndpoint failed: %s", err) 269 } 270 c.EP = ep 271 want := stats.TCP.FailedConnectionAttempts.Value() + 1 272 273 { 274 err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) 275 if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { 276 t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) 277 } 278 } 279 280 if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { 281 t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) 282 } 283 if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { 284 t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) 285 } 286 } 287 288 func TestCloseWithoutConnect(t *testing.T) { 289 c := context.New(t, defaultMTU) 290 defer c.Cleanup() 291 292 // Create TCP endpoint. 293 var err tcpip.Error 294 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 295 if err != nil { 296 t.Fatalf("NewEndpoint failed: %s", err) 297 } 298 299 c.EP.Close() 300 301 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 302 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 303 } 304 } 305 306 func TestTCPSegmentsSentIncrement(t *testing.T) { 307 c := context.New(t, defaultMTU) 308 defer c.Cleanup() 309 310 stats := c.Stack().Stats() 311 // SYN and ACK 312 want := stats.TCP.SegmentsSent.Value() + 2 313 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 314 315 if got := stats.TCP.SegmentsSent.Value(); got != want { 316 t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) 317 } 318 if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { 319 t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) 320 } 321 } 322 323 func TestTCPResetsSentIncrement(t *testing.T) { 324 c := context.New(t, defaultMTU) 325 defer c.Cleanup() 326 stats := c.Stack().Stats() 327 wq := &waiter.Queue{} 328 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 329 if err != nil { 330 t.Fatalf("NewEndpoint failed: %s", err) 331 } 332 want := stats.TCP.SegmentsSent.Value() + 1 333 334 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 335 t.Fatalf("Bind failed: %s", err) 336 } 337 338 if err := ep.Listen(10); err != nil { 339 t.Fatalf("Listen failed: %s", err) 340 } 341 342 // Send a SYN request. 343 iss := seqnum.Value(context.TestInitialSequenceNumber) 344 c.SendPacket(nil, &context.Headers{ 345 SrcPort: context.TestPort, 346 DstPort: context.StackPort, 347 Flags: header.TCPFlagSyn, 348 SeqNum: iss, 349 }) 350 351 // Receive the SYN-ACK reply. 352 b := c.GetPacket() 353 tcpHdr := header.TCP(header.IPv4(b).Payload()) 354 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 355 356 ackHeaders := &context.Headers{ 357 SrcPort: context.TestPort, 358 DstPort: context.StackPort, 359 Flags: header.TCPFlagAck, 360 SeqNum: iss + 1, 361 // If the AckNum is not the increment of the last sequence number, a RST 362 // segment is sent back in response. 363 AckNum: c.IRS + 2, 364 } 365 366 // Send ACK. 367 c.SendPacket(nil, ackHeaders) 368 369 c.GetPacket() 370 371 metricPollFn := func() error { 372 if got := stats.TCP.ResetsSent.Value(); got != want { 373 return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) 374 } 375 return nil 376 } 377 if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { 378 t.Error(err) 379 } 380 } 381 382 // TestTCPResetsSentNoICMP confirms that we don't get an ICMP 383 // DstUnreachable packet when we try send a packet which is not part 384 // of an active session. 385 func TestTCPResetsSentNoICMP(t *testing.T) { 386 c := context.New(t, defaultMTU) 387 defer c.Cleanup() 388 stats := c.Stack().Stats() 389 390 // Send a SYN request for a closed port. This should elicit an RST 391 // but NOT an ICMPv4 DstUnreachable packet. 392 iss := seqnum.Value(context.TestInitialSequenceNumber) 393 c.SendPacket(nil, &context.Headers{ 394 SrcPort: context.TestPort, 395 DstPort: context.StackPort, 396 Flags: header.TCPFlagSyn, 397 SeqNum: iss, 398 }) 399 400 // Receive whatever comes back. 401 b := c.GetPacket() 402 ipHdr := header.IPv4(b) 403 if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { 404 t.Errorf("unexpected protocol, got = %d, want = %d", got, want) 405 } 406 407 // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. 408 sent := stats.ICMP.V4.PacketsSent 409 if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { 410 t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) 411 } 412 } 413 414 // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates 415 // a RST if an ACK is received on the listening socket for which there is no 416 // active handshake in progress and we are not using SYN cookies. 417 func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { 418 c := context.New(t, defaultMTU) 419 defer c.Cleanup() 420 421 // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed 422 wq := &waiter.Queue{} 423 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 424 if err != nil { 425 t.Fatalf("NewEndpoint failed: %s", err) 426 } 427 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 428 t.Fatalf("Bind failed: %s", err) 429 } 430 431 if err := ep.Listen(10); err != nil { 432 t.Fatalf("Listen failed: %s", err) 433 } 434 435 // Send a SYN request. 436 iss := seqnum.Value(context.TestInitialSequenceNumber) 437 c.SendPacket(nil, &context.Headers{ 438 SrcPort: context.TestPort, 439 DstPort: context.StackPort, 440 Flags: header.TCPFlagSyn, 441 SeqNum: iss, 442 }) 443 444 // Receive the SYN-ACK reply. 445 b := c.GetPacket() 446 tcpHdr := header.TCP(header.IPv4(b).Payload()) 447 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 448 449 ackHeaders := &context.Headers{ 450 SrcPort: context.TestPort, 451 DstPort: context.StackPort, 452 Flags: header.TCPFlagAck, 453 SeqNum: iss + 1, 454 AckNum: c.IRS + 1, 455 } 456 457 // Send ACK. 458 c.SendPacket(nil, ackHeaders) 459 460 // Try to accept the connection. 461 we, ch := waiter.NewChannelEntry(nil) 462 wq.EventRegister(&we, waiter.ReadableEvents) 463 defer wq.EventUnregister(&we) 464 465 c.EP, _, err = ep.Accept(nil) 466 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 467 // Wait for connection to be established. 468 select { 469 case <-ch: 470 c.EP, _, err = ep.Accept(nil) 471 if err != nil { 472 t.Fatalf("Accept failed: %s", err) 473 } 474 475 case <-time.After(1 * time.Second): 476 t.Fatalf("Timed out waiting for accept") 477 } 478 } 479 480 // Lower stackwide TIME_WAIT timeout so that the reservations 481 // are released instantly on Close. 482 tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) 483 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { 484 t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) 485 } 486 487 c.EP.Close() 488 checker.IPv4(t, c.GetPacket(), checker.TCP( 489 checker.SrcPort(context.StackPort), 490 checker.DstPort(context.TestPort), 491 checker.TCPSeqNum(uint32(c.IRS+1)), 492 checker.TCPAckNum(uint32(iss)+1), 493 checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) 494 finHeaders := &context.Headers{ 495 SrcPort: context.TestPort, 496 DstPort: context.StackPort, 497 Flags: header.TCPFlagAck | header.TCPFlagFin, 498 SeqNum: iss + 1, 499 AckNum: c.IRS + 2, 500 } 501 502 c.SendPacket(nil, finHeaders) 503 504 // Get the ACK to the FIN we just sent. 505 c.GetPacket() 506 507 // Since an active close was done we need to wait for a little more than 508 // tcpLingerTimeout for the port reservations to be released and the 509 // socket to move to a CLOSED state. 510 time.Sleep(20 * time.Millisecond) 511 512 // Now resend the same ACK, this ACK should generate a RST as there 513 // should be no endpoint in SYN-RCVD state and we are not using 514 // syn-cookies yet. The reason we send the same ACK is we need a valid 515 // cookie(IRS) generated by the netstack without which the ACK will be 516 // rejected. 517 c.SendPacket(nil, ackHeaders) 518 519 checker.IPv4(t, c.GetPacket(), checker.TCP( 520 checker.SrcPort(context.StackPort), 521 checker.DstPort(context.TestPort), 522 checker.TCPSeqNum(uint32(c.IRS+1)), 523 checker.TCPAckNum(0), 524 checker.TCPFlags(header.TCPFlagRst))) 525 } 526 527 func TestTCPResetsReceivedIncrement(t *testing.T) { 528 c := context.New(t, defaultMTU) 529 defer c.Cleanup() 530 531 stats := c.Stack().Stats() 532 want := stats.TCP.ResetsReceived.Value() + 1 533 iss := seqnum.Value(context.TestInitialSequenceNumber) 534 rcvWnd := seqnum.Size(30000) 535 c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) 536 537 c.SendPacket(nil, &context.Headers{ 538 SrcPort: context.TestPort, 539 DstPort: c.Port, 540 SeqNum: iss.Add(1), 541 AckNum: c.IRS.Add(1), 542 RcvWnd: rcvWnd, 543 Flags: header.TCPFlagRst, 544 }) 545 546 if got := stats.TCP.ResetsReceived.Value(); got != want { 547 t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) 548 } 549 } 550 551 func TestTCPResetsDoNotGenerateResets(t *testing.T) { 552 c := context.New(t, defaultMTU) 553 defer c.Cleanup() 554 555 stats := c.Stack().Stats() 556 want := stats.TCP.ResetsReceived.Value() + 1 557 iss := seqnum.Value(context.TestInitialSequenceNumber) 558 rcvWnd := seqnum.Size(30000) 559 c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) 560 561 c.SendPacket(nil, &context.Headers{ 562 SrcPort: context.TestPort, 563 DstPort: c.Port, 564 SeqNum: iss.Add(1), 565 AckNum: c.IRS.Add(1), 566 RcvWnd: rcvWnd, 567 Flags: header.TCPFlagRst, 568 }) 569 570 if got := stats.TCP.ResetsReceived.Value(); got != want { 571 t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) 572 } 573 c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) 574 } 575 576 func TestActiveHandshake(t *testing.T) { 577 c := context.New(t, defaultMTU) 578 defer c.Cleanup() 579 580 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 581 } 582 583 func TestNonBlockingClose(t *testing.T) { 584 c := context.New(t, defaultMTU) 585 defer c.Cleanup() 586 587 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 588 ep := c.EP 589 c.EP = nil 590 591 // Close the endpoint and measure how long it takes. 592 t0 := time.Now() 593 ep.Close() 594 if diff := time.Now().Sub(t0); diff > 3*time.Second { 595 t.Fatalf("Took too long to close: %s", diff) 596 } 597 } 598 599 func TestConnectResetAfterClose(t *testing.T) { 600 c := context.New(t, defaultMTU) 601 defer c.Cleanup() 602 603 // Set TCPLinger to 3 seconds so that sockets are marked closed 604 // after 3 second in FIN_WAIT2 state. 605 tcpLingerTimeout := 3 * time.Second 606 opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) 607 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 608 t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) 609 } 610 611 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 612 ep := c.EP 613 c.EP = nil 614 615 // Close the endpoint, make sure we get a FIN segment, then acknowledge 616 // to complete closure of sender, but don't send our own FIN. 617 ep.Close() 618 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 619 checker.IPv4(t, c.GetPacket(), 620 checker.TCP( 621 checker.DstPort(context.TestPort), 622 checker.TCPSeqNum(uint32(c.IRS)+1), 623 checker.TCPAckNum(uint32(iss)), 624 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 625 ), 626 ) 627 c.SendPacket(nil, &context.Headers{ 628 SrcPort: context.TestPort, 629 DstPort: c.Port, 630 Flags: header.TCPFlagAck, 631 SeqNum: iss, 632 AckNum: c.IRS.Add(2), 633 RcvWnd: 30000, 634 }) 635 636 // Wait for the ep to give up waiting for a FIN. 637 time.Sleep(tcpLingerTimeout + 1*time.Second) 638 639 // Now send an ACK and it should trigger a RST as the endpoint should 640 // not exist anymore. 641 c.SendPacket(nil, &context.Headers{ 642 SrcPort: context.TestPort, 643 DstPort: c.Port, 644 Flags: header.TCPFlagAck, 645 SeqNum: iss, 646 AckNum: c.IRS.Add(2), 647 RcvWnd: 30000, 648 }) 649 650 for { 651 b := c.GetPacket() 652 tcpHdr := header.TCP(header.IPv4(b).Payload()) 653 if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { 654 // This is a retransmit of the FIN, ignore it. 655 continue 656 } 657 658 checker.IPv4(t, b, 659 checker.TCP( 660 checker.DstPort(context.TestPort), 661 // RST is always generated with sndNxt which if the FIN 662 // has been sent will be 1 higher than the sequence number 663 // of the FIN itself. 664 checker.TCPSeqNum(uint32(c.IRS)+2), 665 checker.TCPAckNum(0), 666 checker.TCPFlags(header.TCPFlagRst), 667 ), 668 ) 669 break 670 } 671 } 672 673 // TestCurrentConnectedIncrement tests increment of the current 674 // established and connected counters. 675 func TestCurrentConnectedIncrement(t *testing.T) { 676 c := context.New(t, defaultMTU) 677 defer c.Cleanup() 678 679 // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed 680 // after 1 second in TIME_WAIT state. 681 tcpTimeWaitTimeout := 1 * time.Second 682 opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) 683 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 684 t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) 685 } 686 687 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 688 ep := c.EP 689 c.EP = nil 690 691 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { 692 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) 693 } 694 gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() 695 if gotConnected != 1 { 696 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) 697 } 698 699 ep.Close() 700 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 701 checker.IPv4(t, c.GetPacket(), 702 checker.TCP( 703 checker.DstPort(context.TestPort), 704 checker.TCPSeqNum(uint32(c.IRS)+1), 705 checker.TCPAckNum(uint32(iss)), 706 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 707 ), 708 ) 709 c.SendPacket(nil, &context.Headers{ 710 SrcPort: context.TestPort, 711 DstPort: c.Port, 712 Flags: header.TCPFlagAck, 713 SeqNum: iss, 714 AckNum: c.IRS.Add(2), 715 RcvWnd: 30000, 716 }) 717 718 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 719 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) 720 } 721 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { 722 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) 723 } 724 725 // Ack and send FIN as well. 726 c.SendPacket(nil, &context.Headers{ 727 SrcPort: context.TestPort, 728 DstPort: c.Port, 729 Flags: header.TCPFlagAck | header.TCPFlagFin, 730 SeqNum: iss, 731 AckNum: c.IRS.Add(2), 732 RcvWnd: 30000, 733 }) 734 735 // Check that the stack acks the FIN. 736 checker.IPv4(t, c.GetPacket(), 737 checker.PayloadLen(header.TCPMinimumSize), 738 checker.TCP( 739 checker.DstPort(context.TestPort), 740 checker.TCPSeqNum(uint32(c.IRS)+2), 741 checker.TCPAckNum(uint32(iss)+1), 742 checker.TCPFlags(header.TCPFlagAck), 743 ), 744 ) 745 746 // Wait for a little more than the TIME-WAIT duration for the socket to 747 // transition to CLOSED state. 748 time.Sleep(1200 * time.Millisecond) 749 750 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 751 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) 752 } 753 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 754 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 755 } 756 } 757 758 // TestClosingWithEnqueuedSegments tests handling of still enqueued segments 759 // when the endpoint transitions to StateClose. The in-flight segments would be 760 // re-enqueued to a any listening endpoint. 761 func TestClosingWithEnqueuedSegments(t *testing.T) { 762 c := context.New(t, defaultMTU) 763 defer c.Cleanup() 764 765 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 766 ep := c.EP 767 c.EP = nil 768 769 if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { 770 t.Errorf("unexpected endpoint state: want %d, got %d", want, got) 771 } 772 773 // Send a FIN for ESTABLISHED --> CLOSED-WAIT 774 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 775 c.SendPacket(nil, &context.Headers{ 776 SrcPort: context.TestPort, 777 DstPort: c.Port, 778 Flags: header.TCPFlagFin | header.TCPFlagAck, 779 SeqNum: iss, 780 AckNum: c.IRS.Add(1), 781 RcvWnd: 30000, 782 }) 783 784 // Get the ACK for the FIN we sent. 785 checker.IPv4(t, c.GetPacket(), 786 checker.TCP( 787 checker.DstPort(context.TestPort), 788 checker.TCPSeqNum(uint32(c.IRS)+1), 789 checker.TCPAckNum(uint32(iss)+1), 790 checker.TCPFlags(header.TCPFlagAck), 791 ), 792 ) 793 794 // Give the stack a few ms to transition the endpoint out of ESTABLISHED 795 // state. 796 time.Sleep(10 * time.Millisecond) 797 798 if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { 799 t.Errorf("unexpected endpoint state: want %d, got %d", want, got) 800 } 801 802 // Close the application endpoint for CLOSE_WAIT --> LAST_ACK 803 ep.Close() 804 805 // Get the FIN 806 checker.IPv4(t, c.GetPacket(), 807 checker.TCP( 808 checker.DstPort(context.TestPort), 809 checker.TCPSeqNum(uint32(c.IRS)+1), 810 checker.TCPAckNum(uint32(iss)+1), 811 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 812 ), 813 ) 814 815 if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { 816 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 817 } 818 819 // Pause the endpoint`s protocolMainLoop. 820 ep.(interface{ StopWork() }).StopWork() 821 822 // Enqueue last ACK followed by an ACK matching the endpoint 823 // 824 // Send Last ACK for LAST_ACK --> CLOSED 825 c.SendPacket(nil, &context.Headers{ 826 SrcPort: context.TestPort, 827 DstPort: c.Port, 828 Flags: header.TCPFlagAck, 829 SeqNum: iss.Add(1), 830 AckNum: c.IRS.Add(2), 831 RcvWnd: 30000, 832 }) 833 834 // Send a packet with ACK set, this would generate RST when 835 // not using SYN cookies as in this test. 836 c.SendPacket(nil, &context.Headers{ 837 SrcPort: context.TestPort, 838 DstPort: c.Port, 839 Flags: header.TCPFlagAck | header.TCPFlagFin, 840 SeqNum: iss.Add(2), 841 AckNum: c.IRS.Add(2), 842 RcvWnd: 30000, 843 }) 844 845 // Unpause endpoint`s protocolMainLoop. 846 ep.(interface{ ResumeWork() }).ResumeWork() 847 848 // Wait for the protocolMainLoop to resume and update state. 849 time.Sleep(10 * time.Millisecond) 850 851 // Expect the endpoint to be closed. 852 if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { 853 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 854 } 855 856 if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { 857 t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) 858 } 859 860 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 861 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) 862 } 863 864 // Check if the endpoint was moved to CLOSED and netstack a reset in 865 // response to the ACK packet that we sent after last-ACK. 866 checker.IPv4(t, c.GetPacket(), 867 checker.TCP( 868 checker.DstPort(context.TestPort), 869 checker.TCPSeqNum(uint32(c.IRS)+2), 870 checker.TCPAckNum(0), 871 checker.TCPFlags(header.TCPFlagRst), 872 ), 873 ) 874 } 875 876 func TestSimpleReceive(t *testing.T) { 877 c := context.New(t, defaultMTU) 878 defer c.Cleanup() 879 880 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 881 882 we, ch := waiter.NewChannelEntry(nil) 883 c.WQ.EventRegister(&we, waiter.ReadableEvents) 884 defer c.WQ.EventUnregister(&we) 885 886 ept := endpointTester{c.EP} 887 888 data := []byte{1, 2, 3} 889 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 890 c.SendPacket(data, &context.Headers{ 891 SrcPort: context.TestPort, 892 DstPort: c.Port, 893 Flags: header.TCPFlagAck, 894 SeqNum: iss, 895 AckNum: c.IRS.Add(1), 896 RcvWnd: 30000, 897 }) 898 899 // Wait for receive to be notified. 900 select { 901 case <-ch: 902 case <-time.After(1 * time.Second): 903 t.Fatalf("Timed out waiting for data to arrive") 904 } 905 906 // Receive data. 907 v := ept.CheckRead(t) 908 if !bytes.Equal(data, v) { 909 t.Fatalf("got data = %v, want = %v", v, data) 910 } 911 912 // Check that ACK is received. 913 checker.IPv4(t, c.GetPacket(), 914 checker.TCP( 915 checker.DstPort(context.TestPort), 916 checker.TCPSeqNum(uint32(c.IRS)+1), 917 checker.TCPAckNum(uint32(iss)+uint32(len(data))), 918 checker.TCPFlags(header.TCPFlagAck), 919 ), 920 ) 921 } 922 923 // TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when 924 // creating a new active TCP socket. It should be present in the sent TCP 925 // SYN segment. 926 func TestUserSuppliedMSSOnConnect(t *testing.T) { 927 const mtu = 5000 928 929 ips := []struct { 930 name string 931 createEP func(*context.Context) 932 connectAddr tcpip.Address 933 checker func(*testing.T, *context.Context, uint16, int) 934 maxMSS uint16 935 }{ 936 { 937 name: "IPv4", 938 createEP: func(c *context.Context) { 939 c.Create(-1) 940 }, 941 connectAddr: context.TestAddr, 942 checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { 943 checker.IPv4(t, c.GetPacket(), checker.TCP( 944 checker.DstPort(context.TestPort), 945 checker.TCPFlags(header.TCPFlagSyn), 946 checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) 947 }, 948 maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, 949 }, 950 { 951 name: "IPv6", 952 createEP: func(c *context.Context) { 953 c.CreateV6Endpoint(true) 954 }, 955 connectAddr: context.TestV6Addr, 956 checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { 957 checker.IPv6(t, c.GetV6Packet(), checker.TCP( 958 checker.DstPort(context.TestPort), 959 checker.TCPFlags(header.TCPFlagSyn), 960 checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) 961 }, 962 maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, 963 }, 964 } 965 966 for _, ip := range ips { 967 t.Run(ip.name, func(t *testing.T) { 968 tests := []struct { 969 name string 970 setMSS uint16 971 expMSS uint16 972 }{ 973 { 974 name: "EqualToMaxMSS", 975 setMSS: ip.maxMSS, 976 expMSS: ip.maxMSS, 977 }, 978 { 979 name: "LessThanMaxMSS", 980 setMSS: ip.maxMSS - 1, 981 expMSS: ip.maxMSS - 1, 982 }, 983 { 984 name: "GreaterThanMaxMSS", 985 setMSS: ip.maxMSS + 1, 986 expMSS: ip.maxMSS, 987 }, 988 } 989 990 for _, test := range tests { 991 t.Run(test.name, func(t *testing.T) { 992 c := context.New(t, mtu) 993 defer c.Cleanup() 994 995 ip.createEP(c) 996 997 // Set the MSS socket option. 998 if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { 999 t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) 1000 } 1001 1002 // Get expected window size. 1003 rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() 1004 ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) 1005 1006 connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} 1007 { 1008 err := c.EP.Connect(connectAddr) 1009 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 1010 t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) 1011 } 1012 } 1013 1014 // Receive SYN packet with our user supplied MSS. 1015 ip.checker(t, c, test.expMSS, ws) 1016 }) 1017 } 1018 }) 1019 } 1020 } 1021 1022 // TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used 1023 // when completing the handshake for a new TCP connection from a TCP 1024 // listening socket. It should be present in the sent TCP SYN-ACK segment. 1025 func TestUserSuppliedMSSOnListenAccept(t *testing.T) { 1026 const mtu = 5000 1027 1028 ips := []struct { 1029 name string 1030 createEP func(*context.Context) 1031 sendPkt func(*context.Context, *context.Headers) 1032 checker func(*testing.T, *context.Context, uint16, uint16) 1033 maxMSS uint16 1034 }{ 1035 { 1036 name: "IPv4", 1037 createEP: func(c *context.Context) { 1038 c.Create(-1) 1039 }, 1040 sendPkt: func(c *context.Context, h *context.Headers) { 1041 c.SendPacket(nil, h) 1042 }, 1043 checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { 1044 checker.IPv4(t, c.GetPacket(), checker.TCP( 1045 checker.DstPort(srcPort), 1046 checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), 1047 checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) 1048 }, 1049 maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, 1050 }, 1051 { 1052 name: "IPv6", 1053 createEP: func(c *context.Context) { 1054 c.CreateV6Endpoint(false) 1055 }, 1056 sendPkt: func(c *context.Context, h *context.Headers) { 1057 c.SendV6Packet(nil, h) 1058 }, 1059 checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { 1060 checker.IPv6(t, c.GetV6Packet(), checker.TCP( 1061 checker.DstPort(srcPort), 1062 checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), 1063 checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) 1064 }, 1065 maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, 1066 }, 1067 } 1068 1069 for _, ip := range ips { 1070 t.Run(ip.name, func(t *testing.T) { 1071 tests := []struct { 1072 name string 1073 setMSS uint16 1074 expMSS uint16 1075 }{ 1076 { 1077 name: "EqualToMaxMSS", 1078 setMSS: ip.maxMSS, 1079 expMSS: ip.maxMSS, 1080 }, 1081 { 1082 name: "LessThanMaxMSS", 1083 setMSS: ip.maxMSS - 1, 1084 expMSS: ip.maxMSS - 1, 1085 }, 1086 { 1087 name: "GreaterThanMaxMSS", 1088 setMSS: ip.maxMSS + 1, 1089 expMSS: ip.maxMSS, 1090 }, 1091 } 1092 1093 for _, test := range tests { 1094 t.Run(test.name, func(t *testing.T) { 1095 c := context.New(t, mtu) 1096 defer c.Cleanup() 1097 1098 ip.createEP(c) 1099 1100 if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { 1101 t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) 1102 } 1103 1104 bindAddr := tcpip.FullAddress{Port: context.StackPort} 1105 if err := c.EP.Bind(bindAddr); err != nil { 1106 t.Fatalf("Bind(%+v): %s:", bindAddr, err) 1107 } 1108 1109 backlog := 5 1110 // Keep the number of client requests twice to the backlog 1111 // such that half of the connections do not use syncookies 1112 // and the other half does. 1113 clientConnects := backlog * 2 1114 1115 if err := c.EP.Listen(backlog); err != nil { 1116 t.Fatalf("Listen(%d): %s:", backlog, err) 1117 } 1118 1119 for i := 0; i < clientConnects; i++ { 1120 // Send a SYN requests. 1121 iss := seqnum.Value(i) 1122 srcPort := context.TestPort + uint16(i) 1123 ip.sendPkt(c, &context.Headers{ 1124 SrcPort: srcPort, 1125 DstPort: context.StackPort, 1126 Flags: header.TCPFlagSyn, 1127 SeqNum: iss, 1128 }) 1129 1130 // Receive the SYN-ACK reply. 1131 ip.checker(t, c, srcPort, test.expMSS) 1132 } 1133 }) 1134 } 1135 }) 1136 } 1137 } 1138 func TestSendRstOnListenerRxSynAckV4(t *testing.T) { 1139 c := context.New(t, defaultMTU) 1140 defer c.Cleanup() 1141 1142 c.Create(-1) 1143 1144 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1145 t.Fatal("Bind failed:", err) 1146 } 1147 1148 if err := c.EP.Listen(10); err != nil { 1149 t.Fatal("Listen failed:", err) 1150 } 1151 1152 c.SendPacket(nil, &context.Headers{ 1153 SrcPort: context.TestPort, 1154 DstPort: context.StackPort, 1155 Flags: header.TCPFlagSyn | header.TCPFlagAck, 1156 SeqNum: 100, 1157 AckNum: 200, 1158 }) 1159 1160 checker.IPv4(t, c.GetPacket(), checker.TCP( 1161 checker.DstPort(context.TestPort), 1162 checker.TCPFlags(header.TCPFlagRst), 1163 checker.TCPSeqNum(200))) 1164 } 1165 1166 func TestSendRstOnListenerRxSynAckV6(t *testing.T) { 1167 c := context.New(t, defaultMTU) 1168 defer c.Cleanup() 1169 1170 c.CreateV6Endpoint(true) 1171 1172 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1173 t.Fatal("Bind failed:", err) 1174 } 1175 1176 if err := c.EP.Listen(10); err != nil { 1177 t.Fatal("Listen failed:", err) 1178 } 1179 1180 c.SendV6Packet(nil, &context.Headers{ 1181 SrcPort: context.TestPort, 1182 DstPort: context.StackPort, 1183 Flags: header.TCPFlagSyn | header.TCPFlagAck, 1184 SeqNum: 100, 1185 AckNum: 200, 1186 }) 1187 1188 checker.IPv6(t, c.GetV6Packet(), checker.TCP( 1189 checker.DstPort(context.TestPort), 1190 checker.TCPFlags(header.TCPFlagRst), 1191 checker.TCPSeqNum(200))) 1192 } 1193 1194 // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, 1195 // peers can send data and expect a response within a reasonable ammount of time 1196 // without calling Accept on the listening endpoint first. 1197 // 1198 // This test uses IPv4. 1199 func TestTCPAckBeforeAcceptV4(t *testing.T) { 1200 c := context.New(t, defaultMTU) 1201 defer c.Cleanup() 1202 1203 c.Create(-1) 1204 1205 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1206 t.Fatal("Bind failed:", err) 1207 } 1208 1209 if err := c.EP.Listen(10); err != nil { 1210 t.Fatal("Listen failed:", err) 1211 } 1212 1213 irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) 1214 1215 // Send data before accepting the connection. 1216 c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ 1217 SrcPort: context.TestPort, 1218 DstPort: context.StackPort, 1219 Flags: header.TCPFlagAck, 1220 SeqNum: irs + 1, 1221 AckNum: iss + 1, 1222 }) 1223 1224 // Receive ACK for the data we sent. 1225 checker.IPv4(t, c.GetPacket(), checker.TCP( 1226 checker.DstPort(context.TestPort), 1227 checker.TCPFlags(header.TCPFlagAck), 1228 checker.TCPSeqNum(uint32(iss+1)), 1229 checker.TCPAckNum(uint32(irs+5)))) 1230 } 1231 1232 // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, 1233 // peers can send data and expect a response within a reasonable ammount of time 1234 // without calling Accept on the listening endpoint first. 1235 // 1236 // This test uses IPv6. 1237 func TestTCPAckBeforeAcceptV6(t *testing.T) { 1238 c := context.New(t, defaultMTU) 1239 defer c.Cleanup() 1240 1241 c.CreateV6Endpoint(true) 1242 1243 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1244 t.Fatal("Bind failed:", err) 1245 } 1246 1247 if err := c.EP.Listen(10); err != nil { 1248 t.Fatal("Listen failed:", err) 1249 } 1250 1251 irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) 1252 1253 // Send data before accepting the connection. 1254 c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ 1255 SrcPort: context.TestPort, 1256 DstPort: context.StackPort, 1257 Flags: header.TCPFlagAck, 1258 SeqNum: irs + 1, 1259 AckNum: iss + 1, 1260 }) 1261 1262 // Receive ACK for the data we sent. 1263 checker.IPv6(t, c.GetV6Packet(), checker.TCP( 1264 checker.DstPort(context.TestPort), 1265 checker.TCPFlags(header.TCPFlagAck), 1266 checker.TCPSeqNum(uint32(iss+1)), 1267 checker.TCPAckNum(uint32(irs+5)))) 1268 } 1269 1270 func TestSendRstOnListenerRxAckV4(t *testing.T) { 1271 c := context.New(t, defaultMTU) 1272 defer c.Cleanup() 1273 1274 c.Create(-1 /* epRcvBuf */) 1275 1276 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1277 t.Fatal("Bind failed:", err) 1278 } 1279 1280 if err := c.EP.Listen(10 /* backlog */); err != nil { 1281 t.Fatal("Listen failed:", err) 1282 } 1283 1284 c.SendPacket(nil, &context.Headers{ 1285 SrcPort: context.TestPort, 1286 DstPort: context.StackPort, 1287 Flags: header.TCPFlagFin | header.TCPFlagAck, 1288 SeqNum: 100, 1289 AckNum: 200, 1290 }) 1291 1292 checker.IPv4(t, c.GetPacket(), checker.TCP( 1293 checker.DstPort(context.TestPort), 1294 checker.TCPFlags(header.TCPFlagRst), 1295 checker.TCPSeqNum(200))) 1296 } 1297 1298 func TestSendRstOnListenerRxAckV6(t *testing.T) { 1299 c := context.New(t, defaultMTU) 1300 defer c.Cleanup() 1301 1302 c.CreateV6Endpoint(true /* v6Only */) 1303 1304 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1305 t.Fatal("Bind failed:", err) 1306 } 1307 1308 if err := c.EP.Listen(10 /* backlog */); err != nil { 1309 t.Fatal("Listen failed:", err) 1310 } 1311 1312 c.SendV6Packet(nil, &context.Headers{ 1313 SrcPort: context.TestPort, 1314 DstPort: context.StackPort, 1315 Flags: header.TCPFlagFin | header.TCPFlagAck, 1316 SeqNum: 100, 1317 AckNum: 200, 1318 }) 1319 1320 checker.IPv6(t, c.GetV6Packet(), checker.TCP( 1321 checker.DstPort(context.TestPort), 1322 checker.TCPFlags(header.TCPFlagRst), 1323 checker.TCPSeqNum(200))) 1324 } 1325 1326 // TestListenShutdown tests for the listening endpoint replying with RST 1327 // on read shutdown. 1328 func TestListenShutdown(t *testing.T) { 1329 c := context.New(t, defaultMTU) 1330 defer c.Cleanup() 1331 1332 c.Create(-1 /* epRcvBuf */) 1333 1334 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1335 t.Fatal("Bind failed:", err) 1336 } 1337 1338 if err := c.EP.Listen(1 /* backlog */); err != nil { 1339 t.Fatal("Listen failed:", err) 1340 } 1341 1342 if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { 1343 t.Fatal("Shutdown failed:", err) 1344 } 1345 1346 c.SendPacket(nil, &context.Headers{ 1347 SrcPort: context.TestPort, 1348 DstPort: context.StackPort, 1349 Flags: header.TCPFlagSyn, 1350 SeqNum: 100, 1351 AckNum: 200, 1352 }) 1353 1354 // Expect the listening endpoint to reset the connection. 1355 checker.IPv4(t, c.GetPacket(), 1356 checker.TCP( 1357 checker.DstPort(context.TestPort), 1358 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), 1359 )) 1360 } 1361 1362 var _ waiter.EntryCallback = (callback)(nil) 1363 1364 type callback func(*waiter.Entry, waiter.EventMask) 1365 1366 func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { 1367 cb(entry, mask) 1368 } 1369 1370 func TestListenerReadinessOnEvent(t *testing.T) { 1371 s := stack.New(stack.Options{ 1372 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, 1373 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 1374 }) 1375 { 1376 ep := loopback.New() 1377 if testing.Verbose() { 1378 ep = sniffer.New(ep) 1379 } 1380 const id = 1 1381 if err := s.CreateNIC(id, ep); err != nil { 1382 t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) 1383 } 1384 if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { 1385 t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) 1386 } 1387 s.SetRouteTable([]tcpip.Route{ 1388 {Destination: header.IPv4EmptySubnet, NIC: id}, 1389 }) 1390 } 1391 1392 var wq waiter.Queue 1393 ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 1394 if err != nil { 1395 t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) 1396 } 1397 defer ep.Close() 1398 1399 if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { 1400 t.Fatalf("Bind(%s): %s", context.StackAddr, err) 1401 } 1402 const backlog = 1 1403 if err := ep.Listen(backlog); err != nil { 1404 t.Fatalf("Listen(%d): %s", backlog, err) 1405 } 1406 1407 address, err := ep.GetLocalAddress() 1408 if err != nil { 1409 t.Fatalf("GetLocalAddress(): %s", err) 1410 } 1411 1412 conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 1413 if err != nil { 1414 t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) 1415 } 1416 defer conn.Close() 1417 1418 events := make(chan waiter.EventMask) 1419 // Scope `entry` to allow a binding of the same name below. 1420 { 1421 entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { 1422 events <- ep.Readiness(mask) 1423 })} 1424 wq.EventRegister(&entry, waiter.EventIn) 1425 defer wq.EventUnregister(&entry) 1426 } 1427 1428 entry, ch := waiter.NewChannelEntry(nil) 1429 wq.EventRegister(&entry, waiter.EventOut) 1430 defer wq.EventUnregister(&entry) 1431 1432 switch err := conn.Connect(address).(type) { 1433 case *tcpip.ErrConnectStarted: 1434 default: 1435 t.Fatalf("Connect(%#v): %v", address, err) 1436 } 1437 1438 // Read at least one event. 1439 got := <-events 1440 for { 1441 select { 1442 case event := <-events: 1443 got |= event 1444 continue 1445 case <-ch: 1446 if want := waiter.ReadableEvents; got != want { 1447 t.Errorf("observed events = %b, want %b", got, want) 1448 } 1449 } 1450 break 1451 } 1452 } 1453 1454 // TestListenCloseWhileConnect tests for the listening endpoint to 1455 // drain the accept-queue when closed. This should reset all of the 1456 // pending connections that are waiting to be accepted. 1457 func TestListenCloseWhileConnect(t *testing.T) { 1458 c := context.New(t, defaultMTU) 1459 defer c.Cleanup() 1460 1461 c.Create(-1 /* epRcvBuf */) 1462 1463 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1464 t.Fatal("Bind failed:", err) 1465 } 1466 1467 if err := c.EP.Listen(1 /* backlog */); err != nil { 1468 t.Fatal("Listen failed:", err) 1469 } 1470 1471 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 1472 c.WQ.EventRegister(&waitEntry, waiter.ReadableEvents) 1473 defer c.WQ.EventUnregister(&waitEntry) 1474 1475 executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) 1476 // Wait for the new endpoint created because of handshake to be delivered 1477 // to the listening endpoint's accept queue. 1478 <-notifyCh 1479 1480 // Close the listening endpoint. 1481 c.EP.Close() 1482 1483 // Expect the listening endpoint to reset the connection. 1484 checker.IPv4(t, c.GetPacket(), 1485 checker.TCP( 1486 checker.DstPort(context.TestPort), 1487 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), 1488 )) 1489 } 1490 1491 func TestTOSV4(t *testing.T) { 1492 c := context.New(t, defaultMTU) 1493 defer c.Cleanup() 1494 1495 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 1496 if err != nil { 1497 t.Fatalf("NewEndpoint failed: %s", err) 1498 } 1499 c.EP = ep 1500 1501 const tos = 0xC0 1502 if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { 1503 t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) 1504 } 1505 1506 v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) 1507 if err != nil { 1508 t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) 1509 } 1510 1511 if v != tos { 1512 t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) 1513 } 1514 1515 testV4Connect(t, c, checker.TOS(tos, 0)) 1516 1517 data := []byte{1, 2, 3} 1518 var r bytes.Reader 1519 r.Reset(data) 1520 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 1521 t.Fatalf("Write failed: %s", err) 1522 } 1523 1524 // Check that data is received. 1525 b := c.GetPacket() 1526 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 1527 checker.IPv4(t, b, 1528 checker.PayloadLen(len(data)+header.TCPMinimumSize), 1529 checker.TCP( 1530 checker.DstPort(context.TestPort), 1531 checker.TCPSeqNum(uint32(c.IRS)+1), 1532 checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 1533 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 1534 ), 1535 checker.TOS(tos, 0), 1536 ) 1537 1538 if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { 1539 t.Errorf("got data = %x, want = %x", p, data) 1540 } 1541 } 1542 1543 func TestTrafficClassV6(t *testing.T) { 1544 c := context.New(t, defaultMTU) 1545 defer c.Cleanup() 1546 1547 c.CreateV6Endpoint(false) 1548 1549 const tos = 0xC0 1550 if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { 1551 t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) 1552 } 1553 1554 v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) 1555 if err != nil { 1556 t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) 1557 } 1558 1559 if v != tos { 1560 t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) 1561 } 1562 1563 // Test the connection request. 1564 testV6Connect(t, c, checker.TOS(tos, 0)) 1565 1566 data := []byte{1, 2, 3} 1567 var r bytes.Reader 1568 r.Reset(data) 1569 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 1570 t.Fatalf("Write failed: %s", err) 1571 } 1572 1573 // Check that data is received. 1574 b := c.GetV6Packet() 1575 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 1576 checker.IPv6(t, b, 1577 checker.PayloadLen(len(data)+header.TCPMinimumSize), 1578 checker.TCP( 1579 checker.DstPort(context.TestPort), 1580 checker.TCPSeqNum(uint32(c.IRS)+1), 1581 checker.TCPAckNum(uint32(iss)), 1582 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 1583 ), 1584 checker.TOS(tos, 0), 1585 ) 1586 1587 if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { 1588 t.Errorf("got data = %x, want = %x", p, data) 1589 } 1590 } 1591 1592 func TestConnectBindToDevice(t *testing.T) { 1593 for _, test := range []struct { 1594 name string 1595 device tcpip.NICID 1596 want tcp.EndpointState 1597 }{ 1598 {"RightDevice", 1, tcp.StateEstablished}, 1599 {"WrongDevice", 2, tcp.StateSynSent}, 1600 {"AnyDevice", 0, tcp.StateEstablished}, 1601 } { 1602 t.Run(test.name, func(t *testing.T) { 1603 c := context.New(t, defaultMTU) 1604 defer c.Cleanup() 1605 1606 c.Create(-1) 1607 if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { 1608 t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) 1609 } 1610 // Start connection attempt. 1611 waitEntry, _ := waiter.NewChannelEntry(nil) 1612 c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) 1613 defer c.WQ.EventUnregister(&waitEntry) 1614 1615 err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 1616 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 1617 t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) 1618 } 1619 1620 // Receive SYN packet. 1621 b := c.GetPacket() 1622 checker.IPv4(t, b, 1623 checker.TCP( 1624 checker.DstPort(context.TestPort), 1625 checker.TCPFlags(header.TCPFlagSyn), 1626 ), 1627 ) 1628 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { 1629 t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) 1630 } 1631 tcpHdr := header.TCP(header.IPv4(b).Payload()) 1632 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 1633 1634 iss := seqnum.Value(context.TestInitialSequenceNumber) 1635 rcvWnd := seqnum.Size(30000) 1636 c.SendPacket(nil, &context.Headers{ 1637 SrcPort: tcpHdr.DestinationPort(), 1638 DstPort: tcpHdr.SourcePort(), 1639 Flags: header.TCPFlagSyn | header.TCPFlagAck, 1640 SeqNum: iss, 1641 AckNum: c.IRS.Add(1), 1642 RcvWnd: rcvWnd, 1643 TCPOpts: nil, 1644 }) 1645 1646 c.GetPacket() 1647 if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { 1648 t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) 1649 } 1650 }) 1651 } 1652 } 1653 1654 func TestSynSent(t *testing.T) { 1655 for _, test := range []struct { 1656 name string 1657 reset bool 1658 }{ 1659 {"RstOnSynSent", true}, 1660 {"CloseOnSynSent", false}, 1661 } { 1662 t.Run(test.name, func(t *testing.T) { 1663 c := context.New(t, defaultMTU) 1664 defer c.Cleanup() 1665 1666 // Create an endpoint, don't handshake because we want to interfere with the 1667 // handshake process. 1668 c.Create(-1) 1669 1670 // Start connection attempt. 1671 waitEntry, ch := waiter.NewChannelEntry(nil) 1672 c.WQ.EventRegister(&waitEntry, waiter.EventHUp) 1673 defer c.WQ.EventUnregister(&waitEntry) 1674 1675 addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} 1676 err := c.EP.Connect(addr) 1677 if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { 1678 t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) 1679 } 1680 1681 // Receive SYN packet. 1682 b := c.GetPacket() 1683 checker.IPv4(t, b, 1684 checker.TCP( 1685 checker.DstPort(context.TestPort), 1686 checker.TCPFlags(header.TCPFlagSyn), 1687 ), 1688 ) 1689 1690 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { 1691 t.Fatalf("got State() = %s, want %s", got, want) 1692 } 1693 tcpHdr := header.TCP(header.IPv4(b).Payload()) 1694 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 1695 1696 if test.reset { 1697 // Send a packet with a proper ACK and a RST flag to cause the socket 1698 // to error and close out. 1699 iss := seqnum.Value(context.TestInitialSequenceNumber) 1700 rcvWnd := seqnum.Size(30000) 1701 c.SendPacket(nil, &context.Headers{ 1702 SrcPort: tcpHdr.DestinationPort(), 1703 DstPort: tcpHdr.SourcePort(), 1704 Flags: header.TCPFlagRst | header.TCPFlagAck, 1705 SeqNum: iss, 1706 AckNum: c.IRS.Add(1), 1707 RcvWnd: rcvWnd, 1708 TCPOpts: nil, 1709 }) 1710 } else { 1711 c.EP.Close() 1712 } 1713 1714 // Wait for receive to be notified. 1715 select { 1716 case <-ch: 1717 case <-time.After(3 * time.Second): 1718 t.Fatal("timed out waiting for packet to arrive") 1719 } 1720 1721 ept := endpointTester{c.EP} 1722 if test.reset { 1723 ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) 1724 } else { 1725 ept.CheckReadError(t, &tcpip.ErrAborted{}) 1726 } 1727 1728 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 1729 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 1730 } 1731 1732 // Due to the RST the endpoint should be in an error state. 1733 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { 1734 t.Fatalf("got State() = %s, want %s", got, want) 1735 } 1736 }) 1737 } 1738 } 1739 1740 func TestOutOfOrderReceive(t *testing.T) { 1741 c := context.New(t, defaultMTU) 1742 defer c.Cleanup() 1743 1744 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 1745 1746 we, ch := waiter.NewChannelEntry(nil) 1747 c.WQ.EventRegister(&we, waiter.ReadableEvents) 1748 defer c.WQ.EventUnregister(&we) 1749 1750 ept := endpointTester{c.EP} 1751 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 1752 1753 // Send second half of data first, with seqnum 3 ahead of expected. 1754 data := []byte{1, 2, 3, 4, 5, 6} 1755 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 1756 c.SendPacket(data[3:], &context.Headers{ 1757 SrcPort: context.TestPort, 1758 DstPort: c.Port, 1759 Flags: header.TCPFlagAck, 1760 SeqNum: iss.Add(3), 1761 AckNum: c.IRS.Add(1), 1762 RcvWnd: 30000, 1763 }) 1764 1765 // Check that we get an ACK specifying which seqnum is expected. 1766 checker.IPv4(t, c.GetPacket(), 1767 checker.TCP( 1768 checker.DstPort(context.TestPort), 1769 checker.TCPSeqNum(uint32(c.IRS)+1), 1770 checker.TCPAckNum(uint32(iss)), 1771 checker.TCPFlags(header.TCPFlagAck), 1772 ), 1773 ) 1774 1775 // Wait 200ms and check that no data has been received. 1776 time.Sleep(200 * time.Millisecond) 1777 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 1778 1779 // Send the first 3 bytes now. 1780 c.SendPacket(data[:3], &context.Headers{ 1781 SrcPort: context.TestPort, 1782 DstPort: c.Port, 1783 Flags: header.TCPFlagAck, 1784 SeqNum: iss, 1785 AckNum: c.IRS.Add(1), 1786 RcvWnd: 30000, 1787 }) 1788 1789 // Receive data. 1790 read := ept.CheckReadFull(t, 6, ch, 5*time.Second) 1791 1792 // Check that we received the data in proper order. 1793 if !bytes.Equal(data, read) { 1794 t.Fatalf("got data = %v, want = %v", read, data) 1795 } 1796 1797 // Check that the whole data is acknowledged. 1798 checker.IPv4(t, c.GetPacket(), 1799 checker.TCP( 1800 checker.DstPort(context.TestPort), 1801 checker.TCPSeqNum(uint32(c.IRS)+1), 1802 checker.TCPAckNum(uint32(iss)+uint32(len(data))), 1803 checker.TCPFlags(header.TCPFlagAck), 1804 ), 1805 ) 1806 } 1807 1808 func TestOutOfOrderFlood(t *testing.T) { 1809 c := context.New(t, defaultMTU) 1810 defer c.Cleanup() 1811 1812 rcvBufSz := math.MaxUint16 1813 c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) 1814 1815 ept := endpointTester{c.EP} 1816 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 1817 1818 // Send 100 packets before the actual one that is expected. 1819 data := []byte{1, 2, 3, 4, 5, 6} 1820 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 1821 for i := 0; i < 100; i++ { 1822 c.SendPacket(data[3:], &context.Headers{ 1823 SrcPort: context.TestPort, 1824 DstPort: c.Port, 1825 Flags: header.TCPFlagAck, 1826 SeqNum: iss.Add(6), 1827 AckNum: c.IRS.Add(1), 1828 RcvWnd: 30000, 1829 }) 1830 1831 checker.IPv4(t, c.GetPacket(), 1832 checker.TCP( 1833 checker.DstPort(context.TestPort), 1834 checker.TCPSeqNum(uint32(c.IRS)+1), 1835 checker.TCPAckNum(uint32(iss)), 1836 checker.TCPFlags(header.TCPFlagAck), 1837 ), 1838 ) 1839 } 1840 1841 // Send packet with seqnum as initial + 3. It must be discarded because the 1842 // out-of-order buffer was filled by the previous packets. 1843 c.SendPacket(data[3:], &context.Headers{ 1844 SrcPort: context.TestPort, 1845 DstPort: c.Port, 1846 Flags: header.TCPFlagAck, 1847 SeqNum: iss.Add(3), 1848 AckNum: c.IRS.Add(1), 1849 RcvWnd: 30000, 1850 }) 1851 1852 checker.IPv4(t, c.GetPacket(), 1853 checker.TCP( 1854 checker.DstPort(context.TestPort), 1855 checker.TCPSeqNum(uint32(c.IRS)+1), 1856 checker.TCPAckNum(uint32(iss)), 1857 checker.TCPFlags(header.TCPFlagAck), 1858 ), 1859 ) 1860 1861 // Now send the expected packet with initial sequence number. 1862 c.SendPacket(data[:3], &context.Headers{ 1863 SrcPort: context.TestPort, 1864 DstPort: c.Port, 1865 Flags: header.TCPFlagAck, 1866 SeqNum: iss, 1867 AckNum: c.IRS.Add(1), 1868 RcvWnd: 30000, 1869 }) 1870 1871 // Check that only packet with initial sequence number is acknowledged. 1872 checker.IPv4(t, c.GetPacket(), 1873 checker.TCP( 1874 checker.DstPort(context.TestPort), 1875 checker.TCPSeqNum(uint32(c.IRS)+1), 1876 checker.TCPAckNum(uint32(iss)+3), 1877 checker.TCPFlags(header.TCPFlagAck), 1878 ), 1879 ) 1880 } 1881 1882 func TestRstOnCloseWithUnreadData(t *testing.T) { 1883 c := context.New(t, defaultMTU) 1884 defer c.Cleanup() 1885 1886 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 1887 1888 we, ch := waiter.NewChannelEntry(nil) 1889 c.WQ.EventRegister(&we, waiter.ReadableEvents) 1890 defer c.WQ.EventUnregister(&we) 1891 1892 ept := endpointTester{c.EP} 1893 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 1894 1895 data := []byte{1, 2, 3} 1896 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 1897 c.SendPacket(data, &context.Headers{ 1898 SrcPort: context.TestPort, 1899 DstPort: c.Port, 1900 Flags: header.TCPFlagAck, 1901 SeqNum: iss, 1902 AckNum: c.IRS.Add(1), 1903 RcvWnd: 30000, 1904 }) 1905 1906 // Wait for receive to be notified. 1907 select { 1908 case <-ch: 1909 case <-time.After(3 * time.Second): 1910 t.Fatalf("Timed out waiting for data to arrive") 1911 } 1912 1913 // Check that ACK is received, this happens regardless of the read. 1914 checker.IPv4(t, c.GetPacket(), 1915 checker.TCP( 1916 checker.DstPort(context.TestPort), 1917 checker.TCPSeqNum(uint32(c.IRS)+1), 1918 checker.TCPAckNum(uint32(iss)+uint32(len(data))), 1919 checker.TCPFlags(header.TCPFlagAck), 1920 ), 1921 ) 1922 1923 // Now that we know we have unread data, let's just close the connection 1924 // and verify that netstack sends an RST rather than a FIN. 1925 c.EP.Close() 1926 1927 checker.IPv4(t, c.GetPacket(), 1928 checker.TCP( 1929 checker.DstPort(context.TestPort), 1930 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), 1931 // We shouldn't consume a sequence number on RST. 1932 checker.TCPSeqNum(uint32(c.IRS)+1), 1933 )) 1934 // The RST puts the endpoint into an error state. 1935 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { 1936 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 1937 } 1938 1939 // This final ACK should be ignored because an ACK on a reset doesn't mean 1940 // anything. 1941 c.SendPacket(nil, &context.Headers{ 1942 SrcPort: context.TestPort, 1943 DstPort: c.Port, 1944 Flags: header.TCPFlagAck, 1945 SeqNum: iss.Add(seqnum.Size(len(data))), 1946 AckNum: c.IRS.Add(seqnum.Size(2)), 1947 RcvWnd: 30000, 1948 }) 1949 } 1950 1951 func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { 1952 c := context.New(t, defaultMTU) 1953 defer c.Cleanup() 1954 1955 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 1956 1957 we, ch := waiter.NewChannelEntry(nil) 1958 c.WQ.EventRegister(&we, waiter.ReadableEvents) 1959 defer c.WQ.EventUnregister(&we) 1960 1961 ept := endpointTester{c.EP} 1962 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 1963 1964 data := []byte{1, 2, 3} 1965 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 1966 c.SendPacket(data, &context.Headers{ 1967 SrcPort: context.TestPort, 1968 DstPort: c.Port, 1969 Flags: header.TCPFlagAck, 1970 SeqNum: iss, 1971 AckNum: c.IRS.Add(1), 1972 RcvWnd: 30000, 1973 }) 1974 1975 // Wait for receive to be notified. 1976 select { 1977 case <-ch: 1978 case <-time.After(3 * time.Second): 1979 t.Fatalf("Timed out waiting for data to arrive") 1980 } 1981 1982 // Check that ACK is received, this happens regardless of the read. 1983 checker.IPv4(t, c.GetPacket(), 1984 checker.TCP( 1985 checker.DstPort(context.TestPort), 1986 checker.TCPSeqNum(uint32(c.IRS)+1), 1987 checker.TCPAckNum(uint32(iss)+uint32(len(data))), 1988 checker.TCPFlags(header.TCPFlagAck), 1989 ), 1990 ) 1991 1992 // Cause a FIN to be generated. 1993 c.EP.Shutdown(tcpip.ShutdownWrite) 1994 1995 // Make sure we get the FIN but DON't ACK IT. 1996 checker.IPv4(t, c.GetPacket(), 1997 checker.TCP( 1998 checker.DstPort(context.TestPort), 1999 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 2000 checker.TCPSeqNum(uint32(c.IRS)+1), 2001 )) 2002 2003 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { 2004 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 2005 } 2006 2007 // Cause a RST to be generated by closing the read end now since we have 2008 // unread data. 2009 c.EP.Shutdown(tcpip.ShutdownRead) 2010 2011 // Make sure we get the RST 2012 checker.IPv4(t, c.GetPacket(), 2013 checker.TCP( 2014 checker.DstPort(context.TestPort), 2015 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), 2016 // RST is always generated with sndNxt which if the FIN 2017 // has been sent will be 1 higher than the sequence 2018 // number of the FIN itself. 2019 checker.TCPSeqNum(uint32(c.IRS)+2), 2020 )) 2021 // The RST puts the endpoint into an error state. 2022 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { 2023 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 2024 } 2025 2026 // The ACK to the FIN should now be rejected since the connection has been 2027 // closed by a RST. 2028 c.SendPacket(nil, &context.Headers{ 2029 SrcPort: context.TestPort, 2030 DstPort: c.Port, 2031 Flags: header.TCPFlagAck, 2032 SeqNum: iss.Add(seqnum.Size(len(data))), 2033 AckNum: c.IRS.Add(seqnum.Size(2)), 2034 RcvWnd: 30000, 2035 }) 2036 } 2037 2038 func TestShutdownRead(t *testing.T) { 2039 c := context.New(t, defaultMTU) 2040 defer c.Cleanup() 2041 2042 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 2043 2044 ept := endpointTester{c.EP} 2045 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 2046 2047 if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { 2048 t.Fatalf("Shutdown failed: %s", err) 2049 } 2050 2051 ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) 2052 var want uint64 = 1 2053 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { 2054 t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) 2055 } 2056 } 2057 2058 func TestFullWindowReceive(t *testing.T) { 2059 c := context.New(t, defaultMTU) 2060 defer c.Cleanup() 2061 2062 const rcvBufSz = 10 2063 c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) 2064 2065 we, ch := waiter.NewChannelEntry(nil) 2066 c.WQ.EventRegister(&we, waiter.ReadableEvents) 2067 defer c.WQ.EventUnregister(&we) 2068 2069 ept := endpointTester{c.EP} 2070 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 2071 2072 // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies 2073 // the provided buffer value by tcp.SegOverheadFactor to calculate the actual 2074 // receive buffer size. 2075 data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) 2076 for i := range data { 2077 data[i] = byte(i % 255) 2078 } 2079 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2080 c.SendPacket(data, &context.Headers{ 2081 SrcPort: context.TestPort, 2082 DstPort: c.Port, 2083 Flags: header.TCPFlagAck, 2084 SeqNum: iss, 2085 AckNum: c.IRS.Add(1), 2086 RcvWnd: 30000, 2087 }) 2088 2089 // Wait for receive to be notified. 2090 select { 2091 case <-ch: 2092 case <-time.After(5 * time.Second): 2093 t.Fatalf("Timed out waiting for data to arrive") 2094 } 2095 2096 // Check that data is acknowledged, and window goes to zero. 2097 checker.IPv4(t, c.GetPacket(), 2098 checker.TCP( 2099 checker.DstPort(context.TestPort), 2100 checker.TCPSeqNum(uint32(c.IRS)+1), 2101 checker.TCPAckNum(uint32(iss)+uint32(len(data))), 2102 checker.TCPFlags(header.TCPFlagAck), 2103 checker.TCPWindow(0), 2104 ), 2105 ) 2106 2107 // Receive data and check it. 2108 v := ept.CheckRead(t) 2109 if !bytes.Equal(data, v) { 2110 t.Fatalf("got data = %v, want = %v", v, data) 2111 } 2112 2113 var want uint64 = 1 2114 if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { 2115 t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) 2116 } 2117 2118 // Check that we get an ACK for the newly non-zero window. 2119 checker.IPv4(t, c.GetPacket(), 2120 checker.TCP( 2121 checker.DstPort(context.TestPort), 2122 checker.TCPSeqNum(uint32(c.IRS)+1), 2123 checker.TCPAckNum(uint32(iss)+uint32(len(data))), 2124 checker.TCPFlags(header.TCPFlagAck), 2125 checker.TCPWindow(10), 2126 ), 2127 ) 2128 } 2129 2130 // Test the stack receive window advertisement on receiving segments smaller than 2131 // segment overhead. It tests for the right edge of the window to not grow when 2132 // the endpoint is not being read from. 2133 func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { 2134 c := context.New(t, defaultMTU) 2135 defer c.Cleanup() 2136 2137 opt := tcpip.TCPReceiveBufferSizeRangeOption{ 2138 Min: 1, 2139 Default: tcp.DefaultReceiveBufferSize, 2140 Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), 2141 } 2142 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 2143 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 2144 } 2145 2146 c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) 2147 2148 // Bump up the receive buffer size such that, when the receive window grows, 2149 // the scaled window exceeds maxUint16. 2150 c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) 2151 2152 // Keep the payload size < segment overhead and such that it is a multiple 2153 // of the window scaled value. This enables the test to perform equality 2154 // checks on the incoming receive window. 2155 payloadSize := 1 << c.RcvdWindowScale 2156 if payloadSize >= tcp.SegSize { 2157 t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize) 2158 } 2159 payload := generateRandomPayload(t, payloadSize) 2160 payloadLen := seqnum.Size(len(payload)) 2161 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2162 2163 // Send payload to the endpoint and return the advertised receive window 2164 // from the endpoint. 2165 getIncomingRcvWnd := func() uint32 { 2166 c.SendPacket(payload, &context.Headers{ 2167 SrcPort: context.TestPort, 2168 DstPort: c.Port, 2169 SeqNum: iss, 2170 AckNum: c.IRS.Add(1), 2171 Flags: header.TCPFlagAck, 2172 RcvWnd: 30000, 2173 }) 2174 iss = iss.Add(payloadLen) 2175 2176 pkt := c.GetPacket() 2177 return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale 2178 } 2179 2180 // Read the advertised receive window with the ACK for payload. 2181 rcvWnd := getIncomingRcvWnd() 2182 2183 // Check if the subsequent ACK to our send has not grown the right edge of 2184 // the window. 2185 if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { 2186 t.Fatalf("got incomingRcvwnd %d want %d", got, want) 2187 } 2188 2189 // Read the data so that the subsequent ACK from the endpoint 2190 // grows the right edge of the window. 2191 var buf bytes.Buffer 2192 if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { 2193 t.Fatalf("c.EP.Read: %s", err) 2194 } 2195 2196 // Check if we have received max uint16 as our advertised 2197 // scaled window now after a read above. 2198 maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) 2199 if got, want := getIncomingRcvWnd(), maxRcv; got != want { 2200 t.Fatalf("got incomingRcvwnd %d want %d", got, want) 2201 } 2202 2203 // Check if the subsequent ACK to our send has not grown the right edge of 2204 // the window. 2205 if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { 2206 t.Fatalf("got incomingRcvwnd %d want %d", got, want) 2207 } 2208 } 2209 2210 func TestNoWindowShrinking(t *testing.T) { 2211 c := context.New(t, defaultMTU) 2212 defer c.Cleanup() 2213 2214 // Start off with a certain receive buffer then cut it in half and verify that 2215 // the right edge of the window does not shrink. 2216 // NOTE: Netstack doubles the value specified here. 2217 rcvBufSize := 65536 2218 // Enable window scaling with a scale of zero from our end. 2219 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ 2220 header.TCPOptionWS, 3, 0, header.TCPOptionNOP, 2221 }) 2222 2223 we, ch := waiter.NewChannelEntry(nil) 2224 c.WQ.EventRegister(&we, waiter.ReadableEvents) 2225 defer c.WQ.EventUnregister(&we) 2226 2227 ept := endpointTester{c.EP} 2228 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 2229 2230 // Send a 1 byte payload so that we can record the current receive window. 2231 // Send a payload of half the size of rcvBufSize. 2232 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2233 payload := []byte{1} 2234 c.SendPacket(payload, &context.Headers{ 2235 SrcPort: context.TestPort, 2236 DstPort: c.Port, 2237 Flags: header.TCPFlagAck, 2238 SeqNum: iss, 2239 AckNum: c.IRS.Add(1), 2240 RcvWnd: 30000, 2241 }) 2242 2243 // Wait for receive to be notified. 2244 select { 2245 case <-ch: 2246 case <-time.After(5 * time.Second): 2247 t.Fatalf("Timed out waiting for data to arrive") 2248 } 2249 2250 // Read the 1 byte payload we just sent. 2251 if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { 2252 t.Fatalf("got data: %v, want: %v", got, want) 2253 } 2254 2255 // Verify that the ACK does not shrink the window. 2256 pkt := c.GetPacket() 2257 iss = iss.Add(1) 2258 checker.IPv4(t, pkt, 2259 checker.TCP( 2260 checker.DstPort(context.TestPort), 2261 checker.TCPSeqNum(uint32(c.IRS)+1), 2262 checker.TCPAckNum(uint32(iss)), 2263 checker.TCPFlags(header.TCPFlagAck), 2264 ), 2265 ) 2266 // Stash the initial window. 2267 initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale 2268 initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) 2269 // Now shrink the receive buffer to half its original size. 2270 c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) 2271 2272 data := generateRandomPayload(t, rcvBufSize) 2273 // Send a payload of half the size of rcvBufSize. 2274 c.SendPacket(data[:rcvBufSize/2], &context.Headers{ 2275 SrcPort: context.TestPort, 2276 DstPort: c.Port, 2277 Flags: header.TCPFlagAck, 2278 SeqNum: iss, 2279 AckNum: c.IRS.Add(1), 2280 RcvWnd: 30000, 2281 }) 2282 iss = iss.Add(seqnum.Size(rcvBufSize / 2)) 2283 2284 // Verify that the ACK does not shrink the window. 2285 pkt = c.GetPacket() 2286 checker.IPv4(t, pkt, 2287 checker.TCP( 2288 checker.DstPort(context.TestPort), 2289 checker.TCPSeqNum(uint32(c.IRS)+1), 2290 checker.TCPAckNum(uint32(iss)), 2291 checker.TCPFlags(header.TCPFlagAck), 2292 ), 2293 ) 2294 newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale 2295 newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) 2296 if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { 2297 t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) 2298 } 2299 2300 // Send another payload of half the size of rcvBufSize. This should fill up the 2301 // socket receive buffer and we should see a zero window. 2302 c.SendPacket(data[rcvBufSize/2:], &context.Headers{ 2303 SrcPort: context.TestPort, 2304 DstPort: c.Port, 2305 Flags: header.TCPFlagAck, 2306 SeqNum: iss, 2307 AckNum: c.IRS.Add(1), 2308 RcvWnd: 30000, 2309 }) 2310 iss = iss.Add(seqnum.Size(rcvBufSize / 2)) 2311 2312 checker.IPv4(t, c.GetPacket(), 2313 checker.TCP( 2314 checker.DstPort(context.TestPort), 2315 checker.TCPSeqNum(uint32(c.IRS)+1), 2316 checker.TCPAckNum(uint32(iss)), 2317 checker.TCPFlags(header.TCPFlagAck), 2318 checker.TCPWindow(0), 2319 ), 2320 ) 2321 2322 // Receive data and check it. 2323 read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) 2324 if !bytes.Equal(data, read) { 2325 t.Fatalf("got data = %v, want = %v", read, data) 2326 } 2327 2328 // Check that we get an ACK for the newly non-zero window, which is the new 2329 // receive buffer size we set after the connection was established. 2330 checker.IPv4(t, c.GetPacket(), 2331 checker.TCP( 2332 checker.DstPort(context.TestPort), 2333 checker.TCPSeqNum(uint32(c.IRS)+1), 2334 checker.TCPAckNum(uint32(iss)), 2335 checker.TCPFlags(header.TCPFlagAck), 2336 checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), 2337 ), 2338 ) 2339 } 2340 2341 func TestSimpleSend(t *testing.T) { 2342 c := context.New(t, defaultMTU) 2343 defer c.Cleanup() 2344 2345 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 2346 2347 data := []byte{1, 2, 3} 2348 var r bytes.Reader 2349 r.Reset(data) 2350 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2351 t.Fatalf("Write failed: %s", err) 2352 } 2353 2354 // Check that data is received. 2355 b := c.GetPacket() 2356 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2357 checker.IPv4(t, b, 2358 checker.PayloadLen(len(data)+header.TCPMinimumSize), 2359 checker.TCP( 2360 checker.DstPort(context.TestPort), 2361 checker.TCPSeqNum(uint32(c.IRS)+1), 2362 checker.TCPAckNum(uint32(iss)), 2363 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2364 ), 2365 ) 2366 2367 if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { 2368 t.Fatalf("got data = %v, want = %v", p, data) 2369 } 2370 2371 // Acknowledge the data. 2372 c.SendPacket(nil, &context.Headers{ 2373 SrcPort: context.TestPort, 2374 DstPort: c.Port, 2375 Flags: header.TCPFlagAck, 2376 SeqNum: iss, 2377 AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), 2378 RcvWnd: 30000, 2379 }) 2380 } 2381 2382 func TestZeroWindowSend(t *testing.T) { 2383 c := context.New(t, defaultMTU) 2384 defer c.Cleanup() 2385 2386 c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) 2387 2388 data := []byte{1, 2, 3} 2389 var r bytes.Reader 2390 r.Reset(data) 2391 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2392 t.Fatalf("Write failed: %s", err) 2393 } 2394 2395 // Check if we got a zero-window probe. 2396 b := c.GetPacket() 2397 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2398 checker.IPv4(t, b, 2399 checker.PayloadLen(header.TCPMinimumSize), 2400 checker.TCP( 2401 checker.DstPort(context.TestPort), 2402 checker.TCPSeqNum(uint32(c.IRS)), 2403 checker.TCPAckNum(uint32(iss)), 2404 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2405 ), 2406 ) 2407 2408 // Open up the window. Data should be received now. 2409 c.SendPacket(nil, &context.Headers{ 2410 SrcPort: context.TestPort, 2411 DstPort: c.Port, 2412 Flags: header.TCPFlagAck, 2413 SeqNum: iss, 2414 AckNum: c.IRS.Add(1), 2415 RcvWnd: 30000, 2416 }) 2417 2418 // Check that data is received. 2419 b = c.GetPacket() 2420 checker.IPv4(t, b, 2421 checker.PayloadLen(len(data)+header.TCPMinimumSize), 2422 checker.TCP( 2423 checker.DstPort(context.TestPort), 2424 checker.TCPSeqNum(uint32(c.IRS)+1), 2425 checker.TCPAckNum(uint32(iss)), 2426 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2427 ), 2428 ) 2429 2430 if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { 2431 t.Fatalf("got data = %v, want = %v", p, data) 2432 } 2433 2434 // Acknowledge the data. 2435 c.SendPacket(nil, &context.Headers{ 2436 SrcPort: context.TestPort, 2437 DstPort: c.Port, 2438 Flags: header.TCPFlagAck, 2439 SeqNum: iss, 2440 AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), 2441 RcvWnd: 30000, 2442 }) 2443 } 2444 2445 func TestScaledWindowConnect(t *testing.T) { 2446 // This test ensures that window scaling is used when the peer 2447 // does advertise it and connection is established with Connect(). 2448 c := context.New(t, defaultMTU) 2449 defer c.Cleanup() 2450 2451 // Set the window size greater than the maximum non-scaled window. 2452 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ 2453 header.TCPOptionWS, 3, 0, header.TCPOptionNOP, 2454 }) 2455 2456 data := []byte{1, 2, 3} 2457 var r bytes.Reader 2458 r.Reset(data) 2459 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2460 t.Fatalf("Write failed: %s", err) 2461 } 2462 2463 // Check that data is received, and that advertised window is 0x5fff, 2464 // that is, that it is scaled. 2465 b := c.GetPacket() 2466 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2467 checker.IPv4(t, b, 2468 checker.PayloadLen(len(data)+header.TCPMinimumSize), 2469 checker.TCP( 2470 checker.DstPort(context.TestPort), 2471 checker.TCPSeqNum(uint32(c.IRS)+1), 2472 checker.TCPAckNum(uint32(iss)), 2473 checker.TCPWindow(0x5fff), 2474 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2475 ), 2476 ) 2477 } 2478 2479 func TestNonScaledWindowConnect(t *testing.T) { 2480 // This test ensures that window scaling is not used when the peer 2481 // doesn't advertise it and connection is established with Connect(). 2482 c := context.New(t, defaultMTU) 2483 defer c.Cleanup() 2484 2485 // Set the window size greater than the maximum non-scaled window. 2486 c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) 2487 2488 data := []byte{1, 2, 3} 2489 var r bytes.Reader 2490 r.Reset(data) 2491 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2492 t.Fatalf("Write failed: %s", err) 2493 } 2494 2495 // Check that data is received, and that advertised window is 0xffff, 2496 // that is, that it's not scaled. 2497 b := c.GetPacket() 2498 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2499 checker.IPv4(t, b, 2500 checker.PayloadLen(len(data)+header.TCPMinimumSize), 2501 checker.TCP( 2502 checker.DstPort(context.TestPort), 2503 checker.TCPSeqNum(uint32(c.IRS)+1), 2504 checker.TCPAckNum(uint32(iss)), 2505 checker.TCPWindow(0xffff), 2506 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2507 ), 2508 ) 2509 } 2510 2511 func TestScaledWindowAccept(t *testing.T) { 2512 // This test ensures that window scaling is used when the peer 2513 // does advertise it and connection is established with Accept(). 2514 c := context.New(t, defaultMTU) 2515 defer c.Cleanup() 2516 2517 // Create EP and start listening. 2518 wq := &waiter.Queue{} 2519 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 2520 if err != nil { 2521 t.Fatalf("NewEndpoint failed: %s", err) 2522 } 2523 defer ep.Close() 2524 2525 // Set the window size greater than the maximum non-scaled window. 2526 ep.SocketOptions().SetReceiveBufferSize(65535*3, true) 2527 2528 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 2529 t.Fatalf("Bind failed: %s", err) 2530 } 2531 2532 if err := ep.Listen(10); err != nil { 2533 t.Fatalf("Listen failed: %s", err) 2534 } 2535 2536 // Do 3-way handshake. 2537 // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 2538 c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) 2539 2540 // Try to accept the connection. 2541 we, ch := waiter.NewChannelEntry(nil) 2542 wq.EventRegister(&we, waiter.ReadableEvents) 2543 defer wq.EventUnregister(&we) 2544 2545 c.EP, _, err = ep.Accept(nil) 2546 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 2547 // Wait for connection to be established. 2548 select { 2549 case <-ch: 2550 c.EP, _, err = ep.Accept(nil) 2551 if err != nil { 2552 t.Fatalf("Accept failed: %s", err) 2553 } 2554 2555 case <-time.After(1 * time.Second): 2556 t.Fatalf("Timed out waiting for accept") 2557 } 2558 } 2559 2560 data := []byte{1, 2, 3} 2561 var r bytes.Reader 2562 r.Reset(data) 2563 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2564 t.Fatalf("Write failed: %s", err) 2565 } 2566 2567 // Check that data is received, and that advertised window is 0x5fff, 2568 // that is, that it is scaled. 2569 b := c.GetPacket() 2570 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2571 checker.IPv4(t, b, 2572 checker.PayloadLen(len(data)+header.TCPMinimumSize), 2573 checker.TCP( 2574 checker.DstPort(context.TestPort), 2575 checker.TCPSeqNum(uint32(c.IRS)+1), 2576 checker.TCPAckNum(uint32(iss)), 2577 checker.TCPWindow(0x5fff), 2578 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2579 ), 2580 ) 2581 } 2582 2583 func TestNonScaledWindowAccept(t *testing.T) { 2584 // This test ensures that window scaling is not used when the peer 2585 // doesn't advertise it and connection is established with Accept(). 2586 c := context.New(t, defaultMTU) 2587 defer c.Cleanup() 2588 2589 // Create EP and start listening. 2590 wq := &waiter.Queue{} 2591 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 2592 if err != nil { 2593 t.Fatalf("NewEndpoint failed: %s", err) 2594 } 2595 defer ep.Close() 2596 2597 // Set the window size greater than the maximum non-scaled window. 2598 ep.SocketOptions().SetReceiveBufferSize(65535*3, true) 2599 2600 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 2601 t.Fatalf("Bind failed: %s", err) 2602 } 2603 2604 if err := ep.Listen(10); err != nil { 2605 t.Fatalf("Listen failed: %s", err) 2606 } 2607 2608 // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN 2609 // should not carry the window scaling option. 2610 c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) 2611 2612 // Try to accept the connection. 2613 we, ch := waiter.NewChannelEntry(nil) 2614 wq.EventRegister(&we, waiter.ReadableEvents) 2615 defer wq.EventUnregister(&we) 2616 2617 c.EP, _, err = ep.Accept(nil) 2618 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 2619 // Wait for connection to be established. 2620 select { 2621 case <-ch: 2622 c.EP, _, err = ep.Accept(nil) 2623 if err != nil { 2624 t.Fatalf("Accept failed: %s", err) 2625 } 2626 2627 case <-time.After(1 * time.Second): 2628 t.Fatalf("Timed out waiting for accept") 2629 } 2630 } 2631 2632 data := []byte{1, 2, 3} 2633 var r bytes.Reader 2634 r.Reset(data) 2635 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2636 t.Fatalf("Write failed: %s", err) 2637 } 2638 2639 // Check that data is received, and that advertised window is 0xffff, 2640 // that is, that it's not scaled. 2641 b := c.GetPacket() 2642 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2643 checker.IPv4(t, b, 2644 checker.PayloadLen(len(data)+header.TCPMinimumSize), 2645 checker.TCP( 2646 checker.DstPort(context.TestPort), 2647 checker.TCPSeqNum(uint32(c.IRS)+1), 2648 checker.TCPAckNum(uint32(iss)), 2649 checker.TCPWindow(0xffff), 2650 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2651 ), 2652 ) 2653 } 2654 2655 func TestZeroScaledWindowReceive(t *testing.T) { 2656 // This test ensures that the endpoint sends a non-zero window size 2657 // advertisement when the scaled window transitions from 0 to non-zero, 2658 // but the actual window (not scaled) hasn't gotten to zero. 2659 c := context.New(t, defaultMTU) 2660 defer c.Cleanup() 2661 2662 // Set the buffer size such that a window scale of 5 will be used. 2663 const bufSz = 65535 * 10 2664 const ws = uint32(5) 2665 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ 2666 header.TCPOptionWS, 3, 0, header.TCPOptionNOP, 2667 }) 2668 2669 // Write chunks of 50000 bytes. 2670 remain := 0 2671 sent := 0 2672 data := make([]byte, 50000) 2673 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2674 // Keep writing till the window drops below len(data). 2675 for { 2676 c.SendPacket(data, &context.Headers{ 2677 SrcPort: context.TestPort, 2678 DstPort: c.Port, 2679 Flags: header.TCPFlagAck, 2680 SeqNum: iss.Add(seqnum.Size(sent)), 2681 AckNum: c.IRS.Add(1), 2682 RcvWnd: 30000, 2683 }) 2684 sent += len(data) 2685 pkt := c.GetPacket() 2686 checker.IPv4(t, pkt, 2687 checker.PayloadLen(header.TCPMinimumSize), 2688 checker.TCP( 2689 checker.DstPort(context.TestPort), 2690 checker.TCPSeqNum(uint32(c.IRS)+1), 2691 checker.TCPAckNum(uint32(iss)+uint32(sent)), 2692 checker.TCPFlags(header.TCPFlagAck), 2693 ), 2694 ) 2695 // Don't reduce window to zero here. 2696 if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { 2697 remain = wnd << ws 2698 break 2699 } 2700 } 2701 2702 // Make the window non-zero, but the scaled window zero. 2703 for remain >= 16 { 2704 data = data[:remain-15] 2705 c.SendPacket(data, &context.Headers{ 2706 SrcPort: context.TestPort, 2707 DstPort: c.Port, 2708 Flags: header.TCPFlagAck, 2709 SeqNum: iss.Add(seqnum.Size(sent)), 2710 AckNum: c.IRS.Add(1), 2711 RcvWnd: 30000, 2712 }) 2713 sent += len(data) 2714 pkt := c.GetPacket() 2715 checker.IPv4(t, pkt, 2716 checker.PayloadLen(header.TCPMinimumSize), 2717 checker.TCP( 2718 checker.DstPort(context.TestPort), 2719 checker.TCPSeqNum(uint32(c.IRS)+1), 2720 checker.TCPAckNum(uint32(iss)+uint32(sent)), 2721 checker.TCPFlags(header.TCPFlagAck), 2722 ), 2723 ) 2724 // Since the receive buffer is split between window advertisement and 2725 // application data buffer the window does not always reflect the space 2726 // available and actual space available can be a bit more than what is 2727 // advertised in the window. 2728 wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) 2729 if wnd == 0 { 2730 break 2731 } 2732 remain = wnd << ws 2733 } 2734 2735 // Read at least 2MSS of data. An ack should be sent in response to that. 2736 // Since buffer space is now split in half between window and application 2737 // data we need to read more than 1 MSS(65536) of data for a non-zero window 2738 // update to be sent. For 1MSS worth of window to be available we need to 2739 // read at least 128KB. Since our segments above were 50KB each it means 2740 // we need to read at 3 packets. 2741 w := tcpip.LimitedWriter{ 2742 W: ioutil.Discard, 2743 N: defaultMTU * 2, 2744 } 2745 for w.N != 0 { 2746 res, err := c.EP.Read(&w, tcpip.ReadOptions{}) 2747 t.Logf("err=%v res=%#v", err, res) 2748 if err != nil { 2749 t.Fatalf("Read failed: %s", err) 2750 } 2751 } 2752 2753 checker.IPv4(t, c.GetPacket(), 2754 checker.PayloadLen(header.TCPMinimumSize), 2755 checker.TCP( 2756 checker.DstPort(context.TestPort), 2757 checker.TCPSeqNum(uint32(c.IRS)+1), 2758 checker.TCPAckNum(uint32(iss)+uint32(sent)), 2759 checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), 2760 checker.TCPFlags(header.TCPFlagAck), 2761 ), 2762 ) 2763 } 2764 2765 func TestSegmentMerging(t *testing.T) { 2766 tests := []struct { 2767 name string 2768 stop func(tcpip.Endpoint) 2769 resume func(tcpip.Endpoint) 2770 }{ 2771 { 2772 "stop work", 2773 func(ep tcpip.Endpoint) { 2774 ep.(interface{ StopWork() }).StopWork() 2775 }, 2776 func(ep tcpip.Endpoint) { 2777 ep.(interface{ ResumeWork() }).ResumeWork() 2778 }, 2779 }, 2780 { 2781 "cork", 2782 func(ep tcpip.Endpoint) { 2783 ep.SocketOptions().SetCorkOption(true) 2784 }, 2785 func(ep tcpip.Endpoint) { 2786 ep.SocketOptions().SetCorkOption(false) 2787 }, 2788 }, 2789 } 2790 2791 for _, test := range tests { 2792 t.Run(test.name, func(t *testing.T) { 2793 c := context.New(t, defaultMTU) 2794 defer c.Cleanup() 2795 2796 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 2797 2798 // Send tcp.InitialCwnd number of segments to fill up 2799 // InitialWindow but don't ACK. That should prevent 2800 // anymore packets from going out. 2801 var r bytes.Reader 2802 for i := 0; i < tcp.InitialCwnd; i++ { 2803 r.Reset([]byte{0}) 2804 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2805 t.Fatalf("Write #%d failed: %s", i+1, err) 2806 } 2807 } 2808 2809 // Now send the segments that should get merged as the congestion 2810 // window is full and we won't be able to send any more packets. 2811 var allData []byte 2812 for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { 2813 allData = append(allData, data...) 2814 r.Reset(data) 2815 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2816 t.Fatalf("Write #%d failed: %s", i+1, err) 2817 } 2818 } 2819 2820 // Check that we get tcp.InitialCwnd packets. 2821 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2822 for i := 0; i < tcp.InitialCwnd; i++ { 2823 b := c.GetPacket() 2824 checker.IPv4(t, b, 2825 checker.PayloadLen(header.TCPMinimumSize+1), 2826 checker.TCP( 2827 checker.DstPort(context.TestPort), 2828 checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), 2829 checker.TCPAckNum(uint32(iss)), 2830 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2831 ), 2832 ) 2833 } 2834 2835 // Acknowledge the data. 2836 c.SendPacket(nil, &context.Headers{ 2837 SrcPort: context.TestPort, 2838 DstPort: c.Port, 2839 Flags: header.TCPFlagAck, 2840 SeqNum: iss, 2841 AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. 2842 RcvWnd: 30000, 2843 }) 2844 2845 // Check that data is received. 2846 b := c.GetPacket() 2847 checker.IPv4(t, b, 2848 checker.PayloadLen(len(allData)+header.TCPMinimumSize), 2849 checker.TCP( 2850 checker.DstPort(context.TestPort), 2851 checker.TCPSeqNum(uint32(c.IRS)+11), 2852 checker.TCPAckNum(uint32(iss)), 2853 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2854 ), 2855 ) 2856 2857 if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { 2858 t.Fatalf("got data = %v, want = %v", got, allData) 2859 } 2860 2861 // Acknowledge the data. 2862 c.SendPacket(nil, &context.Headers{ 2863 SrcPort: context.TestPort, 2864 DstPort: c.Port, 2865 Flags: header.TCPFlagAck, 2866 SeqNum: iss, 2867 AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), 2868 RcvWnd: 30000, 2869 }) 2870 }) 2871 } 2872 } 2873 2874 func TestDelay(t *testing.T) { 2875 c := context.New(t, defaultMTU) 2876 defer c.Cleanup() 2877 2878 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 2879 2880 c.EP.SocketOptions().SetDelayOption(true) 2881 2882 var allData []byte 2883 for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { 2884 allData = append(allData, data...) 2885 var r bytes.Reader 2886 r.Reset(data) 2887 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2888 t.Fatalf("Write #%d failed: %s", i+1, err) 2889 } 2890 } 2891 2892 seq := c.IRS.Add(1) 2893 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2894 for _, want := range [][]byte{allData[:1], allData[1:]} { 2895 // Check that data is received. 2896 b := c.GetPacket() 2897 checker.IPv4(t, b, 2898 checker.PayloadLen(len(want)+header.TCPMinimumSize), 2899 checker.TCP( 2900 checker.DstPort(context.TestPort), 2901 checker.TCPSeqNum(uint32(seq)), 2902 checker.TCPAckNum(uint32(iss)), 2903 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2904 ), 2905 ) 2906 2907 if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { 2908 t.Fatalf("got data = %v, want = %v", got, want) 2909 } 2910 2911 seq = seq.Add(seqnum.Size(len(want))) 2912 // Acknowledge the data. 2913 c.SendPacket(nil, &context.Headers{ 2914 SrcPort: context.TestPort, 2915 DstPort: c.Port, 2916 Flags: header.TCPFlagAck, 2917 SeqNum: iss, 2918 AckNum: seq, 2919 RcvWnd: 30000, 2920 }) 2921 } 2922 } 2923 2924 func TestUndelay(t *testing.T) { 2925 c := context.New(t, defaultMTU) 2926 defer c.Cleanup() 2927 2928 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 2929 2930 c.EP.SocketOptions().SetDelayOption(true) 2931 2932 allData := [][]byte{{0}, {1, 2, 3}} 2933 for i, data := range allData { 2934 var r bytes.Reader 2935 r.Reset(data) 2936 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 2937 t.Fatalf("Write #%d failed: %s", i+1, err) 2938 } 2939 } 2940 2941 seq := c.IRS.Add(1) 2942 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 2943 // Check that data is received. 2944 first := c.GetPacket() 2945 checker.IPv4(t, first, 2946 checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), 2947 checker.TCP( 2948 checker.DstPort(context.TestPort), 2949 checker.TCPSeqNum(uint32(seq)), 2950 checker.TCPAckNum(uint32(iss)), 2951 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2952 ), 2953 ) 2954 2955 if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { 2956 t.Fatalf("got first packet's data = %v, want = %v", got, want) 2957 } 2958 2959 seq = seq.Add(seqnum.Size(len(allData[0]))) 2960 2961 // Check that we don't get the second packet yet. 2962 c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) 2963 2964 c.EP.SocketOptions().SetDelayOption(false) 2965 2966 // Check that data is received. 2967 second := c.GetPacket() 2968 checker.IPv4(t, second, 2969 checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), 2970 checker.TCP( 2971 checker.DstPort(context.TestPort), 2972 checker.TCPSeqNum(uint32(seq)), 2973 checker.TCPAckNum(uint32(iss)), 2974 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 2975 ), 2976 ) 2977 2978 if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { 2979 t.Fatalf("got second packet's data = %v, want = %v", got, want) 2980 } 2981 2982 seq = seq.Add(seqnum.Size(len(allData[1]))) 2983 2984 // Acknowledge the data. 2985 c.SendPacket(nil, &context.Headers{ 2986 SrcPort: context.TestPort, 2987 DstPort: c.Port, 2988 Flags: header.TCPFlagAck, 2989 SeqNum: iss, 2990 AckNum: seq, 2991 RcvWnd: 30000, 2992 }) 2993 } 2994 2995 func TestMSSNotDelayed(t *testing.T) { 2996 tests := []struct { 2997 name string 2998 fn func(tcpip.Endpoint) 2999 }{ 3000 {"no-op", func(tcpip.Endpoint) {}}, 3001 {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, 3002 {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, 3003 } 3004 3005 for _, test := range tests { 3006 t.Run(test.name, func(t *testing.T) { 3007 const maxPayload = 100 3008 c := context.New(t, defaultMTU) 3009 defer c.Cleanup() 3010 3011 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ 3012 header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), 3013 }) 3014 3015 test.fn(c.EP) 3016 3017 allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} 3018 for i, data := range allData { 3019 var r bytes.Reader 3020 r.Reset(data) 3021 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3022 t.Fatalf("Write #%d failed: %s", i+1, err) 3023 } 3024 } 3025 3026 seq := c.IRS.Add(1) 3027 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3028 for i, data := range allData { 3029 // Check that data is received. 3030 packet := c.GetPacket() 3031 checker.IPv4(t, packet, 3032 checker.PayloadLen(len(data)+header.TCPMinimumSize), 3033 checker.TCP( 3034 checker.DstPort(context.TestPort), 3035 checker.TCPSeqNum(uint32(seq)), 3036 checker.TCPAckNum(uint32(iss)), 3037 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3038 ), 3039 ) 3040 3041 if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { 3042 t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) 3043 } 3044 3045 seq = seq.Add(seqnum.Size(len(data))) 3046 } 3047 3048 // Acknowledge the data. 3049 c.SendPacket(nil, &context.Headers{ 3050 SrcPort: context.TestPort, 3051 DstPort: c.Port, 3052 Flags: header.TCPFlagAck, 3053 SeqNum: iss, 3054 AckNum: seq, 3055 RcvWnd: 30000, 3056 }) 3057 }) 3058 } 3059 } 3060 3061 func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { 3062 payloadMultiplier := 10 3063 dataLen := payloadMultiplier * maxPayload 3064 data := make([]byte, dataLen) 3065 for i := range data { 3066 data[i] = byte(i) 3067 } 3068 3069 var r bytes.Reader 3070 r.Reset(data) 3071 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3072 t.Fatalf("Write failed: %s", err) 3073 } 3074 3075 // Check that data is received in chunks. 3076 bytesReceived := 0 3077 numPackets := 0 3078 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3079 for bytesReceived != dataLen { 3080 b := c.GetPacket() 3081 numPackets++ 3082 tcpHdr := header.TCP(header.IPv4(b).Payload()) 3083 payloadLen := len(tcpHdr.Payload()) 3084 checker.IPv4(t, b, 3085 checker.TCP( 3086 checker.DstPort(context.TestPort), 3087 checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), 3088 checker.TCPAckNum(uint32(iss)), 3089 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3090 ), 3091 ) 3092 3093 pdata := data[bytesReceived : bytesReceived+payloadLen] 3094 if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { 3095 t.Fatalf("got data = %v, want = %v", p, pdata) 3096 } 3097 bytesReceived += payloadLen 3098 var options []byte 3099 if c.TimeStampEnabled { 3100 // If timestamp option is enabled, echo back the timestamp and increment 3101 // the TSEcr value included in the packet and send that back as the TSVal. 3102 parsedOpts := tcpHdr.ParsedOptions() 3103 tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} 3104 header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) 3105 options = tsOpt[:] 3106 } 3107 // Acknowledge the data. 3108 c.SendPacket(nil, &context.Headers{ 3109 SrcPort: context.TestPort, 3110 DstPort: c.Port, 3111 Flags: header.TCPFlagAck, 3112 SeqNum: iss, 3113 AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), 3114 RcvWnd: 30000, 3115 TCPOpts: options, 3116 }) 3117 } 3118 if numPackets == 1 { 3119 t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") 3120 } 3121 } 3122 3123 func TestSendGreaterThanMTU(t *testing.T) { 3124 const maxPayload = 100 3125 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) 3126 defer c.Cleanup() 3127 3128 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3129 testBrokenUpWrite(t, c, maxPayload) 3130 } 3131 3132 func TestSetTTL(t *testing.T) { 3133 for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { 3134 t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { 3135 c := context.New(t, 65535) 3136 defer c.Cleanup() 3137 3138 var err tcpip.Error 3139 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 3140 if err != nil { 3141 t.Fatalf("NewEndpoint failed: %s", err) 3142 } 3143 3144 if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { 3145 t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) 3146 } 3147 3148 { 3149 err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 3150 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 3151 t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) 3152 } 3153 } 3154 3155 // Receive SYN packet. 3156 b := c.GetPacket() 3157 3158 checker.IPv4(t, b, checker.TTL(wantTTL)) 3159 }) 3160 } 3161 } 3162 3163 func TestActiveSendMSSLessThanMTU(t *testing.T) { 3164 const maxPayload = 100 3165 c := context.New(t, 65535) 3166 defer c.Cleanup() 3167 3168 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ 3169 header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), 3170 }) 3171 testBrokenUpWrite(t, c, maxPayload) 3172 } 3173 3174 func TestPassiveSendMSSLessThanMTU(t *testing.T) { 3175 const maxPayload = 100 3176 const mtu = 1200 3177 c := context.New(t, mtu) 3178 defer c.Cleanup() 3179 3180 // Create EP and start listening. 3181 wq := &waiter.Queue{} 3182 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 3183 if err != nil { 3184 t.Fatalf("NewEndpoint failed: %s", err) 3185 } 3186 defer ep.Close() 3187 3188 // Set the buffer size to a deterministic size so that we can check the 3189 // window scaling option. 3190 const rcvBufferSize = 0x20000 3191 ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) 3192 3193 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 3194 t.Fatalf("Bind failed: %s", err) 3195 } 3196 3197 if err := ep.Listen(10); err != nil { 3198 t.Fatalf("Listen failed: %s", err) 3199 } 3200 3201 // Do 3-way handshake. 3202 c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) 3203 3204 // Try to accept the connection. 3205 we, ch := waiter.NewChannelEntry(nil) 3206 wq.EventRegister(&we, waiter.ReadableEvents) 3207 defer wq.EventUnregister(&we) 3208 3209 c.EP, _, err = ep.Accept(nil) 3210 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 3211 // Wait for connection to be established. 3212 select { 3213 case <-ch: 3214 c.EP, _, err = ep.Accept(nil) 3215 if err != nil { 3216 t.Fatalf("Accept failed: %s", err) 3217 } 3218 3219 case <-time.After(1 * time.Second): 3220 t.Fatalf("Timed out waiting for accept") 3221 } 3222 } 3223 3224 // Check that data gets properly segmented. 3225 testBrokenUpWrite(t, c, maxPayload) 3226 } 3227 3228 func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { 3229 const maxPayload = 536 3230 const mtu = 2000 3231 c := context.New(t, mtu) 3232 defer c.Cleanup() 3233 3234 opt := tcpip.TCPAlwaysUseSynCookies(true) 3235 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 3236 t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) 3237 } 3238 3239 // Create EP and start listening. 3240 wq := &waiter.Queue{} 3241 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 3242 if err != nil { 3243 t.Fatalf("NewEndpoint failed: %s", err) 3244 } 3245 defer ep.Close() 3246 3247 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 3248 t.Fatalf("Bind failed: %s", err) 3249 } 3250 3251 if err := ep.Listen(10); err != nil { 3252 t.Fatalf("Listen failed: %s", err) 3253 } 3254 3255 // Do 3-way handshake. 3256 c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) 3257 3258 // Try to accept the connection. 3259 we, ch := waiter.NewChannelEntry(nil) 3260 wq.EventRegister(&we, waiter.ReadableEvents) 3261 defer wq.EventUnregister(&we) 3262 3263 c.EP, _, err = ep.Accept(nil) 3264 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 3265 // Wait for connection to be established. 3266 select { 3267 case <-ch: 3268 c.EP, _, err = ep.Accept(nil) 3269 if err != nil { 3270 t.Fatalf("Accept failed: %s", err) 3271 } 3272 3273 case <-time.After(1 * time.Second): 3274 t.Fatalf("Timed out waiting for accept") 3275 } 3276 } 3277 3278 // Check that data gets properly segmented. 3279 testBrokenUpWrite(t, c, maxPayload) 3280 } 3281 3282 func TestForwarderSendMSSLessThanMTU(t *testing.T) { 3283 const maxPayload = 100 3284 const mtu = 1200 3285 c := context.New(t, mtu) 3286 defer c.Cleanup() 3287 3288 s := c.Stack() 3289 ch := make(chan tcpip.Error, 1) 3290 f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { 3291 var err tcpip.Error 3292 c.EP, err = r.CreateEndpoint(&c.WQ) 3293 ch <- err 3294 }) 3295 s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) 3296 3297 // Do 3-way handshake. 3298 c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) 3299 3300 // Wait for connection to be available. 3301 select { 3302 case err := <-ch: 3303 if err != nil { 3304 t.Fatalf("Error creating endpoint: %s", err) 3305 } 3306 case <-time.After(2 * time.Second): 3307 t.Fatalf("Timed out waiting for connection") 3308 } 3309 3310 // Check that data gets properly segmented. 3311 testBrokenUpWrite(t, c, maxPayload) 3312 } 3313 3314 func TestSynOptionsOnActiveConnect(t *testing.T) { 3315 const mtu = 1400 3316 c := context.New(t, mtu) 3317 defer c.Cleanup() 3318 3319 // Create TCP endpoint. 3320 var err tcpip.Error 3321 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 3322 if err != nil { 3323 t.Fatalf("NewEndpoint failed: %s", err) 3324 } 3325 3326 // Set the buffer size to a deterministic size so that we can check the 3327 // window scaling option. 3328 const rcvBufferSize = 0x20000 3329 const wndScale = 3 3330 c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) 3331 3332 // Start connection attempt. 3333 we, ch := waiter.NewChannelEntry(nil) 3334 c.WQ.EventRegister(&we, waiter.WritableEvents) 3335 defer c.WQ.EventUnregister(&we) 3336 3337 { 3338 err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 3339 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 3340 t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) 3341 } 3342 } 3343 3344 // Receive SYN packet. 3345 b := c.GetPacket() 3346 mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) 3347 checker.IPv4(t, b, 3348 checker.TCP( 3349 checker.DstPort(context.TestPort), 3350 checker.TCPFlags(header.TCPFlagSyn), 3351 checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), 3352 ), 3353 ) 3354 3355 tcpHdr := header.TCP(header.IPv4(b).Payload()) 3356 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 3357 3358 // Wait for retransmit. 3359 time.Sleep(1 * time.Second) 3360 checker.IPv4(t, c.GetPacket(), 3361 checker.TCP( 3362 checker.DstPort(context.TestPort), 3363 checker.TCPFlags(header.TCPFlagSyn), 3364 checker.SrcPort(tcpHdr.SourcePort()), 3365 checker.TCPSeqNum(tcpHdr.SequenceNumber()), 3366 checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), 3367 ), 3368 ) 3369 3370 // Send SYN-ACK. 3371 iss := seqnum.Value(context.TestInitialSequenceNumber) 3372 c.SendPacket(nil, &context.Headers{ 3373 SrcPort: tcpHdr.DestinationPort(), 3374 DstPort: tcpHdr.SourcePort(), 3375 Flags: header.TCPFlagSyn | header.TCPFlagAck, 3376 SeqNum: iss, 3377 AckNum: c.IRS.Add(1), 3378 RcvWnd: 30000, 3379 }) 3380 3381 // Receive ACK packet. 3382 checker.IPv4(t, c.GetPacket(), 3383 checker.TCP( 3384 checker.DstPort(context.TestPort), 3385 checker.TCPFlags(header.TCPFlagAck), 3386 checker.TCPSeqNum(uint32(c.IRS)+1), 3387 checker.TCPAckNum(uint32(iss)+1), 3388 ), 3389 ) 3390 3391 // Wait for connection to be established. 3392 select { 3393 case <-ch: 3394 if err := c.EP.LastError(); err != nil { 3395 t.Fatalf("Connect failed: %s", err) 3396 } 3397 case <-time.After(1 * time.Second): 3398 t.Fatalf("Timed out waiting for connection") 3399 } 3400 } 3401 3402 func TestCloseListener(t *testing.T) { 3403 c := context.New(t, defaultMTU) 3404 defer c.Cleanup() 3405 3406 // Create listener. 3407 var wq waiter.Queue 3408 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 3409 if err != nil { 3410 t.Fatalf("NewEndpoint failed: %s", err) 3411 } 3412 3413 if err := ep.Bind(tcpip.FullAddress{}); err != nil { 3414 t.Fatalf("Bind failed: %s", err) 3415 } 3416 3417 if err := ep.Listen(10); err != nil { 3418 t.Fatalf("Listen failed: %s", err) 3419 } 3420 3421 // Close the listener and measure how long it takes. 3422 t0 := time.Now() 3423 ep.Close() 3424 if diff := time.Now().Sub(t0); diff > 3*time.Second { 3425 t.Fatalf("Took too long to close: %s", diff) 3426 } 3427 } 3428 3429 func TestReceiveOnResetConnection(t *testing.T) { 3430 c := context.New(t, defaultMTU) 3431 defer c.Cleanup() 3432 3433 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3434 3435 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3436 // Send RST segment. 3437 c.SendPacket(nil, &context.Headers{ 3438 SrcPort: context.TestPort, 3439 DstPort: c.Port, 3440 Flags: header.TCPFlagRst, 3441 SeqNum: iss, 3442 RcvWnd: 30000, 3443 }) 3444 3445 // Try to read. 3446 we, ch := waiter.NewChannelEntry(nil) 3447 c.WQ.EventRegister(&we, waiter.ReadableEvents) 3448 defer c.WQ.EventUnregister(&we) 3449 3450 loop: 3451 for { 3452 switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { 3453 case *tcpip.ErrWouldBlock: 3454 <-ch 3455 // Expect the state to be StateError and subsequent Reads to fail with HardError. 3456 _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) 3457 if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { 3458 t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) 3459 } 3460 break loop 3461 case *tcpip.ErrConnectionReset: 3462 break loop 3463 default: 3464 t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) 3465 } 3466 } 3467 3468 if tcp.EndpointState(c.EP.State()) != tcp.StateError { 3469 t.Fatalf("got EP state is not StateError") 3470 } 3471 3472 checkValid := func() []error { 3473 var errors []error 3474 if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { 3475 errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) 3476 } 3477 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 3478 errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) 3479 } 3480 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 3481 errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) 3482 } 3483 return errors 3484 } 3485 3486 start := time.Now() 3487 for time.Since(start) < time.Minute && len(checkValid()) > 0 { 3488 time.Sleep(50 * time.Millisecond) 3489 } 3490 for _, err := range checkValid() { 3491 t.Error(err) 3492 } 3493 } 3494 3495 func TestSendOnResetConnection(t *testing.T) { 3496 c := context.New(t, defaultMTU) 3497 defer c.Cleanup() 3498 3499 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3500 3501 // Send RST segment. 3502 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3503 c.SendPacket(nil, &context.Headers{ 3504 SrcPort: context.TestPort, 3505 DstPort: c.Port, 3506 Flags: header.TCPFlagRst, 3507 SeqNum: iss, 3508 RcvWnd: 30000, 3509 }) 3510 3511 // Wait for the RST to be received. 3512 time.Sleep(1 * time.Second) 3513 3514 // Try to write. 3515 var r bytes.Reader 3516 r.Reset(make([]byte, 10)) 3517 _, err := c.EP.Write(&r, tcpip.WriteOptions{}) 3518 if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { 3519 t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) 3520 } 3521 } 3522 3523 // TestMaxRetransmitsTimeout tests if the connection is timed out after 3524 // a segment has been retransmitted MaxRetries times. 3525 func TestMaxRetransmitsTimeout(t *testing.T) { 3526 c := context.New(t, defaultMTU) 3527 defer c.Cleanup() 3528 3529 const numRetries = 2 3530 opt := tcpip.TCPMaxRetriesOption(numRetries) 3531 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 3532 t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) 3533 } 3534 3535 c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) 3536 3537 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 3538 c.WQ.EventRegister(&waitEntry, waiter.EventHUp) 3539 defer c.WQ.EventUnregister(&waitEntry) 3540 3541 var r bytes.Reader 3542 r.Reset(make([]byte, 1)) 3543 _, err := c.EP.Write(&r, tcpip.WriteOptions{}) 3544 if err != nil { 3545 t.Fatalf("Write failed: %s", err) 3546 } 3547 3548 // Expect first transmit and MaxRetries retransmits. 3549 for i := 0; i < numRetries+1; i++ { 3550 checker.IPv4(t, c.GetPacket(), 3551 checker.TCP( 3552 checker.DstPort(context.TestPort), 3553 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), 3554 ), 3555 ) 3556 } 3557 // Wait for the connection to timeout after MaxRetries retransmits. 3558 initRTO := 1 * time.Second 3559 select { 3560 case <-notifyCh: 3561 case <-time.After((2 << numRetries) * initRTO): 3562 t.Fatalf("connection still alive after maximum retransmits.\n") 3563 } 3564 3565 // Send an ACK and expect a RST as the connection would have been closed. 3566 c.SendPacket(nil, &context.Headers{ 3567 SrcPort: context.TestPort, 3568 DstPort: c.Port, 3569 Flags: header.TCPFlagAck, 3570 }) 3571 3572 checker.IPv4(t, c.GetPacket(), 3573 checker.TCP( 3574 checker.DstPort(context.TestPort), 3575 checker.TCPFlags(header.TCPFlagRst), 3576 ), 3577 ) 3578 3579 if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { 3580 t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) 3581 } 3582 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 3583 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 3584 } 3585 } 3586 3587 // TestMaxRTO tests if the retransmit interval caps to MaxRTO. 3588 func TestMaxRTO(t *testing.T) { 3589 c := context.New(t, defaultMTU) 3590 defer c.Cleanup() 3591 3592 rto := 1 * time.Second 3593 opt := tcpip.TCPMaxRTOOption(rto) 3594 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 3595 t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) 3596 } 3597 3598 c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) 3599 3600 var r bytes.Reader 3601 r.Reset(make([]byte, 1)) 3602 _, err := c.EP.Write(&r, tcpip.WriteOptions{}) 3603 if err != nil { 3604 t.Fatalf("Write failed: %s", err) 3605 } 3606 checker.IPv4(t, c.GetPacket(), 3607 checker.TCP( 3608 checker.DstPort(context.TestPort), 3609 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3610 ), 3611 ) 3612 const numRetransmits = 2 3613 for i := 0; i < numRetransmits; i++ { 3614 start := time.Now() 3615 checker.IPv4(t, c.GetPacket(), 3616 checker.TCP( 3617 checker.DstPort(context.TestPort), 3618 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3619 ), 3620 ) 3621 if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { 3622 t.Errorf("Retransmit interval not capped to MaxRTO.\n") 3623 } 3624 } 3625 } 3626 3627 // TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is 3628 // unique on retransmits. 3629 func TestRetransmitIPv4IDUniqueness(t *testing.T) { 3630 for _, tc := range []struct { 3631 name string 3632 size int 3633 }{ 3634 {"1Byte", 1}, 3635 {"512Bytes", 512}, 3636 } { 3637 t.Run(tc.name, func(t *testing.T) { 3638 c := context.New(t, defaultMTU) 3639 defer c.Cleanup() 3640 3641 c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) 3642 3643 // Disabling PMTU discovery causes all packets sent from this socket to 3644 // have DF=0. This needs to be done because the IPv4 ID uniqueness 3645 // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 3646 // Section 4, and datagrams with DF=0 are non-atomic. 3647 if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { 3648 t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) 3649 } 3650 3651 var r bytes.Reader 3652 r.Reset(make([]byte, tc.size)) 3653 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3654 t.Fatalf("Write failed: %s", err) 3655 } 3656 pkt := c.GetPacket() 3657 checker.IPv4(t, pkt, 3658 checker.FragmentFlags(0), 3659 checker.TCP( 3660 checker.DstPort(context.TestPort), 3661 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3662 ), 3663 ) 3664 idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} 3665 // Expect two retransmitted packets, and that all packets received have 3666 // unique IPv4 ID values. 3667 for i := 0; i <= 2; i++ { 3668 pkt := c.GetPacket() 3669 checker.IPv4(t, pkt, 3670 checker.FragmentFlags(0), 3671 checker.TCP( 3672 checker.DstPort(context.TestPort), 3673 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3674 ), 3675 ) 3676 id := header.IPv4(pkt).ID() 3677 if _, exists := idSet[id]; exists { 3678 t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) 3679 } 3680 idSet[id] = struct{}{} 3681 } 3682 }) 3683 } 3684 } 3685 3686 func TestFinImmediately(t *testing.T) { 3687 c := context.New(t, defaultMTU) 3688 defer c.Cleanup() 3689 3690 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3691 3692 // Shutdown immediately, check that we get a FIN. 3693 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 3694 t.Fatalf("Shutdown failed: %s", err) 3695 } 3696 3697 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3698 checker.IPv4(t, c.GetPacket(), 3699 checker.PayloadLen(header.TCPMinimumSize), 3700 checker.TCP( 3701 checker.DstPort(context.TestPort), 3702 checker.TCPSeqNum(uint32(c.IRS)+1), 3703 checker.TCPAckNum(uint32(iss)), 3704 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 3705 ), 3706 ) 3707 3708 // Ack and send FIN as well. 3709 c.SendPacket(nil, &context.Headers{ 3710 SrcPort: context.TestPort, 3711 DstPort: c.Port, 3712 Flags: header.TCPFlagAck | header.TCPFlagFin, 3713 SeqNum: iss, 3714 AckNum: c.IRS.Add(2), 3715 RcvWnd: 30000, 3716 }) 3717 3718 // Check that the stack acks the FIN. 3719 checker.IPv4(t, c.GetPacket(), 3720 checker.PayloadLen(header.TCPMinimumSize), 3721 checker.TCP( 3722 checker.DstPort(context.TestPort), 3723 checker.TCPSeqNum(uint32(c.IRS)+2), 3724 checker.TCPAckNum(uint32(iss)+1), 3725 checker.TCPFlags(header.TCPFlagAck), 3726 ), 3727 ) 3728 } 3729 3730 func TestFinRetransmit(t *testing.T) { 3731 c := context.New(t, defaultMTU) 3732 defer c.Cleanup() 3733 3734 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3735 3736 // Shutdown immediately, check that we get a FIN. 3737 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 3738 t.Fatalf("Shutdown failed: %s", err) 3739 } 3740 3741 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3742 checker.IPv4(t, c.GetPacket(), 3743 checker.PayloadLen(header.TCPMinimumSize), 3744 checker.TCP( 3745 checker.DstPort(context.TestPort), 3746 checker.TCPSeqNum(uint32(c.IRS)+1), 3747 checker.TCPAckNum(uint32(iss)), 3748 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 3749 ), 3750 ) 3751 3752 // Don't acknowledge yet. We should get a retransmit of the FIN. 3753 checker.IPv4(t, c.GetPacket(), 3754 checker.PayloadLen(header.TCPMinimumSize), 3755 checker.TCP( 3756 checker.DstPort(context.TestPort), 3757 checker.TCPSeqNum(uint32(c.IRS)+1), 3758 checker.TCPAckNum(uint32(iss)), 3759 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 3760 ), 3761 ) 3762 3763 // Ack and send FIN as well. 3764 c.SendPacket(nil, &context.Headers{ 3765 SrcPort: context.TestPort, 3766 DstPort: c.Port, 3767 Flags: header.TCPFlagAck | header.TCPFlagFin, 3768 SeqNum: iss, 3769 AckNum: c.IRS.Add(2), 3770 RcvWnd: 30000, 3771 }) 3772 3773 // Check that the stack acks the FIN. 3774 checker.IPv4(t, c.GetPacket(), 3775 checker.PayloadLen(header.TCPMinimumSize), 3776 checker.TCP( 3777 checker.DstPort(context.TestPort), 3778 checker.TCPSeqNum(uint32(c.IRS)+2), 3779 checker.TCPAckNum(uint32(iss)+1), 3780 checker.TCPFlags(header.TCPFlagAck), 3781 ), 3782 ) 3783 } 3784 3785 func TestFinWithNoPendingData(t *testing.T) { 3786 c := context.New(t, defaultMTU) 3787 defer c.Cleanup() 3788 3789 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3790 3791 // Write something out, and have it acknowledged. 3792 view := make([]byte, 10) 3793 var r bytes.Reader 3794 r.Reset(view) 3795 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3796 t.Fatalf("Write failed: %s", err) 3797 } 3798 3799 next := uint32(c.IRS) + 1 3800 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3801 checker.IPv4(t, c.GetPacket(), 3802 checker.PayloadLen(len(view)+header.TCPMinimumSize), 3803 checker.TCP( 3804 checker.DstPort(context.TestPort), 3805 checker.TCPSeqNum(next), 3806 checker.TCPAckNum(uint32(iss)), 3807 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3808 ), 3809 ) 3810 next += uint32(len(view)) 3811 3812 c.SendPacket(nil, &context.Headers{ 3813 SrcPort: context.TestPort, 3814 DstPort: c.Port, 3815 Flags: header.TCPFlagAck, 3816 SeqNum: iss, 3817 AckNum: seqnum.Value(next), 3818 RcvWnd: 30000, 3819 }) 3820 3821 // Shutdown, check that we get a FIN. 3822 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 3823 t.Fatalf("Shutdown failed: %s", err) 3824 } 3825 3826 checker.IPv4(t, c.GetPacket(), 3827 checker.PayloadLen(header.TCPMinimumSize), 3828 checker.TCP( 3829 checker.DstPort(context.TestPort), 3830 checker.TCPSeqNum(next), 3831 checker.TCPAckNum(uint32(iss)), 3832 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 3833 ), 3834 ) 3835 next++ 3836 3837 // Ack and send FIN as well. 3838 c.SendPacket(nil, &context.Headers{ 3839 SrcPort: context.TestPort, 3840 DstPort: c.Port, 3841 Flags: header.TCPFlagAck | header.TCPFlagFin, 3842 SeqNum: iss, 3843 AckNum: seqnum.Value(next), 3844 RcvWnd: 30000, 3845 }) 3846 3847 // Check that the stack acks the FIN. 3848 checker.IPv4(t, c.GetPacket(), 3849 checker.PayloadLen(header.TCPMinimumSize), 3850 checker.TCP( 3851 checker.DstPort(context.TestPort), 3852 checker.TCPSeqNum(next), 3853 checker.TCPAckNum(uint32(iss)+1), 3854 checker.TCPFlags(header.TCPFlagAck), 3855 ), 3856 ) 3857 } 3858 3859 func TestFinWithPendingDataCwndFull(t *testing.T) { 3860 c := context.New(t, defaultMTU) 3861 defer c.Cleanup() 3862 3863 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3864 3865 // Write enough segments to fill the congestion window before ACK'ing 3866 // any of them. 3867 view := make([]byte, 10) 3868 var r bytes.Reader 3869 for i := tcp.InitialCwnd; i > 0; i-- { 3870 r.Reset(view) 3871 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3872 t.Fatalf("Write failed: %s", err) 3873 } 3874 } 3875 3876 next := uint32(c.IRS) + 1 3877 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3878 for i := tcp.InitialCwnd; i > 0; i-- { 3879 checker.IPv4(t, c.GetPacket(), 3880 checker.PayloadLen(len(view)+header.TCPMinimumSize), 3881 checker.TCP( 3882 checker.DstPort(context.TestPort), 3883 checker.TCPSeqNum(next), 3884 checker.TCPAckNum(uint32(iss)), 3885 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3886 ), 3887 ) 3888 next += uint32(len(view)) 3889 } 3890 3891 // Shutdown the connection, check that the FIN segment isn't sent 3892 // because the congestion window doesn't allow it. Wait until a 3893 // retransmit is received. 3894 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 3895 t.Fatalf("Shutdown failed: %s", err) 3896 } 3897 3898 checker.IPv4(t, c.GetPacket(), 3899 checker.PayloadLen(len(view)+header.TCPMinimumSize), 3900 checker.TCP( 3901 checker.DstPort(context.TestPort), 3902 checker.TCPSeqNum(uint32(c.IRS)+1), 3903 checker.TCPAckNum(uint32(iss)), 3904 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3905 ), 3906 ) 3907 3908 // Send the ACK that will allow the FIN to be sent as well. 3909 c.SendPacket(nil, &context.Headers{ 3910 SrcPort: context.TestPort, 3911 DstPort: c.Port, 3912 Flags: header.TCPFlagAck, 3913 SeqNum: iss, 3914 AckNum: seqnum.Value(next), 3915 RcvWnd: 30000, 3916 }) 3917 3918 checker.IPv4(t, c.GetPacket(), 3919 checker.PayloadLen(header.TCPMinimumSize), 3920 checker.TCP( 3921 checker.DstPort(context.TestPort), 3922 checker.TCPSeqNum(next), 3923 checker.TCPAckNum(uint32(iss)), 3924 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 3925 ), 3926 ) 3927 next++ 3928 3929 // Send a FIN that acknowledges everything. Get an ACK back. 3930 c.SendPacket(nil, &context.Headers{ 3931 SrcPort: context.TestPort, 3932 DstPort: c.Port, 3933 Flags: header.TCPFlagAck | header.TCPFlagFin, 3934 SeqNum: iss, 3935 AckNum: seqnum.Value(next), 3936 RcvWnd: 30000, 3937 }) 3938 3939 checker.IPv4(t, c.GetPacket(), 3940 checker.PayloadLen(header.TCPMinimumSize), 3941 checker.TCP( 3942 checker.DstPort(context.TestPort), 3943 checker.TCPSeqNum(next), 3944 checker.TCPAckNum(uint32(iss)+1), 3945 checker.TCPFlags(header.TCPFlagAck), 3946 ), 3947 ) 3948 } 3949 3950 func TestFinWithPendingData(t *testing.T) { 3951 c := context.New(t, defaultMTU) 3952 defer c.Cleanup() 3953 3954 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 3955 3956 // Write something out, and acknowledge it to get cwnd to 2. 3957 view := make([]byte, 10) 3958 var r bytes.Reader 3959 r.Reset(view) 3960 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3961 t.Fatalf("Write failed: %s", err) 3962 } 3963 3964 next := uint32(c.IRS) + 1 3965 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 3966 checker.IPv4(t, c.GetPacket(), 3967 checker.PayloadLen(len(view)+header.TCPMinimumSize), 3968 checker.TCP( 3969 checker.DstPort(context.TestPort), 3970 checker.TCPSeqNum(next), 3971 checker.TCPAckNum(uint32(iss)), 3972 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3973 ), 3974 ) 3975 next += uint32(len(view)) 3976 3977 c.SendPacket(nil, &context.Headers{ 3978 SrcPort: context.TestPort, 3979 DstPort: c.Port, 3980 Flags: header.TCPFlagAck, 3981 SeqNum: iss, 3982 AckNum: seqnum.Value(next), 3983 RcvWnd: 30000, 3984 }) 3985 3986 // Write new data, but don't acknowledge it. 3987 r.Reset(view) 3988 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 3989 t.Fatalf("Write failed: %s", err) 3990 } 3991 3992 checker.IPv4(t, c.GetPacket(), 3993 checker.PayloadLen(len(view)+header.TCPMinimumSize), 3994 checker.TCP( 3995 checker.DstPort(context.TestPort), 3996 checker.TCPSeqNum(next), 3997 checker.TCPAckNum(uint32(iss)), 3998 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 3999 ), 4000 ) 4001 next += uint32(len(view)) 4002 4003 // Shutdown the connection, check that we do get a FIN. 4004 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 4005 t.Fatalf("Shutdown failed: %s", err) 4006 } 4007 4008 checker.IPv4(t, c.GetPacket(), 4009 checker.PayloadLen(header.TCPMinimumSize), 4010 checker.TCP( 4011 checker.DstPort(context.TestPort), 4012 checker.TCPSeqNum(next), 4013 checker.TCPAckNum(uint32(iss)), 4014 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 4015 ), 4016 ) 4017 next++ 4018 4019 // Send a FIN that acknowledges everything. Get an ACK back. 4020 c.SendPacket(nil, &context.Headers{ 4021 SrcPort: context.TestPort, 4022 DstPort: c.Port, 4023 Flags: header.TCPFlagAck | header.TCPFlagFin, 4024 SeqNum: iss, 4025 AckNum: seqnum.Value(next), 4026 RcvWnd: 30000, 4027 }) 4028 4029 checker.IPv4(t, c.GetPacket(), 4030 checker.PayloadLen(header.TCPMinimumSize), 4031 checker.TCP( 4032 checker.DstPort(context.TestPort), 4033 checker.TCPSeqNum(next), 4034 checker.TCPAckNum(uint32(iss)+1), 4035 checker.TCPFlags(header.TCPFlagAck), 4036 ), 4037 ) 4038 } 4039 4040 func TestFinWithPartialAck(t *testing.T) { 4041 c := context.New(t, defaultMTU) 4042 defer c.Cleanup() 4043 4044 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 4045 4046 // Write something out, and acknowledge it to get cwnd to 2. Also send 4047 // FIN from the test side. 4048 view := make([]byte, 10) 4049 var r bytes.Reader 4050 r.Reset(view) 4051 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 4052 t.Fatalf("Write failed: %s", err) 4053 } 4054 4055 next := uint32(c.IRS) + 1 4056 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4057 checker.IPv4(t, c.GetPacket(), 4058 checker.PayloadLen(len(view)+header.TCPMinimumSize), 4059 checker.TCP( 4060 checker.DstPort(context.TestPort), 4061 checker.TCPSeqNum(next), 4062 checker.TCPAckNum(uint32(iss)), 4063 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 4064 ), 4065 ) 4066 next += uint32(len(view)) 4067 4068 c.SendPacket(nil, &context.Headers{ 4069 SrcPort: context.TestPort, 4070 DstPort: c.Port, 4071 Flags: header.TCPFlagAck | header.TCPFlagFin, 4072 SeqNum: iss, 4073 AckNum: seqnum.Value(next), 4074 RcvWnd: 30000, 4075 }) 4076 4077 // Check that we get an ACK for the fin. 4078 checker.IPv4(t, c.GetPacket(), 4079 checker.PayloadLen(header.TCPMinimumSize), 4080 checker.TCP( 4081 checker.DstPort(context.TestPort), 4082 checker.TCPSeqNum(next), 4083 checker.TCPAckNum(uint32(iss)+1), 4084 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 4085 ), 4086 ) 4087 4088 // Write new data, but don't acknowledge it. 4089 r.Reset(view) 4090 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 4091 t.Fatalf("Write failed: %s", err) 4092 } 4093 4094 checker.IPv4(t, c.GetPacket(), 4095 checker.PayloadLen(len(view)+header.TCPMinimumSize), 4096 checker.TCP( 4097 checker.DstPort(context.TestPort), 4098 checker.TCPSeqNum(next), 4099 checker.TCPAckNum(uint32(iss)+1), 4100 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 4101 ), 4102 ) 4103 next += uint32(len(view)) 4104 4105 // Shutdown the connection, check that we do get a FIN. 4106 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 4107 t.Fatalf("Shutdown failed: %s", err) 4108 } 4109 4110 checker.IPv4(t, c.GetPacket(), 4111 checker.PayloadLen(header.TCPMinimumSize), 4112 checker.TCP( 4113 checker.DstPort(context.TestPort), 4114 checker.TCPSeqNum(next), 4115 checker.TCPAckNum(uint32(iss)+1), 4116 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 4117 ), 4118 ) 4119 next++ 4120 4121 // Send an ACK for the data, but not for the FIN yet. 4122 c.SendPacket(nil, &context.Headers{ 4123 SrcPort: context.TestPort, 4124 DstPort: c.Port, 4125 Flags: header.TCPFlagAck, 4126 SeqNum: iss.Add(1), 4127 AckNum: seqnum.Value(next - 1), 4128 RcvWnd: 30000, 4129 }) 4130 4131 // Check that we don't get a retransmit of the FIN. 4132 c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) 4133 4134 // Ack the FIN. 4135 c.SendPacket(nil, &context.Headers{ 4136 SrcPort: context.TestPort, 4137 DstPort: c.Port, 4138 Flags: header.TCPFlagAck | header.TCPFlagFin, 4139 SeqNum: iss.Add(1), 4140 AckNum: seqnum.Value(next), 4141 RcvWnd: 30000, 4142 }) 4143 } 4144 4145 func TestUpdateListenBacklog(t *testing.T) { 4146 c := context.New(t, defaultMTU) 4147 defer c.Cleanup() 4148 4149 // Create listener. 4150 var wq waiter.Queue 4151 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 4152 if err != nil { 4153 t.Fatalf("NewEndpoint failed: %s", err) 4154 } 4155 4156 if err := ep.Bind(tcpip.FullAddress{}); err != nil { 4157 t.Fatalf("Bind failed: %s", err) 4158 } 4159 4160 if err := ep.Listen(10); err != nil { 4161 t.Fatalf("Listen failed: %s", err) 4162 } 4163 4164 // Update the backlog with another Listen() on the same endpoint. 4165 if err := ep.Listen(20); err != nil { 4166 t.Fatalf("Listen failed to update backlog: %s", err) 4167 } 4168 4169 ep.Close() 4170 } 4171 4172 func scaledSendWindow(t *testing.T, scale uint8) { 4173 // This test ensures that the endpoint is using the right scaling by 4174 // sending a buffer that is larger than the window size, and ensuring 4175 // that the endpoint doesn't send more than allowed. 4176 c := context.New(t, defaultMTU) 4177 defer c.Cleanup() 4178 4179 maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize 4180 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{ 4181 header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), 4182 header.TCPOptionWS, 3, scale, header.TCPOptionNOP, 4183 }) 4184 4185 // Open up the window with a scaled value. 4186 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4187 c.SendPacket(nil, &context.Headers{ 4188 SrcPort: context.TestPort, 4189 DstPort: c.Port, 4190 Flags: header.TCPFlagAck, 4191 SeqNum: iss, 4192 AckNum: c.IRS.Add(1), 4193 RcvWnd: 1, 4194 }) 4195 4196 // Send some data. Check that it's capped by the window size. 4197 view := make([]byte, 65535) 4198 var r bytes.Reader 4199 r.Reset(view) 4200 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 4201 t.Fatalf("Write failed: %s", err) 4202 } 4203 4204 // Check that only data that fits in the scaled window is sent. 4205 checker.IPv4(t, c.GetPacket(), 4206 checker.PayloadLen((1<<scale)+header.TCPMinimumSize), 4207 checker.TCP( 4208 checker.DstPort(context.TestPort), 4209 checker.TCPSeqNum(uint32(c.IRS)+1), 4210 checker.TCPAckNum(uint32(iss)), 4211 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 4212 ), 4213 ) 4214 4215 // Reset the connection to free resources. 4216 c.SendPacket(nil, &context.Headers{ 4217 SrcPort: context.TestPort, 4218 DstPort: c.Port, 4219 Flags: header.TCPFlagRst, 4220 SeqNum: iss, 4221 }) 4222 } 4223 4224 func TestScaledSendWindow(t *testing.T) { 4225 for scale := uint8(0); scale <= 14; scale++ { 4226 scaledSendWindow(t, scale) 4227 } 4228 } 4229 4230 func TestReceivedValidSegmentCountIncrement(t *testing.T) { 4231 c := context.New(t, defaultMTU) 4232 defer c.Cleanup() 4233 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 4234 stats := c.Stack().Stats() 4235 want := stats.TCP.ValidSegmentsReceived.Value() + 1 4236 4237 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4238 c.SendPacket(nil, &context.Headers{ 4239 SrcPort: context.TestPort, 4240 DstPort: c.Port, 4241 Flags: header.TCPFlagAck, 4242 SeqNum: iss, 4243 AckNum: c.IRS.Add(1), 4244 RcvWnd: 30000, 4245 }) 4246 4247 if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { 4248 t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) 4249 } 4250 if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { 4251 t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) 4252 } 4253 // Ensure there were no errors during handshake. If these stats have 4254 // incremented, then the connection should not have been established. 4255 if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { 4256 t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) 4257 } 4258 } 4259 4260 func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { 4261 c := context.New(t, defaultMTU) 4262 defer c.Cleanup() 4263 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 4264 stats := c.Stack().Stats() 4265 want := stats.TCP.InvalidSegmentsReceived.Value() + 1 4266 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4267 vv := c.BuildSegment(nil, &context.Headers{ 4268 SrcPort: context.TestPort, 4269 DstPort: c.Port, 4270 Flags: header.TCPFlagAck, 4271 SeqNum: iss, 4272 AckNum: c.IRS.Add(1), 4273 RcvWnd: 30000, 4274 }) 4275 tcpbuf := vv.ToView()[header.IPv4MinimumSize:] 4276 tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 4277 4278 c.SendSegment(vv) 4279 4280 if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { 4281 t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) 4282 } 4283 if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { 4284 t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) 4285 } 4286 } 4287 4288 func TestReceivedIncorrectChecksumIncrement(t *testing.T) { 4289 c := context.New(t, defaultMTU) 4290 defer c.Cleanup() 4291 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 4292 stats := c.Stack().Stats() 4293 want := stats.TCP.ChecksumErrors.Value() + 1 4294 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4295 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ 4296 SrcPort: context.TestPort, 4297 DstPort: c.Port, 4298 Flags: header.TCPFlagAck, 4299 SeqNum: iss, 4300 AckNum: c.IRS.Add(1), 4301 RcvWnd: 30000, 4302 }) 4303 tcpbuf := vv.ToView()[header.IPv4MinimumSize:] 4304 // Overwrite a byte in the payload which should cause checksum 4305 // verification to fail. 4306 tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 4307 4308 c.SendSegment(vv) 4309 4310 if got := stats.TCP.ChecksumErrors.Value(); got != want { 4311 t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) 4312 } 4313 if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { 4314 t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) 4315 } 4316 } 4317 4318 func TestReceivedSegmentQueuing(t *testing.T) { 4319 // This test sends 200 segments containing a few bytes each to an 4320 // endpoint and checks that they're all received and acknowledged by 4321 // the endpoint, that is, that none of the segments are dropped by 4322 // internal queues. 4323 c := context.New(t, defaultMTU) 4324 defer c.Cleanup() 4325 4326 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 4327 4328 // Send 200 segments. 4329 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4330 data := []byte{1, 2, 3} 4331 for i := 0; i < 200; i++ { 4332 c.SendPacket(data, &context.Headers{ 4333 SrcPort: context.TestPort, 4334 DstPort: c.Port, 4335 Flags: header.TCPFlagAck, 4336 SeqNum: iss.Add(seqnum.Size(i * len(data))), 4337 AckNum: c.IRS.Add(1), 4338 RcvWnd: 30000, 4339 }) 4340 } 4341 4342 // Receive ACKs for all segments. 4343 last := iss.Add(seqnum.Size(200 * len(data))) 4344 for { 4345 b := c.GetPacket() 4346 checker.IPv4(t, b, 4347 checker.TCP( 4348 checker.DstPort(context.TestPort), 4349 checker.TCPSeqNum(uint32(c.IRS)+1), 4350 checker.TCPFlags(header.TCPFlagAck), 4351 ), 4352 ) 4353 tcpHdr := header.TCP(header.IPv4(b).Payload()) 4354 ack := seqnum.Value(tcpHdr.AckNumber()) 4355 if ack == last { 4356 break 4357 } 4358 4359 if last.LessThan(ack) { 4360 t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) 4361 } 4362 } 4363 } 4364 4365 func TestReadAfterClosedState(t *testing.T) { 4366 // This test ensures that calling Read() or Peek() after the endpoint 4367 // has transitioned to closedState still works if there is pending 4368 // data. To transition to stateClosed without calling Close(), we must 4369 // shutdown the send path and the peer must send its own FIN. 4370 c := context.New(t, defaultMTU) 4371 defer c.Cleanup() 4372 4373 // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed 4374 // after 1 second in TIME_WAIT state. 4375 tcpTimeWaitTimeout := 1 * time.Second 4376 opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) 4377 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 4378 t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) 4379 } 4380 4381 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 4382 4383 we, ch := waiter.NewChannelEntry(nil) 4384 c.WQ.EventRegister(&we, waiter.ReadableEvents) 4385 defer c.WQ.EventUnregister(&we) 4386 4387 ept := endpointTester{c.EP} 4388 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 4389 4390 // Shutdown immediately for write, check that we get a FIN. 4391 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 4392 t.Fatalf("Shutdown failed: %s", err) 4393 } 4394 4395 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 4396 checker.IPv4(t, c.GetPacket(), 4397 checker.PayloadLen(header.TCPMinimumSize), 4398 checker.TCP( 4399 checker.DstPort(context.TestPort), 4400 checker.TCPSeqNum(uint32(c.IRS)+1), 4401 checker.TCPAckNum(uint32(iss)), 4402 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), 4403 ), 4404 ) 4405 4406 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { 4407 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 4408 } 4409 4410 // Send some data and acknowledge the FIN. 4411 data := []byte{1, 2, 3} 4412 c.SendPacket(data, &context.Headers{ 4413 SrcPort: context.TestPort, 4414 DstPort: c.Port, 4415 Flags: header.TCPFlagAck | header.TCPFlagFin, 4416 SeqNum: iss, 4417 AckNum: c.IRS.Add(2), 4418 RcvWnd: 30000, 4419 }) 4420 4421 // Check that ACK is received. 4422 checker.IPv4(t, c.GetPacket(), 4423 checker.TCP( 4424 checker.DstPort(context.TestPort), 4425 checker.TCPSeqNum(uint32(c.IRS)+2), 4426 checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), 4427 checker.TCPFlags(header.TCPFlagAck), 4428 ), 4429 ) 4430 4431 // Give the stack the chance to transition to closed state from 4432 // TIME_WAIT. 4433 time.Sleep(tcpTimeWaitTimeout * 2) 4434 4435 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { 4436 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 4437 } 4438 4439 // Wait for receive to be notified. 4440 select { 4441 case <-ch: 4442 case <-time.After(1 * time.Second): 4443 t.Fatalf("Timed out waiting for data to arrive") 4444 } 4445 4446 // Check that peek works. 4447 var peekBuf bytes.Buffer 4448 res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) 4449 if err != nil { 4450 t.Fatalf("Peek failed: %s", err) 4451 } 4452 4453 if got, want := res.Count, len(data); got != want { 4454 t.Fatalf("res.Count = %d, want %d", got, want) 4455 } 4456 if !bytes.Equal(data, peekBuf.Bytes()) { 4457 t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) 4458 } 4459 4460 // Receive data. 4461 v := ept.CheckRead(t) 4462 if !bytes.Equal(data, v) { 4463 t.Fatalf("got data = %v, want = %v", v, data) 4464 } 4465 4466 // Now that we drained the queue, check that functions fail with the 4467 // right error code. 4468 ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) 4469 var buf bytes.Buffer 4470 { 4471 _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) 4472 if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { 4473 t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) 4474 } 4475 } 4476 } 4477 4478 func TestReusePort(t *testing.T) { 4479 // This test ensures that ports are immediately available for reuse 4480 // after Close on the endpoints using them returns. 4481 c := context.New(t, defaultMTU) 4482 defer c.Cleanup() 4483 4484 // First case, just an endpoint that was bound. 4485 var err tcpip.Error 4486 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4487 if err != nil { 4488 t.Fatalf("NewEndpoint failed; %s", err) 4489 } 4490 c.EP.SocketOptions().SetReuseAddress(true) 4491 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4492 t.Fatalf("Bind failed: %s", err) 4493 } 4494 4495 c.EP.Close() 4496 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4497 if err != nil { 4498 t.Fatalf("NewEndpoint failed; %s", err) 4499 } 4500 c.EP.SocketOptions().SetReuseAddress(true) 4501 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4502 t.Fatalf("Bind failed: %s", err) 4503 } 4504 c.EP.Close() 4505 4506 // Second case, an endpoint that was bound and is connecting.. 4507 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4508 if err != nil { 4509 t.Fatalf("NewEndpoint failed; %s", err) 4510 } 4511 c.EP.SocketOptions().SetReuseAddress(true) 4512 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4513 t.Fatalf("Bind failed: %s", err) 4514 } 4515 { 4516 err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 4517 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 4518 t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) 4519 } 4520 } 4521 c.EP.Close() 4522 4523 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4524 if err != nil { 4525 t.Fatalf("NewEndpoint failed; %s", err) 4526 } 4527 c.EP.SocketOptions().SetReuseAddress(true) 4528 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4529 t.Fatalf("Bind failed: %s", err) 4530 } 4531 c.EP.Close() 4532 4533 // Third case, an endpoint that was bound and is listening. 4534 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4535 if err != nil { 4536 t.Fatalf("NewEndpoint failed; %s", err) 4537 } 4538 c.EP.SocketOptions().SetReuseAddress(true) 4539 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4540 t.Fatalf("Bind failed: %s", err) 4541 } 4542 if err := c.EP.Listen(10); err != nil { 4543 t.Fatalf("Listen failed: %s", err) 4544 } 4545 c.EP.Close() 4546 4547 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4548 if err != nil { 4549 t.Fatalf("NewEndpoint failed; %s", err) 4550 } 4551 c.EP.SocketOptions().SetReuseAddress(true) 4552 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4553 t.Fatalf("Bind failed: %s", err) 4554 } 4555 if err := c.EP.Listen(10); err != nil { 4556 t.Fatalf("Listen failed: %s", err) 4557 } 4558 } 4559 4560 func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { 4561 t.Helper() 4562 4563 s := ep.SocketOptions().GetReceiveBufferSize() 4564 if int(s) != v { 4565 t.Fatalf("got receive buffer size = %d, want = %d", s, v) 4566 } 4567 } 4568 4569 func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { 4570 t.Helper() 4571 4572 if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v { 4573 t.Fatalf("got send buffer size = %d, want = %d", s, v) 4574 } 4575 } 4576 4577 func TestDefaultBufferSizes(t *testing.T) { 4578 s := stack.New(stack.Options{ 4579 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 4580 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, 4581 }) 4582 4583 // Check the default values. 4584 ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4585 if err != nil { 4586 t.Fatalf("NewEndpoint failed; %s", err) 4587 } 4588 defer func() { 4589 if ep != nil { 4590 ep.Close() 4591 } 4592 }() 4593 4594 checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) 4595 checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) 4596 4597 // Change the default send buffer size. 4598 { 4599 opt := tcpip.TCPSendBufferSizeRangeOption{ 4600 Min: 1, 4601 Default: tcp.DefaultSendBufferSize * 2, 4602 Max: tcp.DefaultSendBufferSize * 20, 4603 } 4604 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 4605 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 4606 } 4607 } 4608 4609 ep.Close() 4610 ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4611 if err != nil { 4612 t.Fatalf("NewEndpoint failed; %s", err) 4613 } 4614 4615 checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) 4616 checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) 4617 4618 // Change the default receive buffer size. 4619 { 4620 opt := tcpip.TCPReceiveBufferSizeRangeOption{ 4621 Min: 1, 4622 Default: tcp.DefaultReceiveBufferSize * 3, 4623 Max: tcp.DefaultReceiveBufferSize * 30, 4624 } 4625 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 4626 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 4627 } 4628 } 4629 4630 ep.Close() 4631 ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4632 if err != nil { 4633 t.Fatalf("NewEndpoint failed; %s", err) 4634 } 4635 4636 checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) 4637 checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) 4638 } 4639 4640 func TestMinMaxBufferSizes(t *testing.T) { 4641 s := stack.New(stack.Options{ 4642 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 4643 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, 4644 }) 4645 4646 // Check the default values. 4647 ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4648 if err != nil { 4649 t.Fatalf("NewEndpoint failed; %s", err) 4650 } 4651 defer ep.Close() 4652 4653 // Change the min/max values for send/receive 4654 { 4655 opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} 4656 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 4657 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 4658 } 4659 } 4660 4661 { 4662 opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} 4663 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 4664 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 4665 } 4666 } 4667 4668 // Set values below the min/2. 4669 ep.SocketOptions().SetReceiveBufferSize(99, true) 4670 checkRecvBufferSize(t, ep, 200) 4671 4672 ep.SocketOptions().SetSendBufferSize(149, true) 4673 4674 checkSendBufferSize(t, ep, 300) 4675 4676 // Set values above the max. 4677 ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) 4678 // Values above max are capped at max and then doubled. 4679 checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) 4680 4681 ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) 4682 // Values above max are capped at max and then doubled. 4683 checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) 4684 } 4685 4686 func TestBindToDeviceOption(t *testing.T) { 4687 s := stack.New(stack.Options{ 4688 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 4689 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) 4690 4691 ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 4692 if err != nil { 4693 t.Fatalf("NewEndpoint failed; %s", err) 4694 } 4695 defer ep.Close() 4696 4697 if err := s.CreateNIC(321, loopback.New()); err != nil { 4698 t.Errorf("CreateNIC failed: %s", err) 4699 } 4700 4701 // nicIDPtr is used instead of taking the address of NICID literals, which is 4702 // a compiler error. 4703 nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { 4704 return &s 4705 } 4706 4707 testActions := []struct { 4708 name string 4709 setBindToDevice *tcpip.NICID 4710 setBindToDeviceError tcpip.Error 4711 getBindToDevice int32 4712 }{ 4713 {"GetDefaultValue", nil, nil, 0}, 4714 {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, 4715 {"BindToExistent", nicIDPtr(321), nil, 321}, 4716 {"UnbindToDevice", nicIDPtr(0), nil, 0}, 4717 } 4718 for _, testAction := range testActions { 4719 t.Run(testAction.name, func(t *testing.T) { 4720 if testAction.setBindToDevice != nil { 4721 bindToDevice := int32(*testAction.setBindToDevice) 4722 if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { 4723 t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) 4724 } 4725 } 4726 bindToDevice := ep.SocketOptions().GetBindToDevice() 4727 if bindToDevice != testAction.getBindToDevice { 4728 t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) 4729 } 4730 }) 4731 } 4732 } 4733 4734 func makeStack() (*stack.Stack, tcpip.Error) { 4735 s := stack.New(stack.Options{ 4736 NetworkProtocols: []stack.NetworkProtocolFactory{ 4737 ipv4.NewProtocol, 4738 ipv6.NewProtocol, 4739 }, 4740 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, 4741 }) 4742 4743 id := loopback.New() 4744 if testing.Verbose() { 4745 id = sniffer.New(id) 4746 } 4747 4748 if err := s.CreateNIC(1, id); err != nil { 4749 return nil, err 4750 } 4751 4752 for _, ct := range []struct { 4753 number tcpip.NetworkProtocolNumber 4754 address tcpip.Address 4755 }{ 4756 {ipv4.ProtocolNumber, context.StackAddr}, 4757 {ipv6.ProtocolNumber, context.StackV6Addr}, 4758 } { 4759 if err := s.AddAddress(1, ct.number, ct.address); err != nil { 4760 return nil, err 4761 } 4762 } 4763 4764 s.SetRouteTable([]tcpip.Route{ 4765 { 4766 Destination: header.IPv4EmptySubnet, 4767 NIC: 1, 4768 }, 4769 { 4770 Destination: header.IPv6EmptySubnet, 4771 NIC: 1, 4772 }, 4773 }) 4774 4775 return s, nil 4776 } 4777 4778 func TestSelfConnect(t *testing.T) { 4779 // This test ensures that intentional self-connects work. In particular, 4780 // it checks that if an endpoint binds to say 127.0.0.1:1000 then 4781 // connects to 127.0.0.1:1000, then it will be connected to itself, and 4782 // is able to send and receive data through the same endpoint. 4783 s, err := makeStack() 4784 if err != nil { 4785 t.Fatal(err) 4786 } 4787 4788 var wq waiter.Queue 4789 ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 4790 if err != nil { 4791 t.Fatalf("NewEndpoint failed: %s", err) 4792 } 4793 defer ep.Close() 4794 4795 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 4796 t.Fatalf("Bind failed: %s", err) 4797 } 4798 4799 // Register for notification, then start connection attempt. 4800 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 4801 wq.EventRegister(&waitEntry, waiter.WritableEvents) 4802 defer wq.EventUnregister(&waitEntry) 4803 4804 { 4805 err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) 4806 if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { 4807 t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) 4808 } 4809 } 4810 4811 <-notifyCh 4812 if err := ep.LastError(); err != nil { 4813 t.Fatalf("Connect failed: %s", err) 4814 } 4815 4816 // Write something. 4817 data := []byte{1, 2, 3} 4818 var r bytes.Reader 4819 r.Reset(data) 4820 if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { 4821 t.Fatalf("Write failed: %s", err) 4822 } 4823 4824 // Read back what was written. 4825 wq.EventUnregister(&waitEntry) 4826 wq.EventRegister(&waitEntry, waiter.ReadableEvents) 4827 ept := endpointTester{ep} 4828 rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) 4829 4830 if !bytes.Equal(data, rd) { 4831 t.Fatalf("got data = %v, want = %v", rd, data) 4832 } 4833 } 4834 4835 func TestConnectAvoidsBoundPorts(t *testing.T) { 4836 addressTypes := func(t *testing.T, network string) []string { 4837 switch network { 4838 case "ipv4": 4839 return []string{"v4"} 4840 case "ipv6": 4841 return []string{"v6"} 4842 case "dual": 4843 return []string{"v6", "mapped"} 4844 default: 4845 t.Fatalf("unknown network: '%s'", network) 4846 } 4847 4848 panic("unreachable") 4849 } 4850 4851 address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { 4852 switch addressType { 4853 case "v4": 4854 if isAny { 4855 return "" 4856 } 4857 return context.StackAddr 4858 case "v6": 4859 if isAny { 4860 return "" 4861 } 4862 return context.StackV6Addr 4863 case "mapped": 4864 if isAny { 4865 return context.V4MappedWildcardAddr 4866 } 4867 return context.StackV4MappedAddr 4868 default: 4869 t.Fatalf("unknown address type: '%s'", addressType) 4870 } 4871 4872 panic("unreachable") 4873 } 4874 // This test ensures that Endpoint.Connect doesn't select already-bound ports. 4875 networks := []string{"ipv4", "ipv6", "dual"} 4876 for _, exhaustedNetwork := range networks { 4877 t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { 4878 for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { 4879 t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { 4880 for _, isAny := range []bool{false, true} { 4881 t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { 4882 for _, candidateNetwork := range networks { 4883 t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { 4884 for _, candidateAddressType := range addressTypes(t, candidateNetwork) { 4885 t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { 4886 s, err := makeStack() 4887 if err != nil { 4888 t.Fatal(err) 4889 } 4890 4891 var wq waiter.Queue 4892 var eps []tcpip.Endpoint 4893 defer func() { 4894 for _, ep := range eps { 4895 ep.Close() 4896 } 4897 }() 4898 makeEP := func(network string) tcpip.Endpoint { 4899 var networkProtocolNumber tcpip.NetworkProtocolNumber 4900 switch network { 4901 case "ipv4": 4902 networkProtocolNumber = ipv4.ProtocolNumber 4903 case "ipv6", "dual": 4904 networkProtocolNumber = ipv6.ProtocolNumber 4905 default: 4906 t.Fatalf("unknown network: '%s'", network) 4907 } 4908 ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) 4909 if err != nil { 4910 t.Fatalf("NewEndpoint failed: %s", err) 4911 } 4912 eps = append(eps, ep) 4913 switch network { 4914 case "ipv4": 4915 case "ipv6": 4916 ep.SocketOptions().SetV6Only(true) 4917 case "dual": 4918 ep.SocketOptions().SetV6Only(false) 4919 default: 4920 t.Fatalf("unknown network: '%s'", network) 4921 } 4922 return ep 4923 } 4924 4925 var v4reserved, v6reserved bool 4926 switch exhaustedAddressType { 4927 case "v4", "mapped": 4928 v4reserved = true 4929 case "v6": 4930 v6reserved = true 4931 // Dual stack sockets bound to v6 any reserve on v4 as 4932 // well. 4933 if isAny { 4934 switch exhaustedNetwork { 4935 case "ipv6": 4936 case "dual": 4937 v4reserved = true 4938 default: 4939 t.Fatalf("unknown address type: '%s'", exhaustedNetwork) 4940 } 4941 } 4942 default: 4943 t.Fatalf("unknown address type: '%s'", exhaustedAddressType) 4944 } 4945 var collides bool 4946 switch candidateAddressType { 4947 case "v4", "mapped": 4948 collides = v4reserved 4949 case "v6": 4950 collides = v6reserved 4951 default: 4952 t.Fatalf("unknown address type: '%s'", candidateAddressType) 4953 } 4954 4955 const ( 4956 start = 16000 4957 end = 16050 4958 ) 4959 if err := s.SetPortRange(start, end); err != nil { 4960 t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) 4961 } 4962 for i := start; i <= end; i++ { 4963 if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { 4964 t.Fatalf("Bind(%d) failed: %s", i, err) 4965 } 4966 } 4967 var want tcpip.Error = &tcpip.ErrConnectStarted{} 4968 if collides { 4969 want = &tcpip.ErrNoPortAvailable{} 4970 } 4971 if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { 4972 t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) 4973 } 4974 }) 4975 } 4976 }) 4977 } 4978 }) 4979 } 4980 }) 4981 } 4982 }) 4983 } 4984 } 4985 4986 func TestPathMTUDiscovery(t *testing.T) { 4987 // This test verifies the stack retransmits packets after it receives an 4988 // ICMP packet indicating that the path MTU has been exceeded. 4989 c := context.New(t, 1500) 4990 defer c.Cleanup() 4991 4992 // Create new connection with MSS of 1460. 4993 const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize 4994 c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ 4995 header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), 4996 }) 4997 4998 // Send 3200 bytes of data. 4999 const writeSize = 3200 5000 data := make([]byte, writeSize) 5001 for i := range data { 5002 data[i] = byte(i) 5003 } 5004 var r bytes.Reader 5005 r.Reset(data) 5006 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 5007 t.Fatalf("Write failed: %s", err) 5008 } 5009 5010 receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { 5011 var ret []byte 5012 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 5013 for i, size := range sizes { 5014 p := c.GetPacket() 5015 if i == which { 5016 ret = p 5017 } 5018 checker.IPv4(t, p, 5019 checker.PayloadLen(size+header.TCPMinimumSize), 5020 checker.TCP( 5021 checker.DstPort(context.TestPort), 5022 checker.TCPSeqNum(seqNum), 5023 checker.TCPAckNum(uint32(iss)), 5024 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 5025 ), 5026 ) 5027 seqNum += uint32(size) 5028 } 5029 return ret 5030 } 5031 5032 // Receive three packets. 5033 sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} 5034 first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) 5035 5036 // Send "packet too big" messages back to netstack. 5037 const newMTU = 1200 5038 const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize 5039 mtu := []byte{0, 0, newMTU / 256, newMTU % 256} 5040 c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) 5041 5042 // See retransmitted packets. None exceeding the new max. 5043 sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} 5044 receivePackets(c, sizes, -1, uint32(c.IRS)+1) 5045 } 5046 5047 func TestTCPEndpointProbe(t *testing.T) { 5048 c := context.New(t, 1500) 5049 defer c.Cleanup() 5050 5051 invoked := make(chan struct{}) 5052 c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { 5053 // Validate that the endpoint ID is what we expect. 5054 // 5055 // We don't do an extensive validation of every field but a 5056 // basic sanity test. 5057 if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { 5058 t.Fatalf("got LocalAddress: %q, want: %q", got, want) 5059 } 5060 if got, want := state.ID.LocalPort, c.Port; got != want { 5061 t.Fatalf("got LocalPort: %d, want: %d", got, want) 5062 } 5063 if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { 5064 t.Fatalf("got RemoteAddress: %q, want: %q", got, want) 5065 } 5066 if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { 5067 t.Fatalf("got RemotePort: %d, want: %d", got, want) 5068 } 5069 5070 invoked <- struct{}{} 5071 }) 5072 5073 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 5074 5075 data := []byte{1, 2, 3} 5076 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 5077 c.SendPacket(data, &context.Headers{ 5078 SrcPort: context.TestPort, 5079 DstPort: c.Port, 5080 Flags: header.TCPFlagAck, 5081 SeqNum: iss, 5082 AckNum: c.IRS.Add(1), 5083 RcvWnd: 30000, 5084 }) 5085 5086 select { 5087 case <-invoked: 5088 case <-time.After(100 * time.Millisecond): 5089 t.Fatalf("TCP Probe function was not called") 5090 } 5091 } 5092 5093 func TestStackSetCongestionControl(t *testing.T) { 5094 testCases := []struct { 5095 cc tcpip.CongestionControlOption 5096 err tcpip.Error 5097 }{ 5098 {"reno", nil}, 5099 {"cubic", nil}, 5100 {"blahblah", &tcpip.ErrNoSuchFile{}}, 5101 } 5102 5103 for _, tc := range testCases { 5104 t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { 5105 c := context.New(t, 1500) 5106 defer c.Cleanup() 5107 5108 s := c.Stack() 5109 5110 var oldCC tcpip.CongestionControlOption 5111 if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { 5112 t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) 5113 } 5114 5115 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { 5116 t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) 5117 } 5118 5119 var cc tcpip.CongestionControlOption 5120 if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { 5121 t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) 5122 } 5123 5124 got, want := cc, oldCC 5125 // If SetTransportProtocolOption is expected to succeed 5126 // then the returned value for congestion control should 5127 // match the one specified in the 5128 // SetTransportProtocolOption call above, else it should 5129 // be what it was before the call to 5130 // SetTransportProtocolOption. 5131 if tc.err == nil { 5132 want = tc.cc 5133 } 5134 if got != want { 5135 t.Fatalf("got congestion control: %v, want: %v", got, want) 5136 } 5137 }) 5138 } 5139 } 5140 5141 func TestStackAvailableCongestionControl(t *testing.T) { 5142 c := context.New(t, 1500) 5143 defer c.Cleanup() 5144 5145 s := c.Stack() 5146 5147 // Query permitted congestion control algorithms. 5148 var aCC tcpip.TCPAvailableCongestionControlOption 5149 if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { 5150 t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) 5151 } 5152 if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { 5153 t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) 5154 } 5155 } 5156 5157 func TestStackSetAvailableCongestionControl(t *testing.T) { 5158 c := context.New(t, 1500) 5159 defer c.Cleanup() 5160 5161 s := c.Stack() 5162 5163 // Setting AvailableCongestionControlOption should fail. 5164 aCC := tcpip.TCPAvailableCongestionControlOption("xyz") 5165 if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { 5166 t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) 5167 } 5168 5169 // Verify that we still get the expected list of congestion control options. 5170 var cc tcpip.TCPAvailableCongestionControlOption 5171 if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { 5172 t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) 5173 } 5174 if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { 5175 t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) 5176 } 5177 } 5178 5179 func TestEndpointSetCongestionControl(t *testing.T) { 5180 testCases := []struct { 5181 cc tcpip.CongestionControlOption 5182 err tcpip.Error 5183 }{ 5184 {"reno", nil}, 5185 {"cubic", nil}, 5186 {"blahblah", &tcpip.ErrNoSuchFile{}}, 5187 } 5188 5189 for _, connected := range []bool{false, true} { 5190 for _, tc := range testCases { 5191 t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { 5192 c := context.New(t, 1500) 5193 defer c.Cleanup() 5194 5195 // Create TCP endpoint. 5196 var err tcpip.Error 5197 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 5198 if err != nil { 5199 t.Fatalf("NewEndpoint failed: %s", err) 5200 } 5201 5202 var oldCC tcpip.CongestionControlOption 5203 if err := c.EP.GetSockOpt(&oldCC); err != nil { 5204 t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) 5205 } 5206 5207 if connected { 5208 c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) 5209 } 5210 5211 if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { 5212 t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) 5213 } 5214 5215 var cc tcpip.CongestionControlOption 5216 if err := c.EP.GetSockOpt(&cc); err != nil { 5217 t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) 5218 } 5219 5220 got, want := cc, oldCC 5221 // If SetSockOpt is expected to succeed then the 5222 // returned value for congestion control should match 5223 // the one specified in the SetSockOpt above, else it 5224 // should be what it was before the call to SetSockOpt. 5225 if tc.err == nil { 5226 want = tc.cc 5227 } 5228 if got != want { 5229 t.Fatalf("got congestion control = %+v, want = %+v", got, want) 5230 } 5231 }) 5232 } 5233 } 5234 } 5235 5236 func enableCUBIC(t *testing.T, c *context.Context) { 5237 t.Helper() 5238 opt := tcpip.CongestionControlOption("cubic") 5239 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 5240 t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) 5241 } 5242 } 5243 5244 func TestKeepalive(t *testing.T) { 5245 c := context.New(t, defaultMTU) 5246 defer c.Cleanup() 5247 5248 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 5249 5250 const keepAliveIdle = 100 * time.Millisecond 5251 const keepAliveInterval = 3 * time.Second 5252 keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) 5253 if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { 5254 t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) 5255 } 5256 keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) 5257 if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { 5258 t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) 5259 } 5260 c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) 5261 if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { 5262 t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) 5263 } 5264 c.EP.SocketOptions().SetKeepAlive(true) 5265 5266 // 5 unacked keepalives are sent. ACK each one, and check that the 5267 // connection stays alive after 5. 5268 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 5269 for i := 0; i < 10; i++ { 5270 b := c.GetPacket() 5271 checker.IPv4(t, b, 5272 checker.TCP( 5273 checker.DstPort(context.TestPort), 5274 checker.TCPSeqNum(uint32(c.IRS)), 5275 checker.TCPAckNum(uint32(iss)), 5276 checker.TCPFlags(header.TCPFlagAck), 5277 ), 5278 ) 5279 5280 // Acknowledge the keepalive. 5281 c.SendPacket(nil, &context.Headers{ 5282 SrcPort: context.TestPort, 5283 DstPort: c.Port, 5284 Flags: header.TCPFlagAck, 5285 SeqNum: iss, 5286 AckNum: c.IRS, 5287 RcvWnd: 30000, 5288 }) 5289 } 5290 5291 // Check that the connection is still alive. 5292 ept := endpointTester{c.EP} 5293 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 5294 5295 // Send some data and wait before ACKing it. Keepalives should be disabled 5296 // during this period. 5297 view := make([]byte, 3) 5298 var r bytes.Reader 5299 r.Reset(view) 5300 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 5301 t.Fatalf("Write failed: %s", err) 5302 } 5303 5304 next := uint32(c.IRS) + 1 5305 checker.IPv4(t, c.GetPacket(), 5306 checker.PayloadLen(len(view)+header.TCPMinimumSize), 5307 checker.TCP( 5308 checker.DstPort(context.TestPort), 5309 checker.TCPSeqNum(next), 5310 checker.TCPAckNum(uint32(iss)), 5311 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 5312 ), 5313 ) 5314 5315 // Wait for the packet to be retransmitted. Verify that no keepalives 5316 // were sent. 5317 checker.IPv4(t, c.GetPacket(), 5318 checker.PayloadLen(len(view)+header.TCPMinimumSize), 5319 checker.TCP( 5320 checker.DstPort(context.TestPort), 5321 checker.TCPSeqNum(next), 5322 checker.TCPAckNum(uint32(iss)), 5323 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), 5324 ), 5325 ) 5326 c.CheckNoPacket("Keepalive packet received while unACKed data is pending") 5327 5328 next += uint32(len(view)) 5329 5330 // Send ACK. Keepalives should start sending again. 5331 c.SendPacket(nil, &context.Headers{ 5332 SrcPort: context.TestPort, 5333 DstPort: c.Port, 5334 Flags: header.TCPFlagAck, 5335 SeqNum: iss, 5336 AckNum: seqnum.Value(next), 5337 RcvWnd: 30000, 5338 }) 5339 5340 // Now receive 5 keepalives, but don't ACK them. The connection 5341 // should be reset after 5. 5342 for i := 0; i < 5; i++ { 5343 b := c.GetPacket() 5344 checker.IPv4(t, b, 5345 checker.TCP( 5346 checker.DstPort(context.TestPort), 5347 checker.TCPSeqNum(next-1), 5348 checker.TCPAckNum(uint32(iss)), 5349 checker.TCPFlags(header.TCPFlagAck), 5350 ), 5351 ) 5352 } 5353 5354 // Sleep for a litte over the KeepAlive interval to make sure 5355 // the timer has time to fire after the last ACK and close the 5356 // close the socket. 5357 time.Sleep(keepAliveInterval + keepAliveInterval/2) 5358 5359 // The connection should be terminated after 5 unacked keepalives. 5360 // Send an ACK to trigger a RST from the stack as the endpoint should 5361 // be dead. 5362 c.SendPacket(nil, &context.Headers{ 5363 SrcPort: context.TestPort, 5364 DstPort: c.Port, 5365 Flags: header.TCPFlagAck, 5366 SeqNum: iss, 5367 AckNum: seqnum.Value(next), 5368 RcvWnd: 30000, 5369 }) 5370 5371 checker.IPv4(t, c.GetPacket(), 5372 checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), 5373 ) 5374 5375 if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { 5376 t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) 5377 } 5378 5379 ept.CheckReadError(t, &tcpip.ErrTimeout{}) 5380 5381 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 5382 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) 5383 } 5384 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 5385 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 5386 } 5387 } 5388 5389 func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { 5390 t.Helper() 5391 // Send a SYN request. 5392 irs = seqnum.Value(context.TestInitialSequenceNumber) 5393 c.SendPacket(nil, &context.Headers{ 5394 SrcPort: srcPort, 5395 DstPort: context.StackPort, 5396 Flags: header.TCPFlagSyn, 5397 SeqNum: irs, 5398 RcvWnd: 30000, 5399 }) 5400 5401 // Receive the SYN-ACK reply. 5402 b := c.GetPacket() 5403 tcp := header.TCP(header.IPv4(b).Payload()) 5404 iss = seqnum.Value(tcp.SequenceNumber()) 5405 tcpCheckers := []checker.TransportChecker{ 5406 checker.SrcPort(context.StackPort), 5407 checker.DstPort(srcPort), 5408 checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), 5409 checker.TCPAckNum(uint32(irs) + 1), 5410 } 5411 5412 if synCookieInUse { 5413 // When cookies are in use window scaling is disabled. 5414 tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ 5415 WS: -1, 5416 MSS: c.MSSWithoutOptions(), 5417 })) 5418 } 5419 5420 checker.IPv4(t, b, checker.TCP(tcpCheckers...)) 5421 5422 // Send ACK. 5423 c.SendPacket(nil, &context.Headers{ 5424 SrcPort: srcPort, 5425 DstPort: context.StackPort, 5426 Flags: header.TCPFlagAck, 5427 SeqNum: irs + 1, 5428 AckNum: iss + 1, 5429 RcvWnd: 30000, 5430 }) 5431 return irs, iss 5432 } 5433 5434 func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { 5435 t.Helper() 5436 // Send a SYN request. 5437 irs = seqnum.Value(context.TestInitialSequenceNumber) 5438 c.SendV6Packet(nil, &context.Headers{ 5439 SrcPort: srcPort, 5440 DstPort: context.StackPort, 5441 Flags: header.TCPFlagSyn, 5442 SeqNum: irs, 5443 RcvWnd: 30000, 5444 }) 5445 5446 // Receive the SYN-ACK reply. 5447 b := c.GetV6Packet() 5448 tcp := header.TCP(header.IPv6(b).Payload()) 5449 iss = seqnum.Value(tcp.SequenceNumber()) 5450 tcpCheckers := []checker.TransportChecker{ 5451 checker.SrcPort(context.StackPort), 5452 checker.DstPort(srcPort), 5453 checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), 5454 checker.TCPAckNum(uint32(irs) + 1), 5455 } 5456 5457 if synCookieInUse { 5458 // When cookies are in use window scaling is disabled. 5459 tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ 5460 WS: -1, 5461 MSS: c.MSSWithoutOptionsV6(), 5462 })) 5463 } 5464 5465 checker.IPv6(t, b, checker.TCP(tcpCheckers...)) 5466 5467 // Send ACK. 5468 c.SendV6Packet(nil, &context.Headers{ 5469 SrcPort: srcPort, 5470 DstPort: context.StackPort, 5471 Flags: header.TCPFlagAck, 5472 SeqNum: irs + 1, 5473 AckNum: iss + 1, 5474 RcvWnd: 30000, 5475 }) 5476 return irs, iss 5477 } 5478 5479 // TestListenBacklogFull tests that netstack does not complete handshakes if the 5480 // listen backlog for the endpoint is full. 5481 func TestListenBacklogFull(t *testing.T) { 5482 c := context.New(t, defaultMTU) 5483 defer c.Cleanup() 5484 5485 // Create TCP endpoint. 5486 var err tcpip.Error 5487 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 5488 if err != nil { 5489 t.Fatalf("NewEndpoint failed: %s", err) 5490 } 5491 5492 // Bind to wildcard. 5493 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 5494 t.Fatalf("Bind failed: %s", err) 5495 } 5496 5497 // Test acceptance. 5498 // Start listening. 5499 listenBacklog := 10 5500 if err := c.EP.Listen(listenBacklog); err != nil { 5501 t.Fatalf("Listen failed: %s", err) 5502 } 5503 5504 lastPortOffset := uint16(0) 5505 for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { 5506 executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) 5507 } 5508 5509 time.Sleep(50 * time.Millisecond) 5510 5511 // Now execute send one more SYN. The stack should not respond as the backlog 5512 // is full at this point. 5513 c.SendPacket(nil, &context.Headers{ 5514 SrcPort: context.TestPort + lastPortOffset, 5515 DstPort: context.StackPort, 5516 Flags: header.TCPFlagSyn, 5517 SeqNum: seqnum.Value(context.TestInitialSequenceNumber), 5518 RcvWnd: 30000, 5519 }) 5520 c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) 5521 5522 // Try to accept the connections in the backlog. 5523 we, ch := waiter.NewChannelEntry(nil) 5524 c.WQ.EventRegister(&we, waiter.ReadableEvents) 5525 defer c.WQ.EventUnregister(&we) 5526 5527 for i := 0; i < listenBacklog; i++ { 5528 _, _, err = c.EP.Accept(nil) 5529 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 5530 // Wait for connection to be established. 5531 select { 5532 case <-ch: 5533 _, _, err = c.EP.Accept(nil) 5534 if err != nil { 5535 t.Fatalf("Accept failed: %s", err) 5536 } 5537 5538 case <-time.After(1 * time.Second): 5539 t.Fatalf("Timed out waiting for accept") 5540 } 5541 } 5542 } 5543 5544 // Now verify that there are no more connections that can be accepted. 5545 _, _, err = c.EP.Accept(nil) 5546 if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 5547 select { 5548 case <-ch: 5549 t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) 5550 case <-time.After(1 * time.Second): 5551 } 5552 } 5553 5554 // Now a new handshake must succeed. 5555 executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) 5556 5557 newEP, _, err := c.EP.Accept(nil) 5558 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 5559 // Wait for connection to be established. 5560 select { 5561 case <-ch: 5562 newEP, _, err = c.EP.Accept(nil) 5563 if err != nil { 5564 t.Fatalf("Accept failed: %s", err) 5565 } 5566 5567 case <-time.After(1 * time.Second): 5568 t.Fatalf("Timed out waiting for accept") 5569 } 5570 } 5571 5572 // Now verify that the TCP socket is usable and in a connected state. 5573 data := "Don't panic" 5574 var r strings.Reader 5575 r.Reset(data) 5576 newEP.Write(&r, tcpip.WriteOptions{}) 5577 b := c.GetPacket() 5578 tcp := header.TCP(header.IPv4(b).Payload()) 5579 if string(tcp.Payload()) != data { 5580 t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) 5581 } 5582 } 5583 5584 // TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a 5585 // non unicast IPv4 address are not accepted. 5586 func TestListenNoAcceptNonUnicastV4(t *testing.T) { 5587 multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") 5588 otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") 5589 subnet := context.StackAddrWithPrefix.Subnet() 5590 subnetBroadcastAddr := subnet.Broadcast() 5591 5592 tests := []struct { 5593 name string 5594 srcAddr tcpip.Address 5595 dstAddr tcpip.Address 5596 }{ 5597 { 5598 name: "SourceUnspecified", 5599 srcAddr: header.IPv4Any, 5600 dstAddr: context.StackAddr, 5601 }, 5602 { 5603 name: "SourceBroadcast", 5604 srcAddr: header.IPv4Broadcast, 5605 dstAddr: context.StackAddr, 5606 }, 5607 { 5608 name: "SourceOurMulticast", 5609 srcAddr: multicastAddr, 5610 dstAddr: context.StackAddr, 5611 }, 5612 { 5613 name: "SourceOtherMulticast", 5614 srcAddr: otherMulticastAddr, 5615 dstAddr: context.StackAddr, 5616 }, 5617 { 5618 name: "DestUnspecified", 5619 srcAddr: context.TestAddr, 5620 dstAddr: header.IPv4Any, 5621 }, 5622 { 5623 name: "DestBroadcast", 5624 srcAddr: context.TestAddr, 5625 dstAddr: header.IPv4Broadcast, 5626 }, 5627 { 5628 name: "DestOurMulticast", 5629 srcAddr: context.TestAddr, 5630 dstAddr: multicastAddr, 5631 }, 5632 { 5633 name: "DestOtherMulticast", 5634 srcAddr: context.TestAddr, 5635 dstAddr: otherMulticastAddr, 5636 }, 5637 { 5638 name: "SrcSubnetBroadcast", 5639 srcAddr: subnetBroadcastAddr, 5640 dstAddr: context.StackAddr, 5641 }, 5642 { 5643 name: "DestSubnetBroadcast", 5644 srcAddr: context.TestAddr, 5645 dstAddr: subnetBroadcastAddr, 5646 }, 5647 } 5648 5649 for _, test := range tests { 5650 t.Run(test.name, func(t *testing.T) { 5651 c := context.New(t, defaultMTU) 5652 defer c.Cleanup() 5653 5654 c.Create(-1) 5655 5656 if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { 5657 t.Fatalf("JoinGroup failed: %s", err) 5658 } 5659 5660 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 5661 t.Fatalf("Bind failed: %s", err) 5662 } 5663 5664 if err := c.EP.Listen(1); err != nil { 5665 t.Fatalf("Listen failed: %s", err) 5666 } 5667 5668 irs := seqnum.Value(context.TestInitialSequenceNumber) 5669 c.SendPacketWithAddrs(nil, &context.Headers{ 5670 SrcPort: context.TestPort, 5671 DstPort: context.StackPort, 5672 Flags: header.TCPFlagSyn, 5673 SeqNum: irs, 5674 RcvWnd: 30000, 5675 }, test.srcAddr, test.dstAddr) 5676 c.CheckNoPacket("Should not have received a response") 5677 5678 // Handle normal packet. 5679 c.SendPacketWithAddrs(nil, &context.Headers{ 5680 SrcPort: context.TestPort, 5681 DstPort: context.StackPort, 5682 Flags: header.TCPFlagSyn, 5683 SeqNum: irs, 5684 RcvWnd: 30000, 5685 }, context.TestAddr, context.StackAddr) 5686 checker.IPv4(t, c.GetPacket(), 5687 checker.TCP( 5688 checker.SrcPort(context.StackPort), 5689 checker.DstPort(context.TestPort), 5690 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), 5691 checker.TCPAckNum(uint32(irs)+1))) 5692 }) 5693 } 5694 } 5695 5696 // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a 5697 // non unicast IPv6 address are not accepted. 5698 func TestListenNoAcceptNonUnicastV6(t *testing.T) { 5699 multicastAddr := tcpiptestutil.MustParse6("ff0e::101") 5700 otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") 5701 5702 tests := []struct { 5703 name string 5704 srcAddr tcpip.Address 5705 dstAddr tcpip.Address 5706 }{ 5707 { 5708 "SourceUnspecified", 5709 header.IPv6Any, 5710 context.StackV6Addr, 5711 }, 5712 { 5713 "SourceAllNodes", 5714 header.IPv6AllNodesMulticastAddress, 5715 context.StackV6Addr, 5716 }, 5717 { 5718 "SourceOurMulticast", 5719 multicastAddr, 5720 context.StackV6Addr, 5721 }, 5722 { 5723 "SourceOtherMulticast", 5724 otherMulticastAddr, 5725 context.StackV6Addr, 5726 }, 5727 { 5728 "DestUnspecified", 5729 context.TestV6Addr, 5730 header.IPv6Any, 5731 }, 5732 { 5733 "DestAllNodes", 5734 context.TestV6Addr, 5735 header.IPv6AllNodesMulticastAddress, 5736 }, 5737 { 5738 "DestOurMulticast", 5739 context.TestV6Addr, 5740 multicastAddr, 5741 }, 5742 { 5743 "DestOtherMulticast", 5744 context.TestV6Addr, 5745 otherMulticastAddr, 5746 }, 5747 } 5748 5749 for _, test := range tests { 5750 t.Run(test.name, func(t *testing.T) { 5751 c := context.New(t, defaultMTU) 5752 defer c.Cleanup() 5753 5754 c.CreateV6Endpoint(true) 5755 5756 if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { 5757 t.Fatalf("JoinGroup failed: %s", err) 5758 } 5759 5760 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 5761 t.Fatalf("Bind failed: %s", err) 5762 } 5763 5764 if err := c.EP.Listen(1); err != nil { 5765 t.Fatalf("Listen failed: %s", err) 5766 } 5767 5768 irs := seqnum.Value(context.TestInitialSequenceNumber) 5769 c.SendV6PacketWithAddrs(nil, &context.Headers{ 5770 SrcPort: context.TestPort, 5771 DstPort: context.StackPort, 5772 Flags: header.TCPFlagSyn, 5773 SeqNum: irs, 5774 RcvWnd: 30000, 5775 }, test.srcAddr, test.dstAddr) 5776 c.CheckNoPacket("Should not have received a response") 5777 5778 // Handle normal packet. 5779 c.SendV6PacketWithAddrs(nil, &context.Headers{ 5780 SrcPort: context.TestPort, 5781 DstPort: context.StackPort, 5782 Flags: header.TCPFlagSyn, 5783 SeqNum: irs, 5784 RcvWnd: 30000, 5785 }, context.TestV6Addr, context.StackV6Addr) 5786 checker.IPv6(t, c.GetV6Packet(), 5787 checker.TCP( 5788 checker.SrcPort(context.StackPort), 5789 checker.DstPort(context.TestPort), 5790 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), 5791 checker.TCPAckNum(uint32(irs)+1))) 5792 }) 5793 } 5794 } 5795 5796 func TestListenSynRcvdQueueFull(t *testing.T) { 5797 c := context.New(t, defaultMTU) 5798 defer c.Cleanup() 5799 5800 // Create TCP endpoint. 5801 var err tcpip.Error 5802 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 5803 if err != nil { 5804 t.Fatalf("NewEndpoint failed: %s", err) 5805 } 5806 5807 // Bind to wildcard. 5808 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 5809 t.Fatalf("Bind failed: %s", err) 5810 } 5811 5812 // Test acceptance. 5813 if err := c.EP.Listen(1); err != nil { 5814 t.Fatalf("Listen failed: %s", err) 5815 } 5816 5817 // Send two SYN's the first one should get a SYN-ACK, the 5818 // second one should not get any response and is dropped as 5819 // the accept queue is full. 5820 irs := seqnum.Value(context.TestInitialSequenceNumber) 5821 c.SendPacket(nil, &context.Headers{ 5822 SrcPort: context.TestPort, 5823 DstPort: context.StackPort, 5824 Flags: header.TCPFlagSyn, 5825 SeqNum: irs, 5826 RcvWnd: 30000, 5827 }) 5828 5829 // Receive the SYN-ACK reply. 5830 b := c.GetPacket() 5831 tcp := header.TCP(header.IPv4(b).Payload()) 5832 iss := seqnum.Value(tcp.SequenceNumber()) 5833 tcpCheckers := []checker.TransportChecker{ 5834 checker.SrcPort(context.StackPort), 5835 checker.DstPort(context.TestPort), 5836 checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), 5837 checker.TCPAckNum(uint32(irs) + 1), 5838 } 5839 checker.IPv4(t, b, checker.TCP(tcpCheckers...)) 5840 5841 // Now complete the previous connection. 5842 // Send ACK. 5843 c.SendPacket(nil, &context.Headers{ 5844 SrcPort: context.TestPort, 5845 DstPort: context.StackPort, 5846 Flags: header.TCPFlagAck, 5847 SeqNum: irs + 1, 5848 AckNum: iss + 1, 5849 RcvWnd: 30000, 5850 }) 5851 5852 // Verify if that is delivered to the accept queue. 5853 we, ch := waiter.NewChannelEntry(nil) 5854 c.WQ.EventRegister(&we, waiter.ReadableEvents) 5855 defer c.WQ.EventUnregister(&we) 5856 <-ch 5857 5858 // Now execute send one more SYN. The stack should not respond as the backlog 5859 // is full at this point. 5860 c.SendPacket(nil, &context.Headers{ 5861 SrcPort: context.TestPort + 1, 5862 DstPort: context.StackPort, 5863 Flags: header.TCPFlagSyn, 5864 SeqNum: seqnum.Value(889), 5865 RcvWnd: 30000, 5866 }) 5867 c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) 5868 5869 // Try to accept the connections in the backlog. 5870 newEP, _, err := c.EP.Accept(nil) 5871 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 5872 // Wait for connection to be established. 5873 select { 5874 case <-ch: 5875 newEP, _, err = c.EP.Accept(nil) 5876 if err != nil { 5877 t.Fatalf("Accept failed: %s", err) 5878 } 5879 5880 case <-time.After(1 * time.Second): 5881 t.Fatalf("Timed out waiting for accept") 5882 } 5883 } 5884 5885 // Now verify that the TCP socket is usable and in a connected state. 5886 data := "Don't panic" 5887 var r strings.Reader 5888 r.Reset(data) 5889 newEP.Write(&r, tcpip.WriteOptions{}) 5890 pkt := c.GetPacket() 5891 tcp = header.IPv4(pkt).Payload() 5892 if string(tcp.Payload()) != data { 5893 t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) 5894 } 5895 } 5896 5897 func TestListenBacklogFullSynCookieInUse(t *testing.T) { 5898 c := context.New(t, defaultMTU) 5899 defer c.Cleanup() 5900 5901 // Create TCP endpoint. 5902 var err tcpip.Error 5903 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 5904 if err != nil { 5905 t.Fatalf("NewEndpoint failed: %s", err) 5906 } 5907 5908 // Bind to wildcard. 5909 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 5910 t.Fatalf("Bind failed: %s", err) 5911 } 5912 5913 // Test for SynCookies usage after filling up the backlog. 5914 if err := c.EP.Listen(1); err != nil { 5915 t.Fatalf("Listen failed: %s", err) 5916 } 5917 5918 executeHandshake(t, c, context.TestPort, false) 5919 5920 // Wait for this to be delivered to the accept queue. 5921 time.Sleep(50 * time.Millisecond) 5922 5923 // Send a SYN request. 5924 irs := seqnum.Value(context.TestInitialSequenceNumber) 5925 c.SendPacket(nil, &context.Headers{ 5926 // pick a different src port for new SYN. 5927 SrcPort: context.TestPort + 1, 5928 DstPort: context.StackPort, 5929 Flags: header.TCPFlagSyn, 5930 SeqNum: irs, 5931 RcvWnd: 30000, 5932 }) 5933 // The Syn should be dropped as the endpoint's backlog is full. 5934 c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) 5935 5936 // Verify that there is only one acceptable connection at this point. 5937 we, ch := waiter.NewChannelEntry(nil) 5938 c.WQ.EventRegister(&we, waiter.ReadableEvents) 5939 defer c.WQ.EventUnregister(&we) 5940 5941 _, _, err = c.EP.Accept(nil) 5942 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 5943 // Wait for connection to be established. 5944 select { 5945 case <-ch: 5946 _, _, err = c.EP.Accept(nil) 5947 if err != nil { 5948 t.Fatalf("Accept failed: %s", err) 5949 } 5950 5951 case <-time.After(1 * time.Second): 5952 t.Fatalf("Timed out waiting for accept") 5953 } 5954 } 5955 5956 // Now verify that there are no more connections that can be accepted. 5957 _, _, err = c.EP.Accept(nil) 5958 if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 5959 select { 5960 case <-ch: 5961 t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) 5962 case <-time.After(1 * time.Second): 5963 } 5964 } 5965 } 5966 5967 func TestSYNRetransmit(t *testing.T) { 5968 c := context.New(t, defaultMTU) 5969 defer c.Cleanup() 5970 5971 // Create TCP endpoint. 5972 var err tcpip.Error 5973 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 5974 if err != nil { 5975 t.Fatalf("NewEndpoint failed: %s", err) 5976 } 5977 5978 // Bind to wildcard. 5979 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 5980 t.Fatalf("Bind failed: %s", err) 5981 } 5982 5983 // Start listening. 5984 if err := c.EP.Listen(10); err != nil { 5985 t.Fatalf("Listen failed: %s", err) 5986 } 5987 5988 // Send the same SYN packet multiple times. We should still get a valid SYN-ACK 5989 // reply. 5990 irs := seqnum.Value(context.TestInitialSequenceNumber) 5991 for i := 0; i < 5; i++ { 5992 c.SendPacket(nil, &context.Headers{ 5993 SrcPort: context.TestPort, 5994 DstPort: context.StackPort, 5995 Flags: header.TCPFlagSyn, 5996 SeqNum: irs, 5997 RcvWnd: 30000, 5998 }) 5999 } 6000 6001 // Receive the SYN-ACK reply. 6002 tcpCheckers := []checker.TransportChecker{ 6003 checker.SrcPort(context.StackPort), 6004 checker.DstPort(context.TestPort), 6005 checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), 6006 checker.TCPAckNum(uint32(irs) + 1), 6007 } 6008 checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) 6009 } 6010 6011 func TestSynRcvdBadSeqNumber(t *testing.T) { 6012 c := context.New(t, defaultMTU) 6013 defer c.Cleanup() 6014 6015 // Create TCP endpoint. 6016 var err tcpip.Error 6017 c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 6018 if err != nil { 6019 t.Fatalf("NewEndpoint failed: %s", err) 6020 } 6021 6022 // Bind to wildcard. 6023 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 6024 t.Fatalf("Bind failed: %s", err) 6025 } 6026 6027 // Start listening. 6028 if err := c.EP.Listen(10); err != nil { 6029 t.Fatalf("Listen failed: %s", err) 6030 } 6031 6032 // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state 6033 irs := seqnum.Value(context.TestInitialSequenceNumber) 6034 c.SendPacket(nil, &context.Headers{ 6035 SrcPort: context.TestPort, 6036 DstPort: context.StackPort, 6037 Flags: header.TCPFlagSyn, 6038 SeqNum: irs, 6039 RcvWnd: 30000, 6040 }) 6041 6042 // Receive the SYN-ACK reply. 6043 b := c.GetPacket() 6044 tcpHdr := header.TCP(header.IPv4(b).Payload()) 6045 iss := seqnum.Value(tcpHdr.SequenceNumber()) 6046 tcpCheckers := []checker.TransportChecker{ 6047 checker.SrcPort(context.StackPort), 6048 checker.DstPort(context.TestPort), 6049 checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), 6050 checker.TCPAckNum(uint32(irs) + 1), 6051 } 6052 checker.IPv4(t, b, checker.TCP(tcpCheckers...)) 6053 6054 // Now send a packet with an out-of-window sequence number 6055 largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 6056 c.SendPacket(nil, &context.Headers{ 6057 SrcPort: context.TestPort, 6058 DstPort: context.StackPort, 6059 Flags: header.TCPFlagAck, 6060 SeqNum: largeSeqnum, 6061 AckNum: iss + 1, 6062 RcvWnd: 30000, 6063 }) 6064 6065 // Should receive an ACK with the expected SEQ number 6066 b = c.GetPacket() 6067 tcpCheckers = []checker.TransportChecker{ 6068 checker.SrcPort(context.StackPort), 6069 checker.DstPort(context.TestPort), 6070 checker.TCPFlags(header.TCPFlagAck), 6071 checker.TCPAckNum(uint32(irs) + 1), 6072 checker.TCPSeqNum(uint32(iss + 1)), 6073 } 6074 checker.IPv4(t, b, checker.TCP(tcpCheckers...)) 6075 6076 // Now that the socket replied appropriately with the ACK, 6077 // complete the connection to test that the large SEQ num 6078 // did not change the state from SYN-RCVD. 6079 6080 // Get setup to be notified about connection establishment. 6081 we, ch := waiter.NewChannelEntry(nil) 6082 c.WQ.EventRegister(&we, waiter.ReadableEvents) 6083 defer c.WQ.EventUnregister(&we) 6084 6085 // Send ACK to move to ESTABLISHED state. 6086 c.SendPacket(nil, &context.Headers{ 6087 SrcPort: context.TestPort, 6088 DstPort: context.StackPort, 6089 Flags: header.TCPFlagAck, 6090 SeqNum: irs + 1, 6091 AckNum: iss + 1, 6092 RcvWnd: 30000, 6093 }) 6094 6095 <-ch 6096 newEP, _, err := c.EP.Accept(nil) 6097 if err != nil { 6098 t.Fatalf("Accept failed: %s", err) 6099 } 6100 6101 // Now verify that the TCP socket is usable and in a connected state. 6102 data := "Don't panic" 6103 var r strings.Reader 6104 r.Reset(data) 6105 if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { 6106 t.Fatalf("Write failed: %s", err) 6107 } 6108 6109 pkt := c.GetPacket() 6110 tcpHdr = header.IPv4(pkt).Payload() 6111 if string(tcpHdr.Payload()) != data { 6112 t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) 6113 } 6114 } 6115 6116 func TestPassiveConnectionAttemptIncrement(t *testing.T) { 6117 c := context.New(t, defaultMTU) 6118 defer c.Cleanup() 6119 6120 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 6121 if err != nil { 6122 t.Fatalf("NewEndpoint failed: %s", err) 6123 } 6124 c.EP = ep 6125 if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { 6126 t.Fatalf("Bind failed: %s", err) 6127 } 6128 if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { 6129 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6130 } 6131 if err := c.EP.Listen(1); err != nil { 6132 t.Fatalf("Listen failed: %s", err) 6133 } 6134 if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { 6135 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6136 } 6137 6138 stats := c.Stack().Stats() 6139 want := stats.TCP.PassiveConnectionOpenings.Value() + 1 6140 6141 srcPort := uint16(context.TestPort) 6142 executeHandshake(t, c, srcPort+1, false) 6143 6144 we, ch := waiter.NewChannelEntry(nil) 6145 c.WQ.EventRegister(&we, waiter.ReadableEvents) 6146 defer c.WQ.EventUnregister(&we) 6147 6148 // Verify that there is only one acceptable connection at this point. 6149 _, _, err = c.EP.Accept(nil) 6150 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6151 // Wait for connection to be established. 6152 select { 6153 case <-ch: 6154 _, _, err = c.EP.Accept(nil) 6155 if err != nil { 6156 t.Fatalf("Accept failed: %s", err) 6157 } 6158 6159 case <-time.After(1 * time.Second): 6160 t.Fatalf("Timed out waiting for accept") 6161 } 6162 } 6163 6164 if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { 6165 t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) 6166 } 6167 } 6168 6169 func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { 6170 c := context.New(t, defaultMTU) 6171 defer c.Cleanup() 6172 6173 stats := c.Stack().Stats() 6174 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 6175 if err != nil { 6176 t.Fatalf("NewEndpoint failed: %s", err) 6177 } 6178 c.EP = ep 6179 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { 6180 t.Fatalf("Bind failed: %s", err) 6181 } 6182 if err := c.EP.Listen(1); err != nil { 6183 t.Fatalf("Listen failed: %s", err) 6184 } 6185 6186 srcPort := uint16(context.TestPort) 6187 // Now attempt a handshakes it will fill up the accept backlog. 6188 executeHandshake(t, c, srcPort, false) 6189 6190 // Give time for the final ACK to be processed as otherwise the next handshake could 6191 // get accepted before the previous one based on goroutine scheduling. 6192 time.Sleep(50 * time.Millisecond) 6193 6194 want := stats.TCP.ListenOverflowSynDrop.Value() + 1 6195 6196 // Now we will send one more SYN and this one should get dropped 6197 // Send a SYN request. 6198 c.SendPacket(nil, &context.Headers{ 6199 SrcPort: srcPort + 2, 6200 DstPort: context.StackPort, 6201 Flags: header.TCPFlagSyn, 6202 SeqNum: seqnum.Value(context.TestInitialSequenceNumber), 6203 RcvWnd: 30000, 6204 }) 6205 6206 checkValid := func() []error { 6207 var errors []error 6208 if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { 6209 errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) 6210 } 6211 if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { 6212 errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) 6213 } 6214 return errors 6215 } 6216 6217 start := time.Now() 6218 for time.Since(start) < time.Minute && len(checkValid()) > 0 { 6219 time.Sleep(50 * time.Millisecond) 6220 } 6221 for _, err := range checkValid() { 6222 t.Error(err) 6223 } 6224 if t.Failed() { 6225 t.FailNow() 6226 } 6227 6228 we, ch := waiter.NewChannelEntry(nil) 6229 c.WQ.EventRegister(&we, waiter.ReadableEvents) 6230 defer c.WQ.EventUnregister(&we) 6231 6232 // Now check that there is one acceptable connections. 6233 _, _, err = c.EP.Accept(nil) 6234 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6235 // Wait for connection to be established. 6236 <-ch 6237 _, _, err = c.EP.Accept(nil) 6238 if err != nil { 6239 t.Fatalf("Accept failed: %s", err) 6240 } 6241 } 6242 } 6243 6244 func TestListenDropIncrement(t *testing.T) { 6245 c := context.New(t, defaultMTU) 6246 defer c.Cleanup() 6247 6248 stats := c.Stack().Stats() 6249 c.Create(-1 /*epRcvBuf*/) 6250 6251 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { 6252 t.Fatalf("Bind failed: %s", err) 6253 } 6254 if err := c.EP.Listen(1 /*backlog*/); err != nil { 6255 t.Fatalf("Listen failed: %s", err) 6256 } 6257 6258 initialDropped := stats.DroppedPackets.Value() 6259 6260 // Send RST, FIN segments, that are expected to be dropped by the listener. 6261 c.SendPacket(nil, &context.Headers{ 6262 SrcPort: context.TestPort, 6263 DstPort: context.StackPort, 6264 Flags: header.TCPFlagRst, 6265 }) 6266 c.SendPacket(nil, &context.Headers{ 6267 SrcPort: context.TestPort, 6268 DstPort: context.StackPort, 6269 Flags: header.TCPFlagFin, 6270 }) 6271 6272 // To ensure that the RST, FIN sent earlier are indeed received and ignored 6273 // by the listener, send a SYN and wait for the SYN to be ACKd. 6274 irs := seqnum.Value(context.TestInitialSequenceNumber) 6275 c.SendPacket(nil, &context.Headers{ 6276 SrcPort: context.TestPort, 6277 DstPort: context.StackPort, 6278 Flags: header.TCPFlagSyn, 6279 SeqNum: irs, 6280 }) 6281 checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), 6282 checker.DstPort(context.TestPort), 6283 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), 6284 checker.TCPAckNum(uint32(irs)+1), 6285 )) 6286 6287 if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { 6288 t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) 6289 } 6290 } 6291 6292 func TestEndpointBindListenAcceptState(t *testing.T) { 6293 c := context.New(t, defaultMTU) 6294 defer c.Cleanup() 6295 wq := &waiter.Queue{} 6296 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 6297 if err != nil { 6298 t.Fatalf("NewEndpoint failed: %s", err) 6299 } 6300 6301 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 6302 t.Fatalf("Bind failed: %s", err) 6303 } 6304 if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { 6305 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6306 } 6307 6308 ept := endpointTester{ep} 6309 ept.CheckReadError(t, &tcpip.ErrNotConnected{}) 6310 if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { 6311 t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) 6312 } 6313 6314 if err := ep.Listen(10); err != nil { 6315 t.Fatalf("Listen failed: %s", err) 6316 } 6317 if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { 6318 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6319 } 6320 6321 c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) 6322 6323 // Try to accept the connection. 6324 we, ch := waiter.NewChannelEntry(nil) 6325 wq.EventRegister(&we, waiter.ReadableEvents) 6326 defer wq.EventUnregister(&we) 6327 6328 aep, _, err := ep.Accept(nil) 6329 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6330 // Wait for connection to be established. 6331 select { 6332 case <-ch: 6333 aep, _, err = ep.Accept(nil) 6334 if err != nil { 6335 t.Fatalf("Accept failed: %s", err) 6336 } 6337 6338 case <-time.After(1 * time.Second): 6339 t.Fatalf("Timed out waiting for accept") 6340 } 6341 } 6342 if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { 6343 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6344 } 6345 { 6346 err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) 6347 if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { 6348 t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) 6349 } 6350 } 6351 // Listening endpoint remains in listen state. 6352 if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { 6353 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6354 } 6355 6356 ep.Close() 6357 // Give worker goroutines time to receive the close notification. 6358 time.Sleep(1 * time.Second) 6359 if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { 6360 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6361 } 6362 // Accepted endpoint remains open when the listen endpoint is closed. 6363 if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { 6364 t.Errorf("unexpected endpoint state: want %s, got %s", want, got) 6365 } 6366 6367 } 6368 6369 // This test verifies that the auto tuning does not grow the receive buffer if 6370 // the application is not reading the data actively. 6371 func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { 6372 const mtu = 1500 6373 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize 6374 6375 c := context.New(t, mtu) 6376 defer c.Cleanup() 6377 6378 stk := c.Stack() 6379 // Set lower limits for auto-tuning tests. This is required because the 6380 // test stops the worker which can cause packets to be dropped because 6381 // the segment queue holding unprocessed packets is limited to 500. 6382 const receiveBufferSize = 80 << 10 // 80KB. 6383 const maxReceiveBufferSize = receiveBufferSize * 10 6384 { 6385 opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} 6386 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 6387 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 6388 } 6389 } 6390 6391 // Enable auto-tuning. 6392 { 6393 opt := tcpip.TCPModerateReceiveBufferOption(true) 6394 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 6395 t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) 6396 } 6397 } 6398 // Change the expected window scale to match the value needed for the 6399 // maximum buffer size defined above. 6400 c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) 6401 6402 rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) 6403 6404 // NOTE: The timestamp values in the sent packets are meaningless to the 6405 // peer so we just increment the timestamp value by 1 every batch as we 6406 // are not really using them for anything. Send a single byte to verify 6407 // the advertised window. 6408 tsVal := rawEP.TSVal + 1 6409 6410 // Introduce a 25ms latency by delaying the first byte. 6411 latency := 25 * time.Millisecond 6412 time.Sleep(latency) 6413 // Send an initial payload with atleast segment overhead size. The receive 6414 // window would not grow for smaller segments. 6415 rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) 6416 6417 pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) 6418 rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() 6419 6420 time.Sleep(25 * time.Millisecond) 6421 6422 // Allocate a large enough payload for the test. 6423 payloadSize := receiveBufferSize * 2 6424 b := make([]byte, payloadSize) 6425 6426 worker := (c.EP).(interface { 6427 StopWork() 6428 ResumeWork() 6429 }) 6430 tsVal++ 6431 6432 // Stop the worker goroutine. 6433 worker.StopWork() 6434 start := 0 6435 end := payloadSize / 2 6436 packetsSent := 0 6437 for ; start < end; start += mss { 6438 packetEnd := start + mss 6439 if start+mss > end { 6440 packetEnd = end 6441 } 6442 rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) 6443 packetsSent++ 6444 } 6445 6446 // Resume the worker so that it only sees the packets once all of them 6447 // are waiting to be read. 6448 worker.ResumeWork() 6449 6450 // Since we sent almost the full receive buffer worth of data (some may have 6451 // been dropped due to segment overheads), we should get a zero window back. 6452 pkt = c.GetPacket() 6453 tcpHdr := header.TCP(header.IPv4(pkt).Payload()) 6454 gotRcvWnd := tcpHdr.WindowSize() 6455 wantAckNum := tcpHdr.AckNumber() 6456 if got, want := int(gotRcvWnd), 0; got != want { 6457 t.Fatalf("got rcvWnd: %d, want: %d", got, want) 6458 } 6459 6460 time.Sleep(25 * time.Millisecond) 6461 // Verify that sending more data when receiveBuffer is exhausted. 6462 rawEP.SendPacketWithTS(b[start:start+mss], tsVal) 6463 6464 // Now read all the data from the endpoint and verify that advertised 6465 // window increases to the full available buffer size. 6466 for { 6467 _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) 6468 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6469 break 6470 } 6471 } 6472 6473 // Verify that we receive a non-zero window update ACK. When running 6474 // under thread santizer this test can end up sending more than 1 6475 // ack, 1 for the non-zero window 6476 p := c.GetPacket() 6477 checker.IPv4(t, p, checker.TCP( 6478 checker.TCPAckNum(wantAckNum), 6479 func(t *testing.T, h header.Transport) { 6480 tcp, ok := h.(header.TCP) 6481 if !ok { 6482 return 6483 } 6484 // We use 10% here as the error margin upwards as the initial window we 6485 // got was afer 1 segment was already in the receive buffer queue. 6486 tolerance := 1.1 6487 if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { 6488 t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) 6489 } 6490 }, 6491 )) 6492 } 6493 6494 // This test verifies that the advertised window is auto-tuned up as the 6495 // application is reading the data that is being received. 6496 func TestReceiveBufferAutoTuning(t *testing.T) { 6497 const mtu = 1500 6498 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize 6499 6500 c := context.New(t, mtu) 6501 defer c.Cleanup() 6502 6503 // Enable Auto-tuning. 6504 stk := c.Stack() 6505 // Disable out of window rate limiting for this test by setting it to 0 as we 6506 // use out of window ACKs to measure the advertised window. 6507 var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption 6508 if err := stk.SetOption(tcpInvalidRateLimit); err != nil { 6509 t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) 6510 } 6511 6512 const receiveBufferSize = 80 << 10 // 80KB. 6513 const maxReceiveBufferSize = receiveBufferSize * 10 6514 { 6515 opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} 6516 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 6517 t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) 6518 } 6519 } 6520 6521 // Enable auto-tuning. 6522 { 6523 opt := tcpip.TCPModerateReceiveBufferOption(true) 6524 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 6525 t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) 6526 } 6527 } 6528 // Change the expected window scale to match the value needed for the 6529 // maximum buffer size used by stack. 6530 c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) 6531 6532 rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) 6533 tsVal := rawEP.TSVal 6534 rawEP.NextSeqNum-- 6535 rawEP.SendPacketWithTS(nil, tsVal) 6536 rawEP.NextSeqNum++ 6537 pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) 6538 curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale 6539 scaleRcvWnd := func(rcvWnd int) uint16 { 6540 return uint16(rcvWnd >> c.WindowScale) 6541 } 6542 // Allocate a large array to send to the endpoint. 6543 b := make([]byte, receiveBufferSize*48) 6544 6545 // In every iteration we will send double the number of bytes sent in 6546 // the previous iteration and read the same from the app. The received 6547 // window should grow by at least 2x of bytes read by the app in every 6548 // RTT. 6549 offset := 0 6550 payloadSize := receiveBufferSize / 8 6551 worker := (c.EP).(interface { 6552 StopWork() 6553 ResumeWork() 6554 }) 6555 latency := 1 * time.Millisecond 6556 for i := 0; i < 5; i++ { 6557 tsVal++ 6558 6559 // Stop the worker goroutine. 6560 worker.StopWork() 6561 start := offset 6562 end := offset + payloadSize 6563 totalSent := 0 6564 packetsSent := 0 6565 for ; start < end; start += mss { 6566 rawEP.SendPacketWithTS(b[start:start+mss], tsVal) 6567 totalSent += mss 6568 packetsSent++ 6569 } 6570 6571 // Resume it so that it only sees the packets once all of them 6572 // are waiting to be read. 6573 worker.ResumeWork() 6574 6575 // Give 1ms for the worker to process the packets. 6576 time.Sleep(1 * time.Millisecond) 6577 6578 lastACK := c.GetPacket() 6579 // Discard any intermediate ACKs and only check the last ACK we get in a 6580 // short time period of few ms. 6581 for { 6582 time.Sleep(1 * time.Millisecond) 6583 pkt := c.GetPacketNonBlocking() 6584 if pkt == nil { 6585 break 6586 } 6587 lastACK = pkt 6588 } 6589 if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { 6590 t.Fatalf("advertised window got: %d, want <= %d", got, want) 6591 } 6592 6593 // Now read all the data from the endpoint and invoke the 6594 // moderation API to allow for receive buffer auto-tuning 6595 // to happen before we measure the new window. 6596 totalCopied := 0 6597 for { 6598 res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) 6599 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6600 break 6601 } 6602 totalCopied += res.Count 6603 } 6604 6605 // Invoke the moderation API. This is required for auto-tuning 6606 // to happen. This method is normally expected to be invoked 6607 // from a higher layer than tcpip.Endpoint. So we simulate 6608 // copying to userspace by invoking it explicitly here. 6609 c.EP.ModerateRecvBuf(totalCopied) 6610 6611 // Now send a keep-alive packet to trigger an ACK so that we can 6612 // measure the new window. 6613 rawEP.NextSeqNum-- 6614 rawEP.SendPacketWithTS(nil, tsVal) 6615 rawEP.NextSeqNum++ 6616 6617 if i == 0 { 6618 // In the first iteration the receiver based RTT is not 6619 // yet known as a result the moderation code should not 6620 // increase the advertised window. 6621 rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) 6622 } else { 6623 // Read loop above could generate an ACK if the window had dropped to 6624 // zero and then read had opened it up. 6625 lastACK := c.GetPacket() 6626 // Discard any intermediate ACKs and only check the last ACK we get in a 6627 // short time period of few ms. 6628 for { 6629 time.Sleep(1 * time.Millisecond) 6630 pkt := c.GetPacketNonBlocking() 6631 if pkt == nil { 6632 break 6633 } 6634 lastACK = pkt 6635 } 6636 curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale 6637 // If thew new current window is close maxReceiveBufferSize then terminate 6638 // the loop. This can happen before all iterations are done due to timing 6639 // differences when running the test. 6640 if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { 6641 break 6642 } 6643 // Increase the latency after first two iterations to 6644 // establish a low RTT value in the receiver since it 6645 // only tracks the lowest value. This ensures that when 6646 // ModerateRcvBuf is called the elapsed time is always > 6647 // rtt. Without this the test is flaky due to delays due 6648 // to scheduling/wakeup etc. 6649 latency += 50 * time.Millisecond 6650 } 6651 time.Sleep(latency) 6652 offset += payloadSize 6653 payloadSize *= 2 6654 } 6655 // Check that at the end of our iterations the receive window grew close to the maximum 6656 // permissible size of maxReceiveBufferSize/2 6657 if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { 6658 t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) 6659 } 6660 6661 } 6662 6663 func TestDelayEnabled(t *testing.T) { 6664 c := context.New(t, defaultMTU) 6665 defer c.Cleanup() 6666 checkDelayOption(t, c, false, false) // Delay is disabled by default. 6667 6668 for _, delayEnabled := range []bool{false, true} { 6669 t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { 6670 c := context.New(t, defaultMTU) 6671 defer c.Cleanup() 6672 opt := tcpip.TCPDelayEnabled(delayEnabled) 6673 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 6674 t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) 6675 } 6676 checkDelayOption(t, c, opt, delayEnabled) 6677 }) 6678 } 6679 } 6680 6681 func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { 6682 t.Helper() 6683 6684 var gotDelayEnabled tcpip.TCPDelayEnabled 6685 if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { 6686 t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) 6687 } 6688 if gotDelayEnabled != wantDelayEnabled { 6689 t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) 6690 } 6691 6692 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) 6693 if err != nil { 6694 t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) 6695 } 6696 gotDelayOption := ep.SocketOptions().GetDelayOption() 6697 if gotDelayOption != wantDelayOption { 6698 t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) 6699 } 6700 } 6701 6702 func TestTCPLingerTimeout(t *testing.T) { 6703 c := context.New(t, 1500 /* mtu */) 6704 defer c.Cleanup() 6705 6706 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 6707 6708 testCases := []struct { 6709 name string 6710 tcpLingerTimeout time.Duration 6711 want time.Duration 6712 }{ 6713 {"NegativeLingerTimeout", -123123, -1}, 6714 // Zero is treated same as the stack's default TCP_LINGER2 timeout. 6715 {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, 6716 {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, 6717 // Values > stack's TCPLingerTimeout are capped to the stack's 6718 // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) 6719 {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, 6720 } 6721 for _, tc := range testCases { 6722 t.Run(tc.name, func(t *testing.T) { 6723 v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) 6724 if err := c.EP.SetSockOpt(&v); err != nil { 6725 t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) 6726 } 6727 6728 v = 0 6729 if err := c.EP.GetSockOpt(&v); err != nil { 6730 t.Fatalf("GetSockOpt(&%T) = %s", v, err) 6731 } 6732 if got, want := time.Duration(v), tc.want; got != want { 6733 t.Fatalf("got linger timeout = %s, want = %s", got, want) 6734 } 6735 }) 6736 } 6737 } 6738 6739 func TestTCPTimeWaitRSTIgnored(t *testing.T) { 6740 c := context.New(t, defaultMTU) 6741 defer c.Cleanup() 6742 6743 wq := &waiter.Queue{} 6744 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 6745 if err != nil { 6746 t.Fatalf("NewEndpoint failed: %s", err) 6747 } 6748 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 6749 t.Fatalf("Bind failed: %s", err) 6750 } 6751 6752 if err := ep.Listen(10); err != nil { 6753 t.Fatalf("Listen failed: %s", err) 6754 } 6755 6756 // Send a SYN request. 6757 iss := seqnum.Value(context.TestInitialSequenceNumber) 6758 c.SendPacket(nil, &context.Headers{ 6759 SrcPort: context.TestPort, 6760 DstPort: context.StackPort, 6761 Flags: header.TCPFlagSyn, 6762 SeqNum: iss, 6763 RcvWnd: 30000, 6764 }) 6765 6766 // Receive the SYN-ACK reply. 6767 b := c.GetPacket() 6768 tcpHdr := header.TCP(header.IPv4(b).Payload()) 6769 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 6770 6771 ackHeaders := &context.Headers{ 6772 SrcPort: context.TestPort, 6773 DstPort: context.StackPort, 6774 Flags: header.TCPFlagAck, 6775 SeqNum: iss + 1, 6776 AckNum: c.IRS + 1, 6777 } 6778 6779 // Send ACK. 6780 c.SendPacket(nil, ackHeaders) 6781 6782 // Try to accept the connection. 6783 we, ch := waiter.NewChannelEntry(nil) 6784 wq.EventRegister(&we, waiter.ReadableEvents) 6785 defer wq.EventUnregister(&we) 6786 6787 c.EP, _, err = ep.Accept(nil) 6788 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6789 // Wait for connection to be established. 6790 select { 6791 case <-ch: 6792 c.EP, _, err = ep.Accept(nil) 6793 if err != nil { 6794 t.Fatalf("Accept failed: %s", err) 6795 } 6796 6797 case <-time.After(1 * time.Second): 6798 t.Fatalf("Timed out waiting for accept") 6799 } 6800 } 6801 6802 c.EP.Close() 6803 checker.IPv4(t, c.GetPacket(), checker.TCP( 6804 checker.SrcPort(context.StackPort), 6805 checker.DstPort(context.TestPort), 6806 checker.TCPSeqNum(uint32(c.IRS+1)), 6807 checker.TCPAckNum(uint32(iss)+1), 6808 checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) 6809 6810 finHeaders := &context.Headers{ 6811 SrcPort: context.TestPort, 6812 DstPort: context.StackPort, 6813 Flags: header.TCPFlagAck | header.TCPFlagFin, 6814 SeqNum: iss + 1, 6815 AckNum: c.IRS + 2, 6816 } 6817 6818 c.SendPacket(nil, finHeaders) 6819 6820 // Get the ACK to the FIN we just sent. 6821 checker.IPv4(t, c.GetPacket(), checker.TCP( 6822 checker.SrcPort(context.StackPort), 6823 checker.DstPort(context.TestPort), 6824 checker.TCPSeqNum(uint32(c.IRS+2)), 6825 checker.TCPAckNum(uint32(iss)+2), 6826 checker.TCPFlags(header.TCPFlagAck))) 6827 6828 // Now send a RST and this should be ignored and not 6829 // generate an ACK. 6830 c.SendPacket(nil, &context.Headers{ 6831 SrcPort: context.TestPort, 6832 DstPort: context.StackPort, 6833 Flags: header.TCPFlagRst, 6834 SeqNum: iss + 1, 6835 AckNum: c.IRS + 2, 6836 }) 6837 6838 c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) 6839 6840 // Out of order ACK should generate an immediate ACK in 6841 // TIME_WAIT. 6842 c.SendPacket(nil, &context.Headers{ 6843 SrcPort: context.TestPort, 6844 DstPort: context.StackPort, 6845 Flags: header.TCPFlagAck, 6846 SeqNum: iss + 1, 6847 AckNum: c.IRS + 3, 6848 }) 6849 6850 checker.IPv4(t, c.GetPacket(), checker.TCP( 6851 checker.SrcPort(context.StackPort), 6852 checker.DstPort(context.TestPort), 6853 checker.TCPSeqNum(uint32(c.IRS+2)), 6854 checker.TCPAckNum(uint32(iss)+2), 6855 checker.TCPFlags(header.TCPFlagAck))) 6856 } 6857 6858 func TestTCPTimeWaitOutOfOrder(t *testing.T) { 6859 c := context.New(t, defaultMTU) 6860 defer c.Cleanup() 6861 6862 wq := &waiter.Queue{} 6863 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 6864 if err != nil { 6865 t.Fatalf("NewEndpoint failed: %s", err) 6866 } 6867 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 6868 t.Fatalf("Bind failed: %s", err) 6869 } 6870 6871 if err := ep.Listen(10); err != nil { 6872 t.Fatalf("Listen failed: %s", err) 6873 } 6874 6875 // Send a SYN request. 6876 iss := seqnum.Value(context.TestInitialSequenceNumber) 6877 c.SendPacket(nil, &context.Headers{ 6878 SrcPort: context.TestPort, 6879 DstPort: context.StackPort, 6880 Flags: header.TCPFlagSyn, 6881 SeqNum: iss, 6882 RcvWnd: 30000, 6883 }) 6884 6885 // Receive the SYN-ACK reply. 6886 b := c.GetPacket() 6887 tcpHdr := header.TCP(header.IPv4(b).Payload()) 6888 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 6889 6890 ackHeaders := &context.Headers{ 6891 SrcPort: context.TestPort, 6892 DstPort: context.StackPort, 6893 Flags: header.TCPFlagAck, 6894 SeqNum: iss + 1, 6895 AckNum: c.IRS + 1, 6896 } 6897 6898 // Send ACK. 6899 c.SendPacket(nil, ackHeaders) 6900 6901 // Try to accept the connection. 6902 we, ch := waiter.NewChannelEntry(nil) 6903 wq.EventRegister(&we, waiter.ReadableEvents) 6904 defer wq.EventUnregister(&we) 6905 6906 c.EP, _, err = ep.Accept(nil) 6907 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 6908 // Wait for connection to be established. 6909 select { 6910 case <-ch: 6911 c.EP, _, err = ep.Accept(nil) 6912 if err != nil { 6913 t.Fatalf("Accept failed: %s", err) 6914 } 6915 6916 case <-time.After(1 * time.Second): 6917 t.Fatalf("Timed out waiting for accept") 6918 } 6919 } 6920 6921 c.EP.Close() 6922 checker.IPv4(t, c.GetPacket(), checker.TCP( 6923 checker.SrcPort(context.StackPort), 6924 checker.DstPort(context.TestPort), 6925 checker.TCPSeqNum(uint32(c.IRS+1)), 6926 checker.TCPAckNum(uint32(iss)+1), 6927 checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) 6928 6929 finHeaders := &context.Headers{ 6930 SrcPort: context.TestPort, 6931 DstPort: context.StackPort, 6932 Flags: header.TCPFlagAck | header.TCPFlagFin, 6933 SeqNum: iss + 1, 6934 AckNum: c.IRS + 2, 6935 } 6936 6937 c.SendPacket(nil, finHeaders) 6938 6939 // Get the ACK to the FIN we just sent. 6940 checker.IPv4(t, c.GetPacket(), checker.TCP( 6941 checker.SrcPort(context.StackPort), 6942 checker.DstPort(context.TestPort), 6943 checker.TCPSeqNum(uint32(c.IRS+2)), 6944 checker.TCPAckNum(uint32(iss)+2), 6945 checker.TCPFlags(header.TCPFlagAck))) 6946 6947 // Out of order ACK should generate an immediate ACK in 6948 // TIME_WAIT. 6949 c.SendPacket(nil, &context.Headers{ 6950 SrcPort: context.TestPort, 6951 DstPort: context.StackPort, 6952 Flags: header.TCPFlagAck, 6953 SeqNum: iss + 1, 6954 AckNum: c.IRS + 3, 6955 }) 6956 6957 checker.IPv4(t, c.GetPacket(), checker.TCP( 6958 checker.SrcPort(context.StackPort), 6959 checker.DstPort(context.TestPort), 6960 checker.TCPSeqNum(uint32(c.IRS+2)), 6961 checker.TCPAckNum(uint32(iss)+2), 6962 checker.TCPFlags(header.TCPFlagAck))) 6963 } 6964 6965 func TestTCPTimeWaitNewSyn(t *testing.T) { 6966 c := context.New(t, defaultMTU) 6967 defer c.Cleanup() 6968 6969 wq := &waiter.Queue{} 6970 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 6971 if err != nil { 6972 t.Fatalf("NewEndpoint failed: %s", err) 6973 } 6974 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 6975 t.Fatalf("Bind failed: %s", err) 6976 } 6977 6978 if err := ep.Listen(10); err != nil { 6979 t.Fatalf("Listen failed: %s", err) 6980 } 6981 6982 // Send a SYN request. 6983 iss := seqnum.Value(context.TestInitialSequenceNumber) 6984 c.SendPacket(nil, &context.Headers{ 6985 SrcPort: context.TestPort, 6986 DstPort: context.StackPort, 6987 Flags: header.TCPFlagSyn, 6988 SeqNum: iss, 6989 RcvWnd: 30000, 6990 }) 6991 6992 // Receive the SYN-ACK reply. 6993 b := c.GetPacket() 6994 tcpHdr := header.TCP(header.IPv4(b).Payload()) 6995 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 6996 6997 ackHeaders := &context.Headers{ 6998 SrcPort: context.TestPort, 6999 DstPort: context.StackPort, 7000 Flags: header.TCPFlagAck, 7001 SeqNum: iss + 1, 7002 AckNum: c.IRS + 1, 7003 } 7004 7005 // Send ACK. 7006 c.SendPacket(nil, ackHeaders) 7007 7008 // Try to accept the connection. 7009 we, ch := waiter.NewChannelEntry(nil) 7010 wq.EventRegister(&we, waiter.ReadableEvents) 7011 defer wq.EventUnregister(&we) 7012 7013 c.EP, _, err = ep.Accept(nil) 7014 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 7015 // Wait for connection to be established. 7016 select { 7017 case <-ch: 7018 c.EP, _, err = ep.Accept(nil) 7019 if err != nil { 7020 t.Fatalf("Accept failed: %s", err) 7021 } 7022 7023 case <-time.After(1 * time.Second): 7024 t.Fatalf("Timed out waiting for accept") 7025 } 7026 } 7027 7028 c.EP.Close() 7029 checker.IPv4(t, c.GetPacket(), checker.TCP( 7030 checker.SrcPort(context.StackPort), 7031 checker.DstPort(context.TestPort), 7032 checker.TCPSeqNum(uint32(c.IRS+1)), 7033 checker.TCPAckNum(uint32(iss)+1), 7034 checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) 7035 7036 finHeaders := &context.Headers{ 7037 SrcPort: context.TestPort, 7038 DstPort: context.StackPort, 7039 Flags: header.TCPFlagAck | header.TCPFlagFin, 7040 SeqNum: iss + 1, 7041 AckNum: c.IRS + 2, 7042 } 7043 7044 c.SendPacket(nil, finHeaders) 7045 7046 // Get the ACK to the FIN we just sent. 7047 checker.IPv4(t, c.GetPacket(), checker.TCP( 7048 checker.SrcPort(context.StackPort), 7049 checker.DstPort(context.TestPort), 7050 checker.TCPSeqNum(uint32(c.IRS+2)), 7051 checker.TCPAckNum(uint32(iss)+2), 7052 checker.TCPFlags(header.TCPFlagAck))) 7053 7054 // Send a SYN request w/ sequence number lower than 7055 // the highest sequence number sent. We just reuse 7056 // the same number. 7057 iss = seqnum.Value(context.TestInitialSequenceNumber) 7058 c.SendPacket(nil, &context.Headers{ 7059 SrcPort: context.TestPort, 7060 DstPort: context.StackPort, 7061 Flags: header.TCPFlagSyn, 7062 SeqNum: iss, 7063 RcvWnd: 30000, 7064 }) 7065 7066 c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) 7067 7068 // drain any older notifications from the notification channel before attempting 7069 // 2nd connection. 7070 select { 7071 case <-ch: 7072 default: 7073 } 7074 7075 // Send a SYN request w/ sequence number higher than 7076 // the highest sequence number sent. 7077 iss = iss.Add(3) 7078 c.SendPacket(nil, &context.Headers{ 7079 SrcPort: context.TestPort, 7080 DstPort: context.StackPort, 7081 Flags: header.TCPFlagSyn, 7082 SeqNum: iss, 7083 RcvWnd: 30000, 7084 }) 7085 7086 // Receive the SYN-ACK reply. 7087 b = c.GetPacket() 7088 tcpHdr = header.IPv4(b).Payload() 7089 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 7090 7091 ackHeaders = &context.Headers{ 7092 SrcPort: context.TestPort, 7093 DstPort: context.StackPort, 7094 Flags: header.TCPFlagAck, 7095 SeqNum: iss + 1, 7096 AckNum: c.IRS + 1, 7097 } 7098 7099 // Send ACK. 7100 c.SendPacket(nil, ackHeaders) 7101 7102 // Try to accept the connection. 7103 c.EP, _, err = ep.Accept(nil) 7104 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 7105 // Wait for connection to be established. 7106 select { 7107 case <-ch: 7108 c.EP, _, err = ep.Accept(nil) 7109 if err != nil { 7110 t.Fatalf("Accept failed: %s", err) 7111 } 7112 7113 case <-time.After(1 * time.Second): 7114 t.Fatalf("Timed out waiting for accept") 7115 } 7116 } 7117 } 7118 7119 func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { 7120 c := context.New(t, defaultMTU) 7121 defer c.Cleanup() 7122 7123 // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed 7124 // after 5 seconds in TIME_WAIT state. 7125 tcpTimeWaitTimeout := 5 * time.Second 7126 opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) 7127 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 7128 t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) 7129 } 7130 7131 want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 7132 7133 wq := &waiter.Queue{} 7134 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 7135 if err != nil { 7136 t.Fatalf("NewEndpoint failed: %s", err) 7137 } 7138 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 7139 t.Fatalf("Bind failed: %s", err) 7140 } 7141 7142 if err := ep.Listen(10); err != nil { 7143 t.Fatalf("Listen failed: %s", err) 7144 } 7145 7146 // Send a SYN request. 7147 iss := seqnum.Value(context.TestInitialSequenceNumber) 7148 c.SendPacket(nil, &context.Headers{ 7149 SrcPort: context.TestPort, 7150 DstPort: context.StackPort, 7151 Flags: header.TCPFlagSyn, 7152 SeqNum: iss, 7153 RcvWnd: 30000, 7154 }) 7155 7156 // Receive the SYN-ACK reply. 7157 b := c.GetPacket() 7158 tcpHdr := header.TCP(header.IPv4(b).Payload()) 7159 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 7160 7161 ackHeaders := &context.Headers{ 7162 SrcPort: context.TestPort, 7163 DstPort: context.StackPort, 7164 Flags: header.TCPFlagAck, 7165 SeqNum: iss + 1, 7166 AckNum: c.IRS + 1, 7167 } 7168 7169 // Send ACK. 7170 c.SendPacket(nil, ackHeaders) 7171 7172 // Try to accept the connection. 7173 we, ch := waiter.NewChannelEntry(nil) 7174 wq.EventRegister(&we, waiter.ReadableEvents) 7175 defer wq.EventUnregister(&we) 7176 7177 c.EP, _, err = ep.Accept(nil) 7178 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 7179 // Wait for connection to be established. 7180 select { 7181 case <-ch: 7182 c.EP, _, err = ep.Accept(nil) 7183 if err != nil { 7184 t.Fatalf("Accept failed: %s", err) 7185 } 7186 7187 case <-time.After(1 * time.Second): 7188 t.Fatalf("Timed out waiting for accept") 7189 } 7190 } 7191 7192 c.EP.Close() 7193 checker.IPv4(t, c.GetPacket(), checker.TCP( 7194 checker.SrcPort(context.StackPort), 7195 checker.DstPort(context.TestPort), 7196 checker.TCPSeqNum(uint32(c.IRS+1)), 7197 checker.TCPAckNum(uint32(iss)+1), 7198 checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) 7199 7200 finHeaders := &context.Headers{ 7201 SrcPort: context.TestPort, 7202 DstPort: context.StackPort, 7203 Flags: header.TCPFlagAck | header.TCPFlagFin, 7204 SeqNum: iss + 1, 7205 AckNum: c.IRS + 2, 7206 } 7207 7208 c.SendPacket(nil, finHeaders) 7209 7210 // Get the ACK to the FIN we just sent. 7211 checker.IPv4(t, c.GetPacket(), checker.TCP( 7212 checker.SrcPort(context.StackPort), 7213 checker.DstPort(context.TestPort), 7214 checker.TCPSeqNum(uint32(c.IRS+2)), 7215 checker.TCPAckNum(uint32(iss)+2), 7216 checker.TCPFlags(header.TCPFlagAck))) 7217 7218 time.Sleep(2 * time.Second) 7219 7220 // Now send a duplicate FIN. This should cause the TIME_WAIT to extend 7221 // by another 5 seconds and also send us a duplicate ACK as it should 7222 // indicate that the final ACK was potentially lost. 7223 c.SendPacket(nil, finHeaders) 7224 7225 // Get the ACK to the FIN we just sent. 7226 checker.IPv4(t, c.GetPacket(), checker.TCP( 7227 checker.SrcPort(context.StackPort), 7228 checker.DstPort(context.TestPort), 7229 checker.TCPSeqNum(uint32(c.IRS+2)), 7230 checker.TCPAckNum(uint32(iss)+2), 7231 checker.TCPFlags(header.TCPFlagAck))) 7232 7233 // Sleep for 4 seconds so at this point we are 1 second past the 7234 // original tcpLingerTimeout of 5 seconds. 7235 time.Sleep(4 * time.Second) 7236 7237 // Send an ACK and it should not generate any packet as the socket 7238 // should still be in TIME_WAIT for another another 5 seconds due 7239 // to the duplicate FIN we sent earlier. 7240 *ackHeaders = *finHeaders 7241 ackHeaders.SeqNum = ackHeaders.SeqNum + 1 7242 ackHeaders.Flags = header.TCPFlagAck 7243 c.SendPacket(nil, ackHeaders) 7244 7245 c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) 7246 // Now sleep for another 2 seconds so that we are past the 7247 // extended TIME_WAIT of 7 seconds (2 + 5). 7248 time.Sleep(2 * time.Second) 7249 7250 // Resend the same ACK. 7251 c.SendPacket(nil, ackHeaders) 7252 7253 // Receive the RST that should be generated as there is no valid 7254 // endpoint. 7255 checker.IPv4(t, c.GetPacket(), checker.TCP( 7256 checker.SrcPort(context.StackPort), 7257 checker.DstPort(context.TestPort), 7258 checker.TCPSeqNum(uint32(ackHeaders.AckNum)), 7259 checker.TCPAckNum(0), 7260 checker.TCPFlags(header.TCPFlagRst))) 7261 7262 if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { 7263 t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) 7264 } 7265 if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { 7266 t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) 7267 } 7268 } 7269 7270 func TestTCPCloseWithData(t *testing.T) { 7271 c := context.New(t, defaultMTU) 7272 defer c.Cleanup() 7273 7274 // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed 7275 // after 5 seconds in TIME_WAIT state. 7276 tcpTimeWaitTimeout := 5 * time.Second 7277 opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) 7278 if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 7279 t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) 7280 } 7281 7282 wq := &waiter.Queue{} 7283 ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) 7284 if err != nil { 7285 t.Fatalf("NewEndpoint failed: %s", err) 7286 } 7287 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 7288 t.Fatalf("Bind failed: %s", err) 7289 } 7290 7291 if err := ep.Listen(10); err != nil { 7292 t.Fatalf("Listen failed: %s", err) 7293 } 7294 7295 // Send a SYN request. 7296 iss := seqnum.Value(context.TestInitialSequenceNumber) 7297 c.SendPacket(nil, &context.Headers{ 7298 SrcPort: context.TestPort, 7299 DstPort: context.StackPort, 7300 Flags: header.TCPFlagSyn, 7301 SeqNum: iss, 7302 RcvWnd: 30000, 7303 }) 7304 7305 // Receive the SYN-ACK reply. 7306 b := c.GetPacket() 7307 tcpHdr := header.TCP(header.IPv4(b).Payload()) 7308 c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) 7309 7310 ackHeaders := &context.Headers{ 7311 SrcPort: context.TestPort, 7312 DstPort: context.StackPort, 7313 Flags: header.TCPFlagAck, 7314 SeqNum: iss + 1, 7315 AckNum: c.IRS + 1, 7316 RcvWnd: 30000, 7317 } 7318 7319 // Send ACK. 7320 c.SendPacket(nil, ackHeaders) 7321 7322 // Try to accept the connection. 7323 we, ch := waiter.NewChannelEntry(nil) 7324 wq.EventRegister(&we, waiter.ReadableEvents) 7325 defer wq.EventUnregister(&we) 7326 7327 c.EP, _, err = ep.Accept(nil) 7328 if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { 7329 // Wait for connection to be established. 7330 select { 7331 case <-ch: 7332 c.EP, _, err = ep.Accept(nil) 7333 if err != nil { 7334 t.Fatalf("Accept failed: %s", err) 7335 } 7336 7337 case <-time.After(1 * time.Second): 7338 t.Fatalf("Timed out waiting for accept") 7339 } 7340 } 7341 7342 // Now trigger a passive close by sending a FIN. 7343 finHeaders := &context.Headers{ 7344 SrcPort: context.TestPort, 7345 DstPort: context.StackPort, 7346 Flags: header.TCPFlagAck | header.TCPFlagFin, 7347 SeqNum: iss + 1, 7348 AckNum: c.IRS + 2, 7349 RcvWnd: 30000, 7350 } 7351 7352 c.SendPacket(nil, finHeaders) 7353 7354 // Get the ACK to the FIN we just sent. 7355 checker.IPv4(t, c.GetPacket(), checker.TCP( 7356 checker.SrcPort(context.StackPort), 7357 checker.DstPort(context.TestPort), 7358 checker.TCPSeqNum(uint32(c.IRS+1)), 7359 checker.TCPAckNum(uint32(iss)+2), 7360 checker.TCPFlags(header.TCPFlagAck))) 7361 7362 // Now write a few bytes and then close the endpoint. 7363 data := []byte{1, 2, 3} 7364 7365 var r bytes.Reader 7366 r.Reset(data) 7367 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 7368 t.Fatalf("Write failed: %s", err) 7369 } 7370 7371 // Check that data is received. 7372 b = c.GetPacket() 7373 checker.IPv4(t, b, 7374 checker.PayloadLen(len(data)+header.TCPMinimumSize), 7375 checker.TCP( 7376 checker.DstPort(context.TestPort), 7377 checker.TCPSeqNum(uint32(c.IRS)+1), 7378 checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 7379 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 7380 ), 7381 ) 7382 7383 if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { 7384 t.Errorf("got data = %x, want = %x", p, data) 7385 } 7386 7387 c.EP.Close() 7388 // Check the FIN. 7389 checker.IPv4(t, c.GetPacket(), checker.TCP( 7390 checker.SrcPort(context.StackPort), 7391 checker.DstPort(context.TestPort), 7392 checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), 7393 checker.TCPAckNum(uint32(iss+2)), 7394 checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) 7395 7396 // First send a partial ACK. 7397 ackHeaders = &context.Headers{ 7398 SrcPort: context.TestPort, 7399 DstPort: context.StackPort, 7400 Flags: header.TCPFlagAck, 7401 SeqNum: iss + 2, 7402 AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), 7403 RcvWnd: 30000, 7404 } 7405 c.SendPacket(nil, ackHeaders) 7406 7407 // Now send a full ACK. 7408 ackHeaders = &context.Headers{ 7409 SrcPort: context.TestPort, 7410 DstPort: context.StackPort, 7411 Flags: header.TCPFlagAck, 7412 SeqNum: iss + 2, 7413 AckNum: c.IRS + 1 + seqnum.Value(len(data)), 7414 RcvWnd: 30000, 7415 } 7416 c.SendPacket(nil, ackHeaders) 7417 7418 // Now ACK the FIN. 7419 ackHeaders.AckNum++ 7420 c.SendPacket(nil, ackHeaders) 7421 7422 // Now send an ACK and we should get a RST back as the endpoint should 7423 // be in CLOSED state. 7424 ackHeaders = &context.Headers{ 7425 SrcPort: context.TestPort, 7426 DstPort: context.StackPort, 7427 Flags: header.TCPFlagAck, 7428 SeqNum: iss + 2, 7429 AckNum: c.IRS + 1 + seqnum.Value(len(data)), 7430 RcvWnd: 30000, 7431 } 7432 c.SendPacket(nil, ackHeaders) 7433 7434 // Check the RST. 7435 checker.IPv4(t, c.GetPacket(), checker.TCP( 7436 checker.SrcPort(context.StackPort), 7437 checker.DstPort(context.TestPort), 7438 checker.TCPSeqNum(uint32(ackHeaders.AckNum)), 7439 checker.TCPAckNum(0), 7440 checker.TCPFlags(header.TCPFlagRst))) 7441 } 7442 7443 func TestTCPUserTimeout(t *testing.T) { 7444 c := context.New(t, defaultMTU) 7445 defer c.Cleanup() 7446 7447 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 7448 7449 waitEntry, notifyCh := waiter.NewChannelEntry(nil) 7450 c.WQ.EventRegister(&waitEntry, waiter.EventHUp) 7451 defer c.WQ.EventUnregister(&waitEntry) 7452 7453 origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() 7454 7455 // Ensure that on the next retransmit timer fire, the user timeout has 7456 // expired. 7457 initRTO := 1 * time.Second 7458 userTimeout := initRTO / 2 7459 v := tcpip.TCPUserTimeoutOption(userTimeout) 7460 if err := c.EP.SetSockOpt(&v); err != nil { 7461 t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) 7462 } 7463 7464 // Send some data and wait before ACKing it. 7465 view := make([]byte, 3) 7466 var r bytes.Reader 7467 r.Reset(view) 7468 if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { 7469 t.Fatalf("Write failed: %s", err) 7470 } 7471 7472 next := uint32(c.IRS) + 1 7473 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 7474 checker.IPv4(t, c.GetPacket(), 7475 checker.PayloadLen(len(view)+header.TCPMinimumSize), 7476 checker.TCP( 7477 checker.DstPort(context.TestPort), 7478 checker.TCPSeqNum(next), 7479 checker.TCPAckNum(uint32(iss)), 7480 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), 7481 ), 7482 ) 7483 7484 // Wait for the retransmit timer to be fired and the user timeout to cause 7485 // close of the connection. 7486 select { 7487 case <-notifyCh: 7488 case <-time.After(2 * initRTO): 7489 t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) 7490 } 7491 7492 // No packet should be received as the connection should be silently 7493 // closed due to timeout. 7494 c.CheckNoPacket("unexpected packet received after userTimeout has expired") 7495 7496 next += uint32(len(view)) 7497 7498 // The connection should be terminated after userTimeout has expired. 7499 // Send an ACK to trigger a RST from the stack as the endpoint should 7500 // be dead. 7501 c.SendPacket(nil, &context.Headers{ 7502 SrcPort: context.TestPort, 7503 DstPort: c.Port, 7504 Flags: header.TCPFlagAck, 7505 SeqNum: iss, 7506 AckNum: seqnum.Value(next), 7507 RcvWnd: 30000, 7508 }) 7509 7510 checker.IPv4(t, c.GetPacket(), 7511 checker.TCP( 7512 checker.DstPort(context.TestPort), 7513 checker.TCPSeqNum(next), 7514 checker.TCPAckNum(uint32(0)), 7515 checker.TCPFlags(header.TCPFlagRst), 7516 ), 7517 ) 7518 7519 ept := endpointTester{c.EP} 7520 ept.CheckReadError(t, &tcpip.ErrTimeout{}) 7521 7522 if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { 7523 t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) 7524 } 7525 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 7526 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 7527 } 7528 } 7529 7530 func TestKeepaliveWithUserTimeout(t *testing.T) { 7531 c := context.New(t, defaultMTU) 7532 defer c.Cleanup() 7533 7534 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) 7535 7536 origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() 7537 7538 const keepAliveIdle = 100 * time.Millisecond 7539 const keepAliveInterval = 3 * time.Second 7540 keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) 7541 if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { 7542 t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) 7543 } 7544 keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) 7545 if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { 7546 t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) 7547 } 7548 if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { 7549 t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) 7550 } 7551 c.EP.SocketOptions().SetKeepAlive(true) 7552 7553 // Set userTimeout to be the duration to be 1 keepalive 7554 // probes. Which means that after the first probe is sent 7555 // the second one should cause the connection to be 7556 // closed due to userTimeout being hit. 7557 userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) 7558 if err := c.EP.SetSockOpt(&userTimeout); err != nil { 7559 t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) 7560 } 7561 7562 // Check that the connection is still alive. 7563 ept := endpointTester{c.EP} 7564 ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) 7565 7566 // Now receive 1 keepalives, but don't ACK it. 7567 b := c.GetPacket() 7568 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 7569 checker.IPv4(t, b, 7570 checker.TCP( 7571 checker.DstPort(context.TestPort), 7572 checker.TCPSeqNum(uint32(c.IRS)), 7573 checker.TCPAckNum(uint32(iss)), 7574 checker.TCPFlags(header.TCPFlagAck), 7575 ), 7576 ) 7577 7578 // Sleep for a litte over the KeepAlive interval to make sure 7579 // the timer has time to fire after the last ACK and close the 7580 // close the socket. 7581 time.Sleep(keepAliveInterval + keepAliveInterval/2) 7582 7583 // The connection should be closed with a timeout. 7584 // Send an ACK to trigger a RST from the stack as the endpoint should 7585 // be dead. 7586 c.SendPacket(nil, &context.Headers{ 7587 SrcPort: context.TestPort, 7588 DstPort: c.Port, 7589 Flags: header.TCPFlagAck, 7590 SeqNum: iss, 7591 AckNum: c.IRS + 1, 7592 RcvWnd: 30000, 7593 }) 7594 7595 checker.IPv4(t, c.GetPacket(), 7596 checker.TCP( 7597 checker.DstPort(context.TestPort), 7598 checker.TCPSeqNum(uint32(c.IRS+1)), 7599 checker.TCPAckNum(uint32(0)), 7600 checker.TCPFlags(header.TCPFlagRst), 7601 ), 7602 ) 7603 7604 ept.CheckReadError(t, &tcpip.ErrTimeout{}) 7605 if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { 7606 t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) 7607 } 7608 if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { 7609 t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) 7610 } 7611 } 7612 7613 func TestIncreaseWindowOnRead(t *testing.T) { 7614 // This test ensures that the endpoint sends an ack, 7615 // after read() when the window grows by more than 1 MSS. 7616 c := context.New(t, defaultMTU) 7617 defer c.Cleanup() 7618 7619 const rcvBuf = 65535 * 10 7620 c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) 7621 7622 // Write chunks of ~30000 bytes. It's important that two 7623 // payloads make it equal or longer than MSS. 7624 remain := rcvBuf * 2 7625 sent := 0 7626 data := make([]byte, defaultMTU/2) 7627 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 7628 for remain > len(data) { 7629 c.SendPacket(data, &context.Headers{ 7630 SrcPort: context.TestPort, 7631 DstPort: c.Port, 7632 Flags: header.TCPFlagAck, 7633 SeqNum: iss.Add(seqnum.Size(sent)), 7634 AckNum: c.IRS.Add(1), 7635 RcvWnd: 30000, 7636 }) 7637 sent += len(data) 7638 remain -= len(data) 7639 pkt := c.GetPacket() 7640 checker.IPv4(t, pkt, 7641 checker.PayloadLen(header.TCPMinimumSize), 7642 checker.TCP( 7643 checker.DstPort(context.TestPort), 7644 checker.TCPSeqNum(uint32(c.IRS)+1), 7645 checker.TCPAckNum(uint32(iss)+uint32(sent)), 7646 checker.TCPFlags(header.TCPFlagAck), 7647 ), 7648 ) 7649 // Break once the window drops below defaultMTU/2 7650 if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { 7651 break 7652 } 7653 } 7654 7655 // We now have < 1 MSS in the buffer space. Read at least > 2 MSS 7656 // worth of data as receive buffer space 7657 w := tcpip.LimitedWriter{ 7658 W: ioutil.Discard, 7659 // defaultMTU is a good enough estimate for the MSS used for this 7660 // connection. 7661 N: defaultMTU * 2, 7662 } 7663 for w.N != 0 { 7664 _, err := c.EP.Read(&w, tcpip.ReadOptions{}) 7665 if err != nil { 7666 t.Fatalf("Read failed: %s", err) 7667 } 7668 } 7669 7670 // After reading > MSS worth of data, we surely crossed MSS. See the ack: 7671 checker.IPv4(t, c.GetPacket(), 7672 checker.PayloadLen(header.TCPMinimumSize), 7673 checker.TCP( 7674 checker.DstPort(context.TestPort), 7675 checker.TCPSeqNum(uint32(c.IRS)+1), 7676 checker.TCPAckNum(uint32(iss)+uint32(sent)), 7677 checker.TCPWindow(uint16(0xffff)), 7678 checker.TCPFlags(header.TCPFlagAck), 7679 ), 7680 ) 7681 } 7682 7683 func TestIncreaseWindowOnBufferResize(t *testing.T) { 7684 // This test ensures that the endpoint sends an ack, 7685 // after available recv buffer grows to more than 1 MSS. 7686 c := context.New(t, defaultMTU) 7687 defer c.Cleanup() 7688 7689 const rcvBuf = 65535 * 10 7690 c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) 7691 7692 // Write chunks of ~30000 bytes. It's important that two 7693 // payloads make it equal or longer than MSS. 7694 remain := rcvBuf 7695 sent := 0 7696 data := make([]byte, defaultMTU/2) 7697 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 7698 for remain > len(data) { 7699 c.SendPacket(data, &context.Headers{ 7700 SrcPort: context.TestPort, 7701 DstPort: c.Port, 7702 Flags: header.TCPFlagAck, 7703 SeqNum: iss.Add(seqnum.Size(sent)), 7704 AckNum: c.IRS.Add(1), 7705 RcvWnd: 30000, 7706 }) 7707 sent += len(data) 7708 remain -= len(data) 7709 checker.IPv4(t, c.GetPacket(), 7710 checker.PayloadLen(header.TCPMinimumSize), 7711 checker.TCP( 7712 checker.DstPort(context.TestPort), 7713 checker.TCPSeqNum(uint32(c.IRS)+1), 7714 checker.TCPAckNum(uint32(iss)+uint32(sent)), 7715 checker.TCPWindowLessThanEq(0xffff), 7716 checker.TCPFlags(header.TCPFlagAck), 7717 ), 7718 ) 7719 } 7720 7721 // Increasing the buffer from should generate an ACK, 7722 // since window grew from small value to larger equal MSS 7723 c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) 7724 checker.IPv4(t, c.GetPacket(), 7725 checker.PayloadLen(header.TCPMinimumSize), 7726 checker.TCP( 7727 checker.DstPort(context.TestPort), 7728 checker.TCPSeqNum(uint32(c.IRS)+1), 7729 checker.TCPAckNum(uint32(iss)+uint32(sent)), 7730 checker.TCPWindow(uint16(0xffff)), 7731 checker.TCPFlags(header.TCPFlagAck), 7732 ), 7733 ) 7734 } 7735 7736 func TestTCPDeferAccept(t *testing.T) { 7737 c := context.New(t, defaultMTU) 7738 defer c.Cleanup() 7739 7740 c.Create(-1) 7741 7742 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 7743 t.Fatal("Bind failed:", err) 7744 } 7745 7746 if err := c.EP.Listen(10); err != nil { 7747 t.Fatal("Listen failed:", err) 7748 } 7749 7750 const tcpDeferAccept = 1 * time.Second 7751 tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) 7752 if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { 7753 t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) 7754 } 7755 7756 irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) 7757 7758 _, _, err := c.EP.Accept(nil) 7759 if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { 7760 t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) 7761 } 7762 7763 // Send data. This should result in an acceptable endpoint. 7764 c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ 7765 SrcPort: context.TestPort, 7766 DstPort: context.StackPort, 7767 Flags: header.TCPFlagAck, 7768 SeqNum: irs + 1, 7769 AckNum: iss + 1, 7770 }) 7771 7772 // Receive ACK for the data we sent. 7773 checker.IPv4(t, c.GetPacket(), checker.TCP( 7774 checker.DstPort(context.TestPort), 7775 checker.TCPFlags(header.TCPFlagAck), 7776 checker.TCPSeqNum(uint32(iss+1)), 7777 checker.TCPAckNum(uint32(irs+5)))) 7778 7779 // Give a bit of time for the socket to be delivered to the accept queue. 7780 time.Sleep(50 * time.Millisecond) 7781 aep, _, err := c.EP.Accept(nil) 7782 if err != nil { 7783 t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) 7784 } 7785 7786 aep.Close() 7787 // Closing aep without reading the data should trigger a RST. 7788 checker.IPv4(t, c.GetPacket(), checker.TCP( 7789 checker.DstPort(context.TestPort), 7790 checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), 7791 checker.TCPSeqNum(uint32(iss+1)), 7792 checker.TCPAckNum(uint32(irs+5)))) 7793 } 7794 7795 func TestTCPDeferAcceptTimeout(t *testing.T) { 7796 c := context.New(t, defaultMTU) 7797 defer c.Cleanup() 7798 7799 c.Create(-1) 7800 7801 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 7802 t.Fatal("Bind failed:", err) 7803 } 7804 7805 if err := c.EP.Listen(10); err != nil { 7806 t.Fatal("Listen failed:", err) 7807 } 7808 7809 const tcpDeferAccept = 1 * time.Second 7810 tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) 7811 if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { 7812 t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) 7813 } 7814 7815 irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) 7816 7817 _, _, err := c.EP.Accept(nil) 7818 if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { 7819 t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) 7820 } 7821 7822 // Sleep for a little of the tcpDeferAccept timeout. 7823 time.Sleep(tcpDeferAccept + 100*time.Millisecond) 7824 7825 // On timeout expiry we should get a SYN-ACK retransmission. 7826 checker.IPv4(t, c.GetPacket(), checker.TCP( 7827 checker.SrcPort(context.StackPort), 7828 checker.DstPort(context.TestPort), 7829 checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), 7830 checker.TCPAckNum(uint32(irs)+1))) 7831 7832 // Send data. This should result in an acceptable endpoint. 7833 c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ 7834 SrcPort: context.TestPort, 7835 DstPort: context.StackPort, 7836 Flags: header.TCPFlagAck, 7837 SeqNum: irs + 1, 7838 AckNum: iss + 1, 7839 }) 7840 7841 // Receive ACK for the data we sent. 7842 checker.IPv4(t, c.GetPacket(), checker.TCP( 7843 checker.SrcPort(context.StackPort), 7844 checker.DstPort(context.TestPort), 7845 checker.TCPFlags(header.TCPFlagAck), 7846 checker.TCPSeqNum(uint32(iss+1)), 7847 checker.TCPAckNum(uint32(irs+5)))) 7848 7849 // Give sometime for the endpoint to be delivered to the accept queue. 7850 time.Sleep(50 * time.Millisecond) 7851 aep, _, err := c.EP.Accept(nil) 7852 if err != nil { 7853 t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) 7854 } 7855 7856 aep.Close() 7857 // Closing aep without reading the data should trigger a RST. 7858 checker.IPv4(t, c.GetPacket(), checker.TCP( 7859 checker.SrcPort(context.StackPort), 7860 checker.DstPort(context.TestPort), 7861 checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), 7862 checker.TCPSeqNum(uint32(iss+1)), 7863 checker.TCPAckNum(uint32(irs+5)))) 7864 } 7865 7866 func TestResetDuringClose(t *testing.T) { 7867 c := context.New(t, defaultMTU) 7868 defer c.Cleanup() 7869 7870 c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) 7871 // Send some data to make sure there is some unread 7872 // data to trigger a reset on c.Close. 7873 irs := c.IRS 7874 iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) 7875 c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ 7876 SrcPort: context.TestPort, 7877 DstPort: c.Port, 7878 Flags: header.TCPFlagAck, 7879 SeqNum: iss, 7880 AckNum: irs.Add(1), 7881 RcvWnd: 30000, 7882 }) 7883 7884 // Receive ACK for the data we sent. 7885 checker.IPv4(t, c.GetPacket(), checker.TCP( 7886 checker.DstPort(context.TestPort), 7887 checker.TCPFlags(header.TCPFlagAck), 7888 checker.TCPSeqNum(uint32(irs.Add(1))), 7889 checker.TCPAckNum(uint32(iss)+4))) 7890 7891 // Close in a separate goroutine so that we can trigger 7892 // a race with the RST we send below. This should not 7893 // panic due to the route being released depeding on 7894 // whether Close() sends an active RST or the RST sent 7895 // below is processed by the worker first. 7896 var wg sync.WaitGroup 7897 7898 wg.Add(1) 7899 go func() { 7900 defer wg.Done() 7901 c.SendPacket(nil, &context.Headers{ 7902 SrcPort: context.TestPort, 7903 DstPort: c.Port, 7904 SeqNum: iss.Add(4), 7905 AckNum: c.IRS.Add(5), 7906 RcvWnd: 30000, 7907 Flags: header.TCPFlagRst, 7908 }) 7909 }() 7910 7911 wg.Add(1) 7912 go func() { 7913 defer wg.Done() 7914 c.EP.Close() 7915 }() 7916 7917 wg.Wait() 7918 } 7919 7920 func TestStackTimeWaitReuse(t *testing.T) { 7921 c := context.New(t, defaultMTU) 7922 defer c.Cleanup() 7923 7924 s := c.Stack() 7925 var twReuse tcpip.TCPTimeWaitReuseOption 7926 if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { 7927 t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) 7928 } 7929 if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { 7930 t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) 7931 } 7932 } 7933 7934 func TestSetStackTimeWaitReuse(t *testing.T) { 7935 c := context.New(t, defaultMTU) 7936 defer c.Cleanup() 7937 7938 s := c.Stack() 7939 testCases := []struct { 7940 v int 7941 err tcpip.Error 7942 }{ 7943 {int(tcpip.TCPTimeWaitReuseDisabled), nil}, 7944 {int(tcpip.TCPTimeWaitReuseGlobal), nil}, 7945 {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, 7946 {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, 7947 {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, 7948 } 7949 7950 for _, tc := range testCases { 7951 opt := tcpip.TCPTimeWaitReuseOption(tc.v) 7952 err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) 7953 if got, want := err, tc.err; got != want { 7954 t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) 7955 } 7956 if tc.err != nil { 7957 continue 7958 } 7959 7960 var twReuse tcpip.TCPTimeWaitReuseOption 7961 if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { 7962 t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) 7963 } 7964 7965 if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { 7966 t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) 7967 } 7968 } 7969 } 7970 7971 // generateRandomPayload generates a random byte slice of the specified length 7972 // causing a fatal test failure if it is unable to do so. 7973 func generateRandomPayload(t *testing.T, n int) []byte { 7974 t.Helper() 7975 buf := make([]byte, n) 7976 if _, err := rand.Read(buf); err != nil { 7977 t.Fatalf("rand.Read(buf) failed: %s", err) 7978 } 7979 return buf 7980 }