github.com/polevpn/netstack@v1.10.9/tcpip/transport/tcp/accept.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "crypto/sha1" 19 "encoding/binary" 20 "hash" 21 "io" 22 "sync" 23 "time" 24 25 "github.com/polevpn/netstack/rand" 26 "github.com/polevpn/netstack/sleep" 27 "github.com/polevpn/netstack/tcpip" 28 "github.com/polevpn/netstack/tcpip/buffer" 29 "github.com/polevpn/netstack/tcpip/header" 30 "github.com/polevpn/netstack/tcpip/seqnum" 31 "github.com/polevpn/netstack/tcpip/stack" 32 "github.com/polevpn/netstack/waiter" 33 ) 34 35 const ( 36 // tsLen is the length, in bits, of the timestamp in the SYN cookie. 37 tsLen = 8 38 39 // tsMask is a mask for timestamp values (i.e., tsLen bits). 40 tsMask = (1 << tsLen) - 1 41 42 // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. 43 tsOffset = 24 44 45 // hashMask is the mask for hash values (i.e., tsOffset bits). 46 hashMask = (1 << tsOffset) - 1 47 48 // maxTSDiff is the maximum allowed difference between a received cookie 49 // timestamp and the current timestamp. If the difference is greater 50 // than maxTSDiff, the cookie is expired. 51 maxTSDiff = 2 52 ) 53 54 var ( 55 // SynRcvdCountThreshold is the global maximum number of connections 56 // that are allowed to be in SYN-RCVD state before TCP starts using SYN 57 // cookies to accept connections. 58 // 59 // It is an exported variable only for testing, and should not otherwise 60 // be used by importers of this package. 61 SynRcvdCountThreshold uint64 = 1000 62 63 // mssTable is a slice containing the possible MSS values that we 64 // encode in the SYN cookie with two bits. 65 mssTable = []uint16{536, 1300, 1440, 1460} 66 ) 67 68 func encodeMSS(mss uint16) uint32 { 69 for i := len(mssTable) - 1; i > 0; i-- { 70 if mss >= mssTable[i] { 71 return uint32(i) 72 } 73 } 74 return 0 75 } 76 77 // syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is 78 // protected by a mutex so that we can increment only when it's guaranteed not 79 // to go above a threshold. 80 var synRcvdCount struct { 81 sync.Mutex 82 value uint64 83 pending sync.WaitGroup 84 } 85 86 // listenContext is used by a listening endpoint to store state used while 87 // listening for connections. This struct is allocated by the listen goroutine 88 // and must not be accessed or have its methods called concurrently as they 89 // may mutate the stored objects. 90 type listenContext struct { 91 stack *stack.Stack 92 rcvWnd seqnum.Size 93 nonce [2][sha1.BlockSize]byte 94 listenEP *endpoint 95 96 hasherMu sync.Mutex 97 hasher hash.Hash 98 v6only bool 99 netProto tcpip.NetworkProtocolNumber 100 // pendingMu protects pendingEndpoints. This should only be accessed 101 // by the listening endpoint's worker goroutine. 102 // 103 // Lock Ordering: listenEP.workerMu -> pendingMu 104 pendingMu sync.Mutex 105 // pending is used to wait for all pendingEndpoints to finish when 106 // a socket is closed. 107 pending sync.WaitGroup 108 // pendingEndpoints is a map of all endpoints for which a handshake is 109 // in progress. 110 pendingEndpoints map[stack.TransportEndpointID]*endpoint 111 } 112 113 // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. 114 func timeStamp() uint32 { 115 return uint32(time.Now().Unix()>>6) & tsMask 116 } 117 118 // incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD 119 // state. It succeeds if the increment doesn't make the count go beyond the 120 // threshold, and fails otherwise. 121 func incSynRcvdCount() bool { 122 synRcvdCount.Lock() 123 124 if synRcvdCount.value >= SynRcvdCountThreshold { 125 synRcvdCount.Unlock() 126 return false 127 } 128 129 synRcvdCount.pending.Add(1) 130 synRcvdCount.value++ 131 132 synRcvdCount.Unlock() 133 return true 134 } 135 136 // decSynRcvdCount atomically decrements the global number of endpoints in 137 // SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount 138 // succeeded. 139 func decSynRcvdCount() { 140 synRcvdCount.Lock() 141 142 synRcvdCount.value-- 143 synRcvdCount.pending.Done() 144 synRcvdCount.Unlock() 145 } 146 147 // synCookiesInUse() returns true if the synRcvdCount is greater than 148 // SynRcvdCountThreshold. 149 func synCookiesInUse() bool { 150 synRcvdCount.Lock() 151 v := synRcvdCount.value 152 synRcvdCount.Unlock() 153 return v >= SynRcvdCountThreshold 154 } 155 156 // newListenContext creates a new listen context. 157 func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { 158 l := &listenContext{ 159 stack: stk, 160 rcvWnd: rcvWnd, 161 hasher: sha1.New(), 162 v6only: v6only, 163 netProto: netProto, 164 listenEP: listenEP, 165 pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), 166 } 167 168 rand.Read(l.nonce[0][:]) 169 rand.Read(l.nonce[1][:]) 170 171 return l 172 } 173 174 // cookieHash calculates the cookieHash for the given id, timestamp and nonce 175 // index. The hash is used to create and validate cookies. 176 func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { 177 178 // Initialize block with fixed-size data: local ports and v. 179 var payload [8]byte 180 binary.BigEndian.PutUint16(payload[0:], id.LocalPort) 181 binary.BigEndian.PutUint16(payload[2:], id.RemotePort) 182 binary.BigEndian.PutUint32(payload[4:], ts) 183 184 // Feed everything to the hasher. 185 l.hasherMu.Lock() 186 l.hasher.Reset() 187 l.hasher.Write(payload[:]) 188 l.hasher.Write(l.nonce[nonceIndex][:]) 189 io.WriteString(l.hasher, string(id.LocalAddress)) 190 io.WriteString(l.hasher, string(id.RemoteAddress)) 191 192 // Finalize the calculation of the hash and return the first 4 bytes. 193 h := make([]byte, 0, sha1.Size) 194 h = l.hasher.Sum(h) 195 l.hasherMu.Unlock() 196 197 return binary.BigEndian.Uint32(h[:]) 198 } 199 200 // createCookie creates a SYN cookie for the given id and incoming sequence 201 // number. 202 func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { 203 ts := timeStamp() 204 v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) 205 v += (l.cookieHash(id, ts, 1) + data) & hashMask 206 return seqnum.Value(v) 207 } 208 209 // isCookieValid checks if the supplied cookie is valid for the given id and 210 // sequence number. If it is, it also returns the data originally encoded in the 211 // cookie when createCookie was called. 212 func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { 213 ts := timeStamp() 214 v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) 215 cookieTS := v >> tsOffset 216 if ((ts - cookieTS) & tsMask) > maxTSDiff { 217 return 0, false 218 } 219 220 return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true 221 } 222 223 // createConnectingEndpoint creates a new endpoint in a connecting state, with 224 // the connection parameters given by the arguments. 225 func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { 226 // Create a new endpoint. 227 netProto := l.netProto 228 if netProto == 0 { 229 netProto = s.route.NetProto 230 } 231 n := newEndpoint(l.stack, netProto, nil) 232 n.v6only = l.v6only 233 n.ID = s.id 234 n.boundNICID = s.route.NICID() 235 n.route = s.route.Clone() 236 n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} 237 n.rcvBufSize = int(l.rcvWnd) 238 n.amss = mssForRoute(&n.route) 239 240 n.maybeEnableTimestamp(rcvdSynOpts) 241 n.maybeEnableSACKPermitted(rcvdSynOpts) 242 243 n.initGSO() 244 245 // Register new endpoint so that packets are routed to it. 246 if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { 247 n.Close() 248 return nil, err 249 } 250 251 n.isRegistered = true 252 253 // Create sender and receiver. 254 // 255 // The receiver at least temporarily has a zero receive window scale, 256 // but the caller may change it (before starting the protocol loop). 257 n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) 258 n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize())) 259 // Bootstrap the auto tuning algorithm. Starting at zero will result in 260 // a large step function on the first window adjustment causing the 261 // window to grow to a really large value. 262 n.rcvAutoParams.prevCopied = n.initialReceiveWindow() 263 264 return n, nil 265 } 266 267 // createEndpoint creates a new endpoint in connected state and then performs 268 // the TCP 3-way handshake. 269 func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { 270 // Create new endpoint. 271 irs := s.sequenceNumber 272 isn := generateSecureISN(s.id, l.stack.Seed()) 273 ep, err := l.createConnectingEndpoint(s, isn, irs, opts) 274 if err != nil { 275 return nil, err 276 } 277 278 // listenEP is nil when listenContext is used by tcp.Forwarder. 279 if l.listenEP != nil { 280 l.listenEP.mu.Lock() 281 if l.listenEP.state != StateListen { 282 l.listenEP.mu.Unlock() 283 return nil, tcpip.ErrConnectionAborted 284 } 285 l.addPendingEndpoint(ep) 286 l.listenEP.mu.Unlock() 287 } 288 289 // Perform the 3-way handshake. 290 h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) 291 292 h.resetToSynRcvd(isn, irs, opts) 293 if err := h.execute(); err != nil { 294 ep.Close() 295 if l.listenEP != nil { 296 l.removePendingEndpoint(ep) 297 } 298 return nil, err 299 } 300 ep.mu.Lock() 301 ep.stack.Stats().TCP.CurrentEstablished.Increment() 302 ep.state = StateEstablished 303 ep.isConnectNotified = true 304 ep.mu.Unlock() 305 306 // Update the receive window scaling. We can't do it before the 307 // handshake because it's possible that the peer doesn't support window 308 // scaling. 309 ep.rcv.rcvWndScale = h.effectiveRcvWndScale() 310 311 return ep, nil 312 } 313 314 func (l *listenContext) addPendingEndpoint(n *endpoint) { 315 l.pendingMu.Lock() 316 l.pendingEndpoints[n.ID] = n 317 l.pending.Add(1) 318 l.pendingMu.Unlock() 319 } 320 321 func (l *listenContext) removePendingEndpoint(n *endpoint) { 322 l.pendingMu.Lock() 323 delete(l.pendingEndpoints, n.ID) 324 l.pending.Done() 325 l.pendingMu.Unlock() 326 } 327 328 func (l *listenContext) closeAllPendingEndpoints() { 329 l.pendingMu.Lock() 330 for _, n := range l.pendingEndpoints { 331 n.notifyProtocolGoroutine(notifyClose) 332 } 333 l.pendingMu.Unlock() 334 l.pending.Wait() 335 } 336 337 // deliverAccepted delivers the newly-accepted endpoint to the listener. If the 338 // endpoint has transitioned out of the listen state, the new endpoint is closed 339 // instead. 340 func (e *endpoint) deliverAccepted(n *endpoint) { 341 e.mu.Lock() 342 state := e.state 343 e.pendingAccepted.Add(1) 344 defer e.pendingAccepted.Done() 345 acceptedChan := e.acceptedChan 346 e.mu.Unlock() 347 if state == StateListen { 348 acceptedChan <- n 349 e.waiterQueue.Notify(waiter.EventIn) 350 } else { 351 n.Close() 352 } 353 } 354 355 // handleSynSegment is called in its own goroutine once the listening endpoint 356 // receives a SYN segment. It is responsible for completing the handshake and 357 // queueing the new endpoint for acceptance. 358 // 359 // A limited number of these goroutines are allowed before TCP starts using SYN 360 // cookies to accept connections. 361 func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { 362 defer decSynRcvdCount() 363 defer e.decSynRcvdCount() 364 defer s.decRef() 365 366 n, err := ctx.createEndpointAndPerformHandshake(s, opts) 367 if err != nil { 368 e.stack.Stats().TCP.FailedConnectionAttempts.Increment() 369 e.stats.FailedConnectionAttempts.Increment() 370 return 371 } 372 ctx.removePendingEndpoint(n) 373 // Start the protocol goroutine. 374 wq := &waiter.Queue{} 375 n.startAcceptedLoop(wq) 376 e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() 377 378 e.deliverAccepted(n) 379 } 380 381 func (e *endpoint) incSynRcvdCount() bool { 382 e.mu.Lock() 383 if e.synRcvdCount >= cap(e.acceptedChan) { 384 e.mu.Unlock() 385 return false 386 } 387 e.synRcvdCount++ 388 e.mu.Unlock() 389 return true 390 } 391 392 func (e *endpoint) decSynRcvdCount() { 393 e.mu.Lock() 394 e.synRcvdCount-- 395 e.mu.Unlock() 396 } 397 398 func (e *endpoint) acceptQueueIsFull() bool { 399 e.mu.Lock() 400 if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { 401 e.mu.Unlock() 402 return true 403 } 404 e.mu.Unlock() 405 return false 406 } 407 408 // handleListenSegment is called when a listening endpoint receives a segment 409 // and needs to handle it. 410 func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { 411 if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { 412 // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST 413 // must be sent in response to a SYN-ACK while in the listen 414 // state to prevent completing a handshake from an old SYN. 415 e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) 416 return 417 } 418 419 // TODO(b/143300739): Use the userMSS of the listening socket 420 // for accepted sockets. 421 422 switch { 423 case (s.flags & header.TCPFlagSyn) != 0: 424 opts := parseSynSegmentOptions(s) 425 if incSynRcvdCount() { 426 // Only handle the syn if the following conditions hold 427 // - accept queue is not full. 428 // - number of connections in synRcvd state is less than the 429 // backlog. 430 if !e.acceptQueueIsFull() && e.incSynRcvdCount() { 431 s.incRef() 432 go e.handleSynSegment(ctx, s, &opts) 433 return 434 } 435 decSynRcvdCount() 436 e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() 437 e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() 438 e.stack.Stats().DroppedPackets.Increment() 439 return 440 } else { 441 // If cookies are in use but the endpoint accept queue 442 // is full then drop the syn. 443 if e.acceptQueueIsFull() { 444 e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() 445 e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() 446 e.stack.Stats().DroppedPackets.Increment() 447 return 448 } 449 cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) 450 451 // Send SYN without window scaling because we currently 452 // dont't encode this information in the cookie. 453 // 454 // Enable Timestamp option if the original syn did have 455 // the timestamp option specified. 456 synOpts := header.TCPSynOptions{ 457 WS: -1, 458 TS: opts.TS, 459 TSVal: tcpTimeStamp(timeStampOffset()), 460 TSEcr: opts.TSVal, 461 MSS: mssForRoute(&s.route), 462 } 463 e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) 464 e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() 465 } 466 467 case (s.flags & header.TCPFlagAck) != 0: 468 if e.acceptQueueIsFull() { 469 // Silently drop the ack as the application can't accept 470 // the connection at this point. The ack will be 471 // retransmitted by the sender anyway and we can 472 // complete the connection at the time of retransmit if 473 // the backlog has space. 474 e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() 475 e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() 476 e.stack.Stats().DroppedPackets.Increment() 477 return 478 } 479 480 if !synCookiesInUse() { 481 // When not using SYN cookies, as per RFC 793, section 3.9, page 64: 482 // Any acknowledgment is bad if it arrives on a connection still in 483 // the LISTEN state. An acceptable reset segment should be formed 484 // for any arriving ACK-bearing segment. The RST should be 485 // formatted as follows: 486 // 487 // <SEQ=SEG.ACK><CTL=RST> 488 // 489 // Send a reset as this is an ACK for which there is no 490 // half open connections and we are not using cookies 491 // yet. 492 // 493 // The only time we should reach here when a connection 494 // was opened and closed really quickly and a delayed 495 // ACK was received from the sender. 496 replyWithReset(s) 497 return 498 } 499 500 // Since SYN cookies are in use this is potentially an ACK to a 501 // SYN-ACK we sent but don't have a half open connection state 502 // as cookies are being used to protect against a potential SYN 503 // flood. In such cases validate the cookie and if valid create 504 // a fully connected endpoint and deliver to the accept queue. 505 // 506 // If not, silently drop the ACK to avoid leaking information 507 // when under a potential syn flood attack. 508 // 509 // Validate the cookie. 510 data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) 511 if !ok || int(data) >= len(mssTable) { 512 e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() 513 e.stack.Stats().DroppedPackets.Increment() 514 return 515 } 516 e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() 517 // Create newly accepted endpoint and deliver it. 518 rcvdSynOptions := &header.TCPSynOptions{ 519 MSS: mssTable[data], 520 // Disable Window scaling as original SYN is 521 // lost. 522 WS: -1, 523 } 524 525 // When syn cookies are in use we enable timestamp only 526 // if the ack specifies the timestamp option assuming 527 // that the other end did in fact negotiate the 528 // timestamp option in the original SYN. 529 if s.parsedOptions.TS { 530 rcvdSynOptions.TS = true 531 rcvdSynOptions.TSVal = s.parsedOptions.TSVal 532 rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr 533 } 534 535 n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) 536 if err != nil { 537 e.stack.Stats().TCP.FailedConnectionAttempts.Increment() 538 e.stats.FailedConnectionAttempts.Increment() 539 return 540 } 541 542 // clear the tsOffset for the newly created 543 // endpoint as the Timestamp was already 544 // randomly offset when the original SYN-ACK was 545 // sent above. 546 n.tsOffset = 0 547 548 // Switch state to connected. 549 n.stack.Stats().TCP.CurrentEstablished.Increment() 550 n.state = StateEstablished 551 n.isConnectNotified = true 552 553 // Do the delivery in a separate goroutine so 554 // that we don't block the listen loop in case 555 // the application is slow to accept or stops 556 // accepting. 557 // 558 // NOTE: This won't result in an unbounded 559 // number of goroutines as we do check before 560 // entering here that there was at least some 561 // space available in the backlog. 562 563 // Start the protocol goroutine. 564 wq := &waiter.Queue{} 565 n.startAcceptedLoop(wq) 566 e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() 567 go e.deliverAccepted(n) 568 } 569 } 570 571 // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in 572 // its own goroutine and is responsible for handling connection requests. 573 func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { 574 e.mu.Lock() 575 v6only := e.v6only 576 e.mu.Unlock() 577 ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) 578 579 defer func() { 580 // Mark endpoint as closed. This will prevent goroutines running 581 // handleSynSegment() from attempting to queue new connections 582 // to the endpoint. 583 e.mu.Lock() 584 e.state = StateClose 585 586 // close any endpoints in SYN-RCVD state. 587 ctx.closeAllPendingEndpoints() 588 589 // Do cleanup if needed. 590 e.completeWorkerLocked() 591 592 if e.drainDone != nil { 593 close(e.drainDone) 594 } 595 e.mu.Unlock() 596 597 // Notify waiters that the endpoint is shutdown. 598 e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) 599 }() 600 601 s := sleep.Sleeper{} 602 s.AddWaker(&e.notificationWaker, wakerForNotification) 603 s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) 604 for { 605 switch index, _ := s.Fetch(true); index { 606 case wakerForNotification: 607 n := e.fetchNotifications() 608 if n¬ifyClose != 0 { 609 return nil 610 } 611 if n¬ifyDrain != 0 { 612 for !e.segmentQueue.empty() { 613 s := e.segmentQueue.dequeue() 614 e.handleListenSegment(ctx, s) 615 s.decRef() 616 } 617 close(e.drainDone) 618 <-e.undrain 619 } 620 621 case wakerForNewSegment: 622 // Process at most maxSegmentsPerWake segments. 623 mayRequeue := true 624 for i := 0; i < maxSegmentsPerWake; i++ { 625 s := e.segmentQueue.dequeue() 626 if s == nil { 627 mayRequeue = false 628 break 629 } 630 631 e.handleListenSegment(ctx, s) 632 s.decRef() 633 } 634 635 // If the queue is not empty, make sure we'll wake up 636 // in the next iteration. 637 if mayRequeue && !e.segmentQueue.empty() { 638 e.newSegmentWaker.Assert() 639 } 640 } 641 } 642 }