github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/transport/tcp/accept.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "crypto/sha1" 19 "encoding/binary" 20 "hash" 21 "io" 22 "sync" 23 "time" 24 25 "github.com/FlowerWrong/netstack/rand" 26 "github.com/FlowerWrong/netstack/sleep" 27 "github.com/FlowerWrong/netstack/tcpip" 28 "github.com/FlowerWrong/netstack/tcpip/header" 29 "github.com/FlowerWrong/netstack/tcpip/seqnum" 30 "github.com/FlowerWrong/netstack/tcpip/stack" 31 "github.com/FlowerWrong/netstack/waiter" 32 ) 33 34 const ( 35 // tsLen is the length, in bits, of the timestamp in the SYN cookie. 36 tsLen = 8 37 38 // tsMask is a mask for timestamp values (i.e., tsLen bits). 39 tsMask = (1 << tsLen) - 1 40 41 // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. 42 tsOffset = 24 43 44 // hashMask is the mask for hash values (i.e., tsOffset bits). 45 hashMask = (1 << tsOffset) - 1 46 47 // maxTSDiff is the maximum allowed difference between a received cookie 48 // timestamp and the current timestamp. If the difference is greater 49 // than maxTSDiff, the cookie is expired. 50 maxTSDiff = 2 51 ) 52 53 var ( 54 // SynRcvdCountThreshold is the global maximum number of connections 55 // that are allowed to be in SYN-RCVD state before TCP starts using SYN 56 // cookies to accept connections. 57 // 58 // It is an exported variable only for testing, and should not otherwise 59 // be used by importers of this package. 60 SynRcvdCountThreshold uint64 = 1000 61 62 // mssTable is a slice containing the possible MSS values that we 63 // encode in the SYN cookie with two bits. 64 mssTable = []uint16{536, 1300, 1440, 1460} 65 ) 66 67 func encodeMSS(mss uint16) uint32 { 68 for i := len(mssTable) - 1; i > 0; i-- { 69 if mss >= mssTable[i] { 70 return uint32(i) 71 } 72 } 73 return 0 74 } 75 76 // syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is 77 // protected by a mutex so that we can increment only when it's guaranteed not 78 // to go above a threshold. 79 var synRcvdCount struct { 80 sync.Mutex 81 value uint64 82 pending sync.WaitGroup 83 } 84 85 // listenContext is used by a listening endpoint to store state used while 86 // listening for connections. This struct is allocated by the listen goroutine 87 // and must not be accessed or have its methods called concurrently as they 88 // may mutate the stored objects. 89 type listenContext struct { 90 stack *stack.Stack 91 rcvWnd seqnum.Size 92 nonce [2][sha1.BlockSize]byte 93 listenEP *endpoint 94 95 hasherMu sync.Mutex 96 hasher hash.Hash 97 v6only bool 98 netProto tcpip.NetworkProtocolNumber 99 // pendingMu protects pendingEndpoints. This should only be accessed 100 // by the listening endpoint's worker goroutine. 101 // 102 // Lock Ordering: listenEP.workerMu -> pendingMu 103 pendingMu sync.Mutex 104 // pending is used to wait for all pendingEndpoints to finish when 105 // a socket is closed. 106 pending sync.WaitGroup 107 // pendingEndpoints is a map of all endpoints for which a handshake is 108 // in progress. 109 pendingEndpoints map[stack.TransportEndpointID]*endpoint 110 } 111 112 // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. 113 func timeStamp() uint32 { 114 return uint32(time.Now().Unix()>>6) & tsMask 115 } 116 117 // incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD 118 // state. It succeeds if the increment doesn't make the count go beyond the 119 // threshold, and fails otherwise. 120 func incSynRcvdCount() bool { 121 synRcvdCount.Lock() 122 123 if synRcvdCount.value >= SynRcvdCountThreshold { 124 synRcvdCount.Unlock() 125 return false 126 } 127 128 synRcvdCount.pending.Add(1) 129 synRcvdCount.value++ 130 131 synRcvdCount.Unlock() 132 return true 133 } 134 135 // decSynRcvdCount atomically decrements the global number of endpoints in 136 // SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount 137 // succeeded. 138 func decSynRcvdCount() { 139 synRcvdCount.Lock() 140 141 synRcvdCount.value-- 142 synRcvdCount.pending.Done() 143 synRcvdCount.Unlock() 144 } 145 146 // synCookiesInUse() returns true if the synRcvdCount is greater than 147 // SynRcvdCountThreshold. 148 func synCookiesInUse() bool { 149 synRcvdCount.Lock() 150 v := synRcvdCount.value 151 synRcvdCount.Unlock() 152 return v >= SynRcvdCountThreshold 153 } 154 155 // newListenContext creates a new listen context. 156 func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { 157 l := &listenContext{ 158 stack: stk, 159 rcvWnd: rcvWnd, 160 hasher: sha1.New(), 161 v6only: v6only, 162 netProto: netProto, 163 listenEP: listenEP, 164 pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), 165 } 166 167 rand.Read(l.nonce[0][:]) 168 rand.Read(l.nonce[1][:]) 169 170 return l 171 } 172 173 // cookieHash calculates the cookieHash for the given id, timestamp and nonce 174 // index. The hash is used to create and validate cookies. 175 func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { 176 177 // Initialize block with fixed-size data: local ports and v. 178 var payload [8]byte 179 binary.BigEndian.PutUint16(payload[0:], id.LocalPort) 180 binary.BigEndian.PutUint16(payload[2:], id.RemotePort) 181 binary.BigEndian.PutUint32(payload[4:], ts) 182 183 // Feed everything to the hasher. 184 l.hasherMu.Lock() 185 l.hasher.Reset() 186 l.hasher.Write(payload[:]) 187 l.hasher.Write(l.nonce[nonceIndex][:]) 188 io.WriteString(l.hasher, string(id.LocalAddress)) 189 io.WriteString(l.hasher, string(id.RemoteAddress)) 190 191 // Finalize the calculation of the hash and return the first 4 bytes. 192 h := make([]byte, 0, sha1.Size) 193 h = l.hasher.Sum(h) 194 l.hasherMu.Unlock() 195 196 return binary.BigEndian.Uint32(h[:]) 197 } 198 199 // createCookie creates a SYN cookie for the given id and incoming sequence 200 // number. 201 func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { 202 ts := timeStamp() 203 v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) 204 v += (l.cookieHash(id, ts, 1) + data) & hashMask 205 return seqnum.Value(v) 206 } 207 208 // isCookieValid checks if the supplied cookie is valid for the given id and 209 // sequence number. If it is, it also returns the data originally encoded in the 210 // cookie when createCookie was called. 211 func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { 212 ts := timeStamp() 213 v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) 214 cookieTS := v >> tsOffset 215 if ((ts - cookieTS) & tsMask) > maxTSDiff { 216 return 0, false 217 } 218 219 return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true 220 } 221 222 // createConnectingEndpoint creates a new endpoint in a connecting state, with 223 // the connection parameters given by the arguments. 224 func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { 225 // Create a new endpoint. 226 netProto := l.netProto 227 if netProto == 0 { 228 netProto = s.route.NetProto 229 } 230 n := newEndpoint(l.stack, netProto, nil) 231 n.v6only = l.v6only 232 n.id = s.id 233 n.boundNICID = s.route.NICID() 234 n.route = s.route.Clone() 235 n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} 236 n.rcvBufSize = int(l.rcvWnd) 237 n.amss = mssForRoute(&n.route) 238 239 n.maybeEnableTimestamp(rcvdSynOpts) 240 n.maybeEnableSACKPermitted(rcvdSynOpts) 241 242 n.initGSO() 243 244 // Register new endpoint so that packets are routed to it. 245 if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil { 246 n.Close() 247 return nil, err 248 } 249 250 n.isRegistered = true 251 252 // Create sender and receiver. 253 // 254 // The receiver at least temporarily has a zero receive window scale, 255 // but the caller may change it (before starting the protocol loop). 256 n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) 257 n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize())) 258 // Bootstrap the auto tuning algorithm. Starting at zero will result in 259 // a large step function on the first window adjustment causing the 260 // window to grow to a really large value. 261 n.rcvAutoParams.prevCopied = n.initialReceiveWindow() 262 263 return n, nil 264 } 265 266 // createEndpoint creates a new endpoint in connected state and then performs 267 // the TCP 3-way handshake. 268 func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { 269 // Create new endpoint. 270 irs := s.sequenceNumber 271 cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) 272 ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) 273 if err != nil { 274 return nil, err 275 } 276 277 // listenEP is nil when listenContext is used by tcp.Forwarder. 278 if l.listenEP != nil { 279 l.listenEP.mu.Lock() 280 if l.listenEP.state != StateListen { 281 l.listenEP.mu.Unlock() 282 return nil, tcpip.ErrConnectionAborted 283 } 284 l.addPendingEndpoint(ep) 285 l.listenEP.mu.Unlock() 286 } 287 288 // Perform the 3-way handshake. 289 h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) 290 291 h.resetToSynRcvd(cookie, irs, opts) 292 if err := h.execute(); err != nil { 293 ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() 294 ep.Close() 295 if l.listenEP != nil { 296 l.removePendingEndpoint(ep) 297 } 298 return nil, err 299 } 300 ep.mu.Lock() 301 ep.state = StateEstablished 302 ep.mu.Unlock() 303 304 // Update the receive window scaling. We can't do it before the 305 // handshake because it's possible that the peer doesn't support window 306 // scaling. 307 ep.rcv.rcvWndScale = h.effectiveRcvWndScale() 308 309 return ep, nil 310 } 311 312 func (l *listenContext) addPendingEndpoint(n *endpoint) { 313 l.pendingMu.Lock() 314 l.pendingEndpoints[n.id] = n 315 l.pending.Add(1) 316 l.pendingMu.Unlock() 317 } 318 319 func (l *listenContext) removePendingEndpoint(n *endpoint) { 320 l.pendingMu.Lock() 321 delete(l.pendingEndpoints, n.id) 322 l.pending.Done() 323 l.pendingMu.Unlock() 324 } 325 326 func (l *listenContext) closeAllPendingEndpoints() { 327 l.pendingMu.Lock() 328 for _, n := range l.pendingEndpoints { 329 n.notifyProtocolGoroutine(notifyClose) 330 } 331 l.pendingMu.Unlock() 332 l.pending.Wait() 333 } 334 335 // deliverAccepted delivers the newly-accepted endpoint to the listener. If the 336 // endpoint has transitioned out of the listen state, the new endpoint is closed 337 // instead. 338 func (e *endpoint) deliverAccepted(n *endpoint) { 339 e.mu.Lock() 340 state := e.state 341 e.pendingAccepted.Add(1) 342 defer e.pendingAccepted.Done() 343 acceptedChan := e.acceptedChan 344 e.mu.Unlock() 345 if state == StateListen { 346 acceptedChan <- n 347 e.waiterQueue.Notify(waiter.EventIn) 348 } else { 349 n.Close() 350 } 351 } 352 353 // handleSynSegment is called in its own goroutine once the listening endpoint 354 // receives a SYN segment. It is responsible for completing the handshake and 355 // queueing the new endpoint for acceptance. 356 // 357 // A limited number of these goroutines are allowed before TCP starts using SYN 358 // cookies to accept connections. 359 func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { 360 defer decSynRcvdCount() 361 defer e.decSynRcvdCount() 362 defer s.decRef() 363 n, err := ctx.createEndpointAndPerformHandshake(s, opts) 364 if err != nil { 365 e.stack.Stats().TCP.FailedConnectionAttempts.Increment() 366 return 367 } 368 ctx.removePendingEndpoint(n) 369 e.deliverAccepted(n) 370 } 371 372 func (e *endpoint) incSynRcvdCount() bool { 373 e.mu.Lock() 374 if e.synRcvdCount >= cap(e.acceptedChan) { 375 e.mu.Unlock() 376 return false 377 } 378 e.synRcvdCount++ 379 e.mu.Unlock() 380 return true 381 } 382 383 func (e *endpoint) decSynRcvdCount() { 384 e.mu.Lock() 385 e.synRcvdCount-- 386 e.mu.Unlock() 387 } 388 389 func (e *endpoint) acceptQueueIsFull() bool { 390 e.mu.Lock() 391 if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { 392 e.mu.Unlock() 393 return true 394 } 395 e.mu.Unlock() 396 return false 397 } 398 399 // handleListenSegment is called when a listening endpoint receives a segment 400 // and needs to handle it. 401 func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { 402 switch s.flags { 403 case header.TCPFlagSyn: 404 opts := parseSynSegmentOptions(s) 405 if incSynRcvdCount() { 406 // Only handle the syn if the following conditions hold 407 // - accept queue is not full. 408 // - number of connections in synRcvd state is less than the 409 // backlog. 410 if !e.acceptQueueIsFull() && e.incSynRcvdCount() { 411 s.incRef() 412 go e.handleSynSegment(ctx, s, &opts) 413 return 414 } 415 decSynRcvdCount() 416 e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() 417 e.stack.Stats().DroppedPackets.Increment() 418 return 419 } else { 420 // If cookies are in use but the endpoint accept queue 421 // is full then drop the syn. 422 if e.acceptQueueIsFull() { 423 e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() 424 e.stack.Stats().DroppedPackets.Increment() 425 return 426 } 427 cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) 428 429 // Send SYN without window scaling because we currently 430 // dont't encode this information in the cookie. 431 // 432 // Enable Timestamp option if the original syn did have 433 // the timestamp option specified. 434 mss := mssForRoute(&s.route) 435 synOpts := header.TCPSynOptions{ 436 WS: -1, 437 TS: opts.TS, 438 TSVal: tcpTimeStamp(timeStampOffset()), 439 TSEcr: opts.TSVal, 440 MSS: uint16(mss), 441 } 442 sendSynTCP(&s.route, s.id, e.ttl, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) 443 e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() 444 } 445 446 case header.TCPFlagAck: 447 if e.acceptQueueIsFull() { 448 // Silently drop the ack as the application can't accept 449 // the connection at this point. The ack will be 450 // retransmitted by the sender anyway and we can 451 // complete the connection at the time of retransmit if 452 // the backlog has space. 453 e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() 454 e.stack.Stats().DroppedPackets.Increment() 455 return 456 } 457 458 if !synCookiesInUse() { 459 // Send a reset as this is an ACK for which there is no 460 // half open connections and we are not using cookies 461 // yet. 462 // 463 // The only time we should reach here when a connection 464 // was opened and closed really quickly and a delayed 465 // ACK was received from the sender. 466 replyWithReset(s) 467 return 468 } 469 470 // Since SYN cookies are in use this is potentially an ACK to a 471 // SYN-ACK we sent but don't have a half open connection state 472 // as cookies are being used to protect against a potential SYN 473 // flood. In such cases validate the cookie and if valid create 474 // a fully connected endpoint and deliver to the accept queue. 475 // 476 // If not, silently drop the ACK to avoid leaking information 477 // when under a potential syn flood attack. 478 // 479 // Validate the cookie. 480 data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) 481 if !ok || int(data) >= len(mssTable) { 482 e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() 483 e.stack.Stats().DroppedPackets.Increment() 484 return 485 } 486 e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() 487 // Create newly accepted endpoint and deliver it. 488 rcvdSynOptions := &header.TCPSynOptions{ 489 MSS: mssTable[data], 490 // Disable Window scaling as original SYN is 491 // lost. 492 WS: -1, 493 } 494 495 // When syn cookies are in use we enable timestamp only 496 // if the ack specifies the timestamp option assuming 497 // that the other end did in fact negotiate the 498 // timestamp option in the original SYN. 499 if s.parsedOptions.TS { 500 rcvdSynOptions.TS = true 501 rcvdSynOptions.TSVal = s.parsedOptions.TSVal 502 rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr 503 } 504 505 n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) 506 if err != nil { 507 e.stack.Stats().TCP.FailedConnectionAttempts.Increment() 508 return 509 } 510 511 // clear the tsOffset for the newly created 512 // endpoint as the Timestamp was already 513 // randomly offset when the original SYN-ACK was 514 // sent above. 515 n.tsOffset = 0 516 517 // Switch state to connected. 518 n.state = StateEstablished 519 520 // Do the delivery in a separate goroutine so 521 // that we don't block the listen loop in case 522 // the application is slow to accept or stops 523 // accepting. 524 // 525 // NOTE: This won't result in an unbounded 526 // number of goroutines as we do check before 527 // entering here that there was at least some 528 // space available in the backlog. 529 go e.deliverAccepted(n) 530 } 531 } 532 533 // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in 534 // its own goroutine and is responsible for handling connection requests. 535 func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { 536 e.mu.Lock() 537 v6only := e.v6only 538 e.mu.Unlock() 539 ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) 540 541 defer func() { 542 // Mark endpoint as closed. This will prevent goroutines running 543 // handleSynSegment() from attempting to queue new connections 544 // to the endpoint. 545 e.mu.Lock() 546 e.state = StateClose 547 548 // close any endpoints in SYN-RCVD state. 549 ctx.closeAllPendingEndpoints() 550 551 // Do cleanup if needed. 552 e.completeWorkerLocked() 553 554 if e.drainDone != nil { 555 close(e.drainDone) 556 } 557 e.mu.Unlock() 558 559 // Notify waiters that the endpoint is shutdown. 560 e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) 561 }() 562 563 s := sleep.Sleeper{} 564 s.AddWaker(&e.notificationWaker, wakerForNotification) 565 s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) 566 for { 567 switch index, _ := s.Fetch(true); index { 568 case wakerForNotification: 569 n := e.fetchNotifications() 570 if n¬ifyClose != 0 { 571 return nil 572 } 573 if n¬ifyDrain != 0 { 574 for !e.segmentQueue.empty() { 575 s := e.segmentQueue.dequeue() 576 e.handleListenSegment(ctx, s) 577 s.decRef() 578 } 579 close(e.drainDone) 580 <-e.undrain 581 } 582 583 case wakerForNewSegment: 584 // Process at most maxSegmentsPerWake segments. 585 mayRequeue := true 586 for i := 0; i < maxSegmentsPerWake; i++ { 587 s := e.segmentQueue.dequeue() 588 if s == nil { 589 mayRequeue = false 590 break 591 } 592 593 e.handleListenSegment(ctx, s) 594 s.decRef() 595 } 596 597 // If the queue is not empty, make sure we'll wake up 598 // in the next iteration. 599 if mayRequeue && !e.segmentQueue.empty() { 600 e.newSegmentWaker.Assert() 601 } 602 } 603 } 604 }