github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/stack/conntrack.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "encoding/binary" 19 "fmt" 20 "sync" 21 "time" 22 23 "github.com/SagerNet/gvisor/pkg/tcpip" 24 "github.com/SagerNet/gvisor/pkg/tcpip/hash/jenkins" 25 "github.com/SagerNet/gvisor/pkg/tcpip/header" 26 "github.com/SagerNet/gvisor/pkg/tcpip/transport/tcpconntrack" 27 ) 28 29 // Connection tracking is used to track and manipulate packets for NAT rules. 30 // The connection is created for a packet if it does not exist. Every 31 // connection contains two tuples (original and reply). The tuples are 32 // manipulated if there is a matching NAT rule. The packet is modified by 33 // looking at the tuples in each hook. 34 // 35 // Currently, only TCP tracking is supported. 36 37 // Our hash table has 16K buckets. 38 const numBuckets = 1 << 14 39 40 // Direction of the tuple. 41 type direction int 42 43 const ( 44 dirOriginal direction = iota 45 dirReply 46 ) 47 48 // Manipulation type for the connection. 49 // TODO(github.com/SagerNet/issue/5696): Define this as a bit set and support SNAT and 50 // DNAT at the same time. 51 type manipType int 52 53 const ( 54 manipNone manipType = iota 55 manipSource 56 manipDestination 57 ) 58 59 // tuple holds a connection's identifying and manipulating data in one 60 // direction. It is immutable. 61 // 62 // +stateify savable 63 type tuple struct { 64 // tupleEntry is used to build an intrusive list of tuples. 65 tupleEntry 66 67 tupleID 68 69 // conn is the connection tracking entry this tuple belongs to. 70 conn *conn 71 72 // direction is the direction of the tuple. 73 direction direction 74 } 75 76 // tupleID uniquely identifies a connection in one direction. It currently 77 // contains enough information to distinguish between any TCP or UDP 78 // connection, and will need to be extended to support other protocols. 79 // 80 // +stateify savable 81 type tupleID struct { 82 srcAddr tcpip.Address 83 srcPort uint16 84 dstAddr tcpip.Address 85 dstPort uint16 86 transProto tcpip.TransportProtocolNumber 87 netProto tcpip.NetworkProtocolNumber 88 } 89 90 // reply creates the reply tupleID. 91 func (ti tupleID) reply() tupleID { 92 return tupleID{ 93 srcAddr: ti.dstAddr, 94 srcPort: ti.dstPort, 95 dstAddr: ti.srcAddr, 96 dstPort: ti.srcPort, 97 transProto: ti.transProto, 98 netProto: ti.netProto, 99 } 100 } 101 102 // conn is a tracked connection. 103 // 104 // +stateify savable 105 type conn struct { 106 // original is the tuple in original direction. It is immutable. 107 original tuple 108 109 // reply is the tuple in reply direction. It is immutable. 110 reply tuple 111 112 // manip indicates if the packet should be manipulated. It is immutable. 113 // TODO(github.com/SagerNet/issue/5696): Support updating manipulation type. 114 manip manipType 115 116 // tcbHook indicates if the packet is inbound or outbound to 117 // update the state of tcb. It is immutable. 118 tcbHook Hook 119 120 // mu protects all mutable state. 121 mu sync.Mutex `state:"nosave"` 122 // tcb is TCB control block. It is used to keep track of states 123 // of tcp connection and is protected by mu. 124 tcb tcpconntrack.TCB 125 // lastUsed is the last time the connection saw a relevant packet, and 126 // is updated by each packet on the connection. It is protected by mu. 127 // 128 // TODO(github.com/SagerNet/issue/5939): do not use the ambient clock. 129 lastUsed time.Time `state:".(unixTime)"` 130 } 131 132 // newConn creates new connection. 133 func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { 134 conn := conn{ 135 manip: manip, 136 tcbHook: hook, 137 lastUsed: time.Now(), 138 } 139 conn.original = tuple{conn: &conn, tupleID: orig} 140 conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} 141 return &conn 142 } 143 144 // timedOut returns whether the connection timed out based on its state. 145 func (cn *conn) timedOut(now time.Time) bool { 146 const establishedTimeout = 5 * 24 * time.Hour 147 const defaultTimeout = 120 * time.Second 148 cn.mu.Lock() 149 defer cn.mu.Unlock() 150 if cn.tcb.State() == tcpconntrack.ResultAlive { 151 // Use the same default as Linux, which doesn't delete 152 // established connections for 5(!) days. 153 return now.Sub(cn.lastUsed) > establishedTimeout 154 } 155 // Use the same default as Linux, which lets connections in most states 156 // other than established remain for <= 120 seconds. 157 return now.Sub(cn.lastUsed) > defaultTimeout 158 } 159 160 // update the connection tracking state. 161 // 162 // Precondition: cn.mu must be held. 163 func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { 164 // Update the state of tcb. tcb assumes it's always initialized on the 165 // client. However, we only need to know whether the connection is 166 // established or not, so the client/server distinction isn't important. 167 if cn.tcb.IsEmpty() { 168 cn.tcb.Init(tcpHeader) 169 } else if hook == cn.tcbHook { 170 cn.tcb.UpdateStateOutbound(tcpHeader) 171 } else { 172 cn.tcb.UpdateStateInbound(tcpHeader) 173 } 174 } 175 176 // ConnTrack tracks all connections created for NAT rules. Most users are 177 // expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. 178 // 179 // ConnTrack keeps all connections in a slice of buckets, each of which holds a 180 // linked list of tuples. This gives us some desirable properties: 181 // - Each bucket has its own lock, lessening lock contention. 182 // - The slice is large enough that lists stay short (<10 elements on average). 183 // Thus traversal is fast. 184 // - During linked list traversal we reap expired connections. This amortizes 185 // the cost of reaping them and makes reapUnused faster. 186 // 187 // Locks are ordered by their location in the buckets slice. That is, a 188 // goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. 189 // 190 // +stateify savable 191 type ConnTrack struct { 192 // seed is a one-time random value initialized at stack startup 193 // and is used in the calculation of hash keys for the list of buckets. 194 // It is immutable. 195 seed uint32 196 197 // mu protects the buckets slice, but not buckets' contents. Only take 198 // the write lock if you are modifying the slice or saving for S/R. 199 mu sync.RWMutex `state:"nosave"` 200 201 // buckets is protected by mu. 202 buckets []bucket 203 } 204 205 // +stateify savable 206 type bucket struct { 207 // mu protects tuples. 208 mu sync.Mutex `state:"nosave"` 209 tuples tupleList 210 } 211 212 // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid 213 // TCP header. 214 // 215 // Preconditions: pkt.NetworkHeader() is valid. 216 func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { 217 netHeader := pkt.Network() 218 if netHeader.TransportProtocol() != header.TCPProtocolNumber { 219 return tupleID{}, &tcpip.ErrUnknownProtocol{} 220 } 221 222 tcpHeader := header.TCP(pkt.TransportHeader().View()) 223 if len(tcpHeader) < header.TCPMinimumSize { 224 return tupleID{}, &tcpip.ErrUnknownProtocol{} 225 } 226 227 return tupleID{ 228 srcAddr: netHeader.SourceAddress(), 229 srcPort: tcpHeader.SourcePort(), 230 dstAddr: netHeader.DestinationAddress(), 231 dstPort: tcpHeader.DestinationPort(), 232 transProto: netHeader.TransportProtocol(), 233 netProto: pkt.NetworkProtocolNumber, 234 }, nil 235 } 236 237 func (ct *ConnTrack) init() { 238 ct.mu.Lock() 239 defer ct.mu.Unlock() 240 ct.buckets = make([]bucket, numBuckets) 241 } 242 243 // connFor gets the conn for pkt if it exists, or returns nil 244 // if it does not. It returns an error when pkt does not contain a valid TCP 245 // header. 246 // TODO(github.com/SagerNet/issue/6168): Support UDP. 247 func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { 248 tid, err := packetToTupleID(pkt) 249 if err != nil { 250 return nil, dirOriginal 251 } 252 return ct.connForTID(tid) 253 } 254 255 func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { 256 bucket := ct.bucket(tid) 257 now := time.Now() 258 259 ct.mu.RLock() 260 defer ct.mu.RUnlock() 261 ct.buckets[bucket].mu.Lock() 262 defer ct.buckets[bucket].mu.Unlock() 263 264 // Iterate over the tuples in a bucket, cleaning up any unused 265 // connections we find. 266 for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { 267 // Clean up any timed-out connections we happen to find. 268 if ct.reapTupleLocked(other, bucket, now) { 269 // The tuple expired. 270 continue 271 } 272 if tid == other.tupleID { 273 return other.conn, other.direction 274 } 275 } 276 277 return nil, dirOriginal 278 } 279 280 func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { 281 tid, err := packetToTupleID(pkt) 282 if err != nil { 283 return nil 284 } 285 if hook != Prerouting && hook != Output { 286 return nil 287 } 288 289 replyTID := tid.reply() 290 replyTID.srcAddr = address 291 replyTID.srcPort = port 292 293 conn, _ := ct.connForTID(tid) 294 if conn != nil { 295 // The connection is already tracked. 296 // TODO(github.com/SagerNet/issue/5696): Support updating an existing connection. 297 return nil 298 } 299 conn = newConn(tid, replyTID, manipDestination, hook) 300 ct.insertConn(conn) 301 return conn 302 } 303 304 func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { 305 tid, err := packetToTupleID(pkt) 306 if err != nil { 307 return nil 308 } 309 if hook != Input && hook != Postrouting { 310 return nil 311 } 312 313 replyTID := tid.reply() 314 replyTID.dstAddr = address 315 replyTID.dstPort = port 316 317 conn, _ := ct.connForTID(tid) 318 if conn != nil { 319 // The connection is already tracked. 320 // TODO(github.com/SagerNet/issue/5696): Support updating an existing connection. 321 return nil 322 } 323 conn = newConn(tid, replyTID, manipSource, hook) 324 ct.insertConn(conn) 325 return conn 326 } 327 328 // insertConn inserts conn into the appropriate table bucket. 329 func (ct *ConnTrack) insertConn(conn *conn) { 330 // Lock the buckets in the correct order. 331 tupleBucket := ct.bucket(conn.original.tupleID) 332 replyBucket := ct.bucket(conn.reply.tupleID) 333 ct.mu.RLock() 334 defer ct.mu.RUnlock() 335 if tupleBucket < replyBucket { 336 ct.buckets[tupleBucket].mu.Lock() 337 ct.buckets[replyBucket].mu.Lock() 338 } else if tupleBucket > replyBucket { 339 ct.buckets[replyBucket].mu.Lock() 340 ct.buckets[tupleBucket].mu.Lock() 341 } else { 342 // Both tuples are in the same bucket. 343 ct.buckets[tupleBucket].mu.Lock() 344 } 345 346 // Now that we hold the locks, ensure the tuple hasn't been inserted by 347 // another thread. 348 // TODO(github.com/SagerNet/issue/5773): Should check conn.reply.tupleID, too? 349 alreadyInserted := false 350 for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { 351 if other.tupleID == conn.original.tupleID { 352 alreadyInserted = true 353 break 354 } 355 } 356 357 if !alreadyInserted { 358 // Add the tuple to the map. 359 ct.buckets[tupleBucket].tuples.PushFront(&conn.original) 360 ct.buckets[replyBucket].tuples.PushFront(&conn.reply) 361 } 362 363 // Unlocking can happen in any order. 364 ct.buckets[tupleBucket].mu.Unlock() 365 if tupleBucket != replyBucket { 366 ct.buckets[replyBucket].mu.Unlock() // +checklocksforce 367 } 368 } 369 370 // handlePacket will manipulate the port and address of the packet if the 371 // connection exists. Returns whether, after the packet traverses the tables, 372 // it should create a new entry in the table. 373 func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { 374 if pkt.NatDone { 375 return false 376 } 377 378 switch hook { 379 case Prerouting, Input, Output, Postrouting: 380 default: 381 return false 382 } 383 384 // TODO(github.com/SagerNet/issue/6168): Support UDP. 385 if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { 386 return false 387 } 388 389 conn, dir := ct.connFor(pkt) 390 // Connection not found for the packet. 391 if conn == nil { 392 // If this is the last hook in the data path for this packet (Input if 393 // incoming, Postrouting if outgoing), indicate that a connection should be 394 // inserted by the end of this hook. 395 return hook == Input || hook == Postrouting 396 } 397 398 netHeader := pkt.Network() 399 tcpHeader := header.TCP(pkt.TransportHeader().View()) 400 if len(tcpHeader) < header.TCPMinimumSize { 401 return false 402 } 403 404 // TODO(github.com/SagerNet/issue/5748): TCP checksums on inbound packets should be 405 // validated if checksum offloading is off. It may require IP defrag if the 406 // packets are fragmented. 407 408 var newAddr tcpip.Address 409 var newPort uint16 410 411 updateSRCFields := false 412 413 switch hook { 414 case Prerouting, Output: 415 if conn.manip == manipDestination { 416 switch dir { 417 case dirOriginal: 418 newPort = conn.reply.srcPort 419 newAddr = conn.reply.srcAddr 420 case dirReply: 421 newPort = conn.original.dstPort 422 newAddr = conn.original.dstAddr 423 424 updateSRCFields = true 425 } 426 pkt.NatDone = true 427 } 428 case Input, Postrouting: 429 if conn.manip == manipSource { 430 switch dir { 431 case dirOriginal: 432 newPort = conn.reply.dstPort 433 newAddr = conn.reply.dstAddr 434 435 updateSRCFields = true 436 case dirReply: 437 newPort = conn.original.srcPort 438 newAddr = conn.original.srcAddr 439 } 440 pkt.NatDone = true 441 } 442 default: 443 panic(fmt.Sprintf("unrecognized hook = %s", hook)) 444 } 445 if !pkt.NatDone { 446 return false 447 } 448 449 fullChecksum := false 450 updatePseudoHeader := false 451 switch hook { 452 case Prerouting, Input: 453 case Output, Postrouting: 454 // Calculate the TCP checksum and set it. 455 if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { 456 updatePseudoHeader = true 457 } else if r.RequiresTXTransportChecksum() { 458 fullChecksum = true 459 updatePseudoHeader = true 460 } 461 default: 462 panic(fmt.Sprintf("unrecognized hook = %s", hook)) 463 } 464 465 rewritePacket( 466 netHeader, 467 tcpHeader, 468 updateSRCFields, 469 fullChecksum, 470 updatePseudoHeader, 471 newPort, 472 newAddr, 473 ) 474 475 // Update the state of tcb. 476 conn.mu.Lock() 477 defer conn.mu.Unlock() 478 479 // Mark the connection as having been used recently so it isn't reaped. 480 conn.lastUsed = time.Now() 481 // Update connection state. 482 conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) 483 484 return false 485 } 486 487 // maybeInsertNoop tries to insert a no-op connection entry to keep connections 488 // from getting clobbered when replies arrive. It only inserts if there isn't 489 // already a connection for pkt. 490 // 491 // This should be called after traversing iptables rules only, to ensure that 492 // pkt.NatDone is set correctly. 493 func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { 494 // If there were a rule applying to this packet, it would be marked 495 // with NatDone. 496 if pkt.NatDone { 497 return 498 } 499 500 // We only track TCP connections. 501 if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { 502 return 503 } 504 505 // This is the first packet we're seeing for the TCP connection. Insert 506 // the noop entry (an identity mapping) so that the response doesn't 507 // get NATed, breaking the connection. 508 tid, err := packetToTupleID(pkt) 509 if err != nil { 510 return 511 } 512 conn := newConn(tid, tid.reply(), manipNone, hook) 513 conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) 514 ct.insertConn(conn) 515 } 516 517 // bucket gets the conntrack bucket for a tupleID. 518 func (ct *ConnTrack) bucket(id tupleID) int { 519 h := jenkins.Sum32(ct.seed) 520 h.Write([]byte(id.srcAddr)) 521 h.Write([]byte(id.dstAddr)) 522 shortBuf := make([]byte, 2) 523 binary.LittleEndian.PutUint16(shortBuf, id.srcPort) 524 h.Write([]byte(shortBuf)) 525 binary.LittleEndian.PutUint16(shortBuf, id.dstPort) 526 h.Write([]byte(shortBuf)) 527 binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) 528 h.Write([]byte(shortBuf)) 529 binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) 530 h.Write([]byte(shortBuf)) 531 ct.mu.RLock() 532 defer ct.mu.RUnlock() 533 return int(h.Sum32()) % len(ct.buckets) 534 } 535 536 // reapUnused deletes timed out entries from the conntrack map. The rules for 537 // reaping are: 538 // - Most reaping occurs in connFor, which is called on each packet. connFor 539 // cleans up the bucket the packet's connection maps to. Thus calls to 540 // reapUnused should be fast. 541 // - Each call to reapUnused traverses a fraction of the conntrack table. 542 // Specifically, it traverses len(ct.buckets)/fractionPerReaping. 543 // - After reaping, reapUnused decides when it should next run based on the 544 // ratio of expired connections to examined connections. If the ratio is 545 // greater than maxExpiredPct, it schedules the next run quickly. Otherwise it 546 // slightly increases the interval between runs. 547 // - maxFullTraversal caps the time it takes to traverse the entire table. 548 // 549 // reapUnused returns the next bucket that should be checked and the time after 550 // which it should be called again. 551 func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { 552 const fractionPerReaping = 128 553 const maxExpiredPct = 50 554 const maxFullTraversal = 60 * time.Second 555 const minInterval = 10 * time.Millisecond 556 const maxInterval = maxFullTraversal / fractionPerReaping 557 558 now := time.Now() 559 checked := 0 560 expired := 0 561 var idx int 562 ct.mu.RLock() 563 defer ct.mu.RUnlock() 564 for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { 565 idx = (i + start) % len(ct.buckets) 566 ct.buckets[idx].mu.Lock() 567 for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { 568 checked++ 569 if ct.reapTupleLocked(tuple, idx, now) { 570 expired++ 571 } 572 } 573 ct.buckets[idx].mu.Unlock() 574 } 575 // We already checked buckets[idx]. 576 idx++ 577 578 // If half or more of the connections are expired, the table has gotten 579 // stale. Reschedule quickly. 580 expiredPct := 0 581 if checked != 0 { 582 expiredPct = expired * 100 / checked 583 } 584 if expiredPct > maxExpiredPct { 585 return idx, minInterval 586 } 587 if interval := prevInterval + minInterval; interval <= maxInterval { 588 // Increment the interval between runs. 589 return idx, interval 590 } 591 // We've hit the maximum interval. 592 return idx, maxInterval 593 } 594 595 // reapTupleLocked tries to remove tuple and its reply from the table. It 596 // returns whether the tuple's connection has timed out. 597 // 598 // Preconditions: 599 // * ct.mu is locked for reading. 600 // * bucket is locked. 601 func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { 602 if !tuple.conn.timedOut(now) { 603 return false 604 } 605 606 // To maintain lock order, we can only reap these tuples if the reply 607 // appears later in the table. 608 replyBucket := ct.bucket(tuple.reply()) 609 if bucket > replyBucket { 610 return true 611 } 612 613 // Don't re-lock if both tuples are in the same bucket. 614 differentBuckets := bucket != replyBucket 615 if differentBuckets { 616 ct.buckets[replyBucket].mu.Lock() 617 } 618 619 // We have the buckets locked and can remove both tuples. 620 if tuple.direction == dirOriginal { 621 ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) 622 } else { 623 ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) 624 } 625 ct.buckets[bucket].tuples.Remove(tuple) 626 627 // Don't re-unlock if both tuples are in the same bucket. 628 if differentBuckets { 629 ct.buckets[replyBucket].mu.Unlock() // +checklocksforce 630 } 631 632 return true 633 } 634 635 func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { 636 // Lookup the connection. The reply's original destination 637 // describes the original address. 638 tid := tupleID{ 639 srcAddr: epID.LocalAddress, 640 srcPort: epID.LocalPort, 641 dstAddr: epID.RemoteAddress, 642 dstPort: epID.RemotePort, 643 transProto: header.TCPProtocolNumber, 644 netProto: netProto, 645 } 646 conn, _ := ct.connForTID(tid) 647 if conn == nil { 648 // Not a tracked connection. 649 return "", 0, &tcpip.ErrNotConnected{} 650 } else if conn.manip != manipDestination { 651 // Unmanipulated destination. 652 return "", 0, &tcpip.ErrInvalidOptionValue{} 653 } 654 655 return conn.original.dstAddr, conn.original.dstPort, nil 656 }