github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/stack/conntrack.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "encoding/binary" 19 "fmt" 20 "math" 21 "math/rand" 22 "sync" 23 "time" 24 25 "github.com/sagernet/gvisor/pkg/atomicbitops" 26 "github.com/sagernet/gvisor/pkg/tcpip" 27 "github.com/sagernet/gvisor/pkg/tcpip/hash/jenkins" 28 "github.com/sagernet/gvisor/pkg/tcpip/header" 29 "github.com/sagernet/gvisor/pkg/tcpip/transport/tcpconntrack" 30 ) 31 32 // Connection tracking is used to track and manipulate packets for NAT rules. 33 // The connection is created for a packet if it does not exist. Every 34 // connection contains two tuples (original and reply). The tuples are 35 // manipulated if there is a matching NAT rule. The packet is modified by 36 // looking at the tuples in each hook. 37 // 38 // Currently, only TCP tracking is supported. 39 40 // Our hash table has 16K buckets. 41 const numBuckets = 1 << 14 42 43 const ( 44 establishedTimeout time.Duration = 5 * 24 * time.Hour 45 unestablishedTimeout time.Duration = 120 * time.Second 46 ) 47 48 // tuple holds a connection's identifying and manipulating data in one 49 // direction. It is immutable. 50 // 51 // +stateify savable 52 type tuple struct { 53 // tupleEntry is used to build an intrusive list of tuples. 54 tupleEntry 55 56 // conn is the connection tracking entry this tuple belongs to. 57 conn *conn 58 59 // reply is true iff the tuple's direction is opposite that of the first 60 // packet seen on the connection. 61 reply bool 62 63 // tupleID is set at initialization and is immutable. 64 tupleID tupleID 65 } 66 67 // tupleID uniquely identifies a trackable connection in one direction. 68 // 69 // +stateify savable 70 type tupleID struct { 71 srcAddr tcpip.Address 72 // The source port of a packet in the original direction is overloaded with 73 // the ident of an Echo Request packet. 74 // 75 // This also matches the behaviour of sending packets on Linux where the 76 // socket's source port value is used for the source port of outgoing packets 77 // for TCP/UDP and the ident field for outgoing Echo Requests on Ping sockets: 78 // 79 // IPv4: https://github.com/torvalds/linux/blob/c5c17547b778975b3d83a73c8d84e8fb5ecf3ba5/net/ipv4/ping.c#L810 80 // IPv6: https://github.com/torvalds/linux/blob/c5c17547b778975b3d83a73c8d84e8fb5ecf3ba5/net/ipv6/ping.c#L133 81 srcPortOrEchoRequestIdent uint16 82 dstAddr tcpip.Address 83 // The opposite of srcPortOrEchoRequestIdent; the destination port of a packet 84 // in the reply direction is overloaded with the ident of an Echo Reply. 85 dstPortOrEchoReplyIdent uint16 86 transProto tcpip.TransportProtocolNumber 87 netProto tcpip.NetworkProtocolNumber 88 } 89 90 // reply creates the reply tupleID. 91 func (ti tupleID) reply() tupleID { 92 return tupleID{ 93 srcAddr: ti.dstAddr, 94 srcPortOrEchoRequestIdent: ti.dstPortOrEchoReplyIdent, 95 dstAddr: ti.srcAddr, 96 dstPortOrEchoReplyIdent: ti.srcPortOrEchoRequestIdent, 97 transProto: ti.transProto, 98 netProto: ti.netProto, 99 } 100 } 101 102 type manipType int 103 104 const ( 105 // manipNotPerformed indicates that NAT has not been performed. 106 manipNotPerformed manipType = iota 107 108 // manipPerformed indicates that NAT was performed. 109 manipPerformed 110 111 // manipPerformedNoop indicates that NAT was performed but it was a no-op. 112 manipPerformedNoop 113 ) 114 115 type finalizeResult uint32 116 117 const ( 118 // A finalizeResult must be explicitly set so we don't make use of the zero 119 // value. 120 _ finalizeResult = iota 121 122 finalizeResultSuccess 123 finalizeResultConflict 124 ) 125 126 // conn is a tracked connection. 127 // 128 // +stateify savable 129 type conn struct { 130 ct *ConnTrack 131 132 // original is the tuple in original direction. It is immutable. 133 original tuple 134 135 // reply is the tuple in reply direction. 136 reply tuple 137 138 finalizeOnce sync.Once 139 // Holds a finalizeResult. 140 finalizeResult atomicbitops.Uint32 141 142 mu connRWMutex `state:"nosave"` 143 // sourceManip indicates the source manipulation type. 144 // 145 // +checklocks:mu 146 sourceManip manipType 147 // destinationManip indicates the destination's manipulation type. 148 // 149 // +checklocks:mu 150 destinationManip manipType 151 152 stateMu stateConnRWMutex `state:"nosave"` 153 // tcb is TCB control block. It is used to keep track of states 154 // of tcp connection. 155 // 156 // +checklocks:stateMu 157 tcb tcpconntrack.TCB 158 // lastUsed is the last time the connection saw a relevant packet, and 159 // is updated by each packet on the connection. 160 // 161 // +checklocks:stateMu 162 lastUsed tcpip.MonotonicTime 163 } 164 165 // timedOut returns whether the connection timed out based on its state. 166 func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { 167 cn.stateMu.RLock() 168 defer cn.stateMu.RUnlock() 169 if cn.tcb.State() == tcpconntrack.ResultAlive { 170 // Use the same default as Linux, which doesn't delete 171 // established connections for 5(!) days. 172 return now.Sub(cn.lastUsed) > establishedTimeout 173 } 174 // Use the same default as Linux, which lets connections in most states 175 // other than established remain for <= 120 seconds. 176 return now.Sub(cn.lastUsed) > unestablishedTimeout 177 } 178 179 // update the connection tracking state. 180 func (cn *conn) update(pkt *PacketBuffer, reply bool) { 181 cn.stateMu.Lock() 182 defer cn.stateMu.Unlock() 183 184 // Mark the connection as having been used recently so it isn't reaped. 185 cn.lastUsed = cn.ct.clock.NowMonotonic() 186 187 if pkt.TransportProtocolNumber != header.TCPProtocolNumber { 188 return 189 } 190 191 tcpHeader := header.TCP(pkt.TransportHeader().Slice()) 192 193 // Update the state of tcb. tcb assumes it's always initialized on the 194 // client. However, we only need to know whether the connection is 195 // established or not, so the client/server distinction isn't important. 196 if cn.tcb.IsEmpty() { 197 cn.tcb.Init(tcpHeader, pkt.Data().Size()) 198 return 199 } 200 201 if reply { 202 cn.tcb.UpdateStateReply(tcpHeader, pkt.Data().Size()) 203 } else { 204 cn.tcb.UpdateStateOriginal(tcpHeader, pkt.Data().Size()) 205 } 206 } 207 208 // ConnTrack tracks all connections created for NAT rules. Most users are 209 // expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. 210 // 211 // ConnTrack keeps all connections in a slice of buckets, each of which holds a 212 // linked list of tuples. This gives us some desirable properties: 213 // - Each bucket has its own lock, lessening lock contention. 214 // - The slice is large enough that lists stay short (<10 elements on average). 215 // Thus traversal is fast. 216 // - During linked list traversal we reap expired connections. This amortizes 217 // the cost of reaping them and makes reapUnused faster. 218 // 219 // Locks are ordered by their location in the buckets slice. That is, a 220 // goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. 221 // 222 // +stateify savable 223 type ConnTrack struct { 224 // seed is a one-time random value initialized at stack startup 225 // and is used in the calculation of hash keys for the list of buckets. 226 // It is immutable. 227 seed uint32 228 229 // clock provides timing used to determine conntrack reapings. 230 clock tcpip.Clock 231 rand *rand.Rand 232 233 mu connTrackRWMutex `state:"nosave"` 234 // mu protects the buckets slice, but not buckets' contents. Only take 235 // the write lock if you are modifying the slice or saving for S/R. 236 // 237 // +checklocks:mu 238 buckets []bucket 239 } 240 241 // +stateify savable 242 type bucket struct { 243 mu bucketRWMutex `state:"nosave"` 244 // +checklocks:mu 245 tuples tupleList 246 } 247 248 // A netAndTransHeadersFunc returns the network and transport headers found 249 // in an ICMP payload. The transport layer's payload will not be returned. 250 // 251 // May panic if the packet does not hold the transport header. 252 type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) 253 254 func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { 255 netHdr := header.IPv4(icmpPayload) 256 // Do not use netHdr.Payload() as we might not hold the full packet 257 // in the ICMP error; Payload() panics if the buffer is smaller than 258 // the total length specified in the IPv4 header. 259 transHdr := icmpPayload[netHdr.HeaderLength():] 260 return netHdr, transHdr[:minTransHdrLen] 261 } 262 263 func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { 264 netHdr := header.IPv6(icmpPayload) 265 // Do not use netHdr.Payload() as we might not hold the full packet 266 // in the ICMP error; Payload() panics if the IP payload is smaller than 267 // the payload length specified in the IPv6 header. 268 transHdr := icmpPayload[header.IPv6MinimumSize:] 269 return netHdr, transHdr[:minTransHdrLen] 270 } 271 272 func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { 273 switch transProto { 274 case header.TCPProtocolNumber: 275 if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { 276 netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) 277 return netHeader, header.TCP(transHeaderBytes), true 278 } 279 case header.UDPProtocolNumber: 280 if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { 281 netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) 282 return netHeader, header.UDP(transHeaderBytes), true 283 } 284 } 285 return nil, nil, false 286 } 287 288 func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Transport, isICMPError bool, ok bool) { 289 switch pkt.TransportProtocolNumber { 290 case header.TCPProtocolNumber: 291 if tcpHeader := header.TCP(pkt.TransportHeader().Slice()); len(tcpHeader) >= header.TCPMinimumSize { 292 return pkt.Network(), tcpHeader, false, true 293 } 294 return nil, nil, false, false 295 case header.UDPProtocolNumber: 296 if udpHeader := header.UDP(pkt.TransportHeader().Slice()); len(udpHeader) >= header.UDPMinimumSize { 297 return pkt.Network(), udpHeader, false, true 298 } 299 return nil, nil, false, false 300 case header.ICMPv4ProtocolNumber: 301 icmpHeader := header.ICMPv4(pkt.TransportHeader().Slice()) 302 if len(icmpHeader) < header.ICMPv4MinimumSize { 303 return nil, nil, false, false 304 } 305 306 switch icmpType := icmpHeader.Type(); icmpType { 307 case header.ICMPv4Echo, header.ICMPv4EchoReply: 308 return pkt.Network(), icmpHeader, false, true 309 case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem: 310 default: 311 panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType)) 312 } 313 314 h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) 315 if !ok { 316 panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize)) 317 } 318 319 if header.IPv4(h).HeaderLength() > header.IPv4MinimumSize { 320 // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. 321 panic("should have dropped packets with IPv4 options") 322 } 323 324 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.tupleID.transProto); ok { 325 return netHdr, transHdr, true, true 326 } 327 return nil, nil, false, false 328 case header.ICMPv6ProtocolNumber: 329 icmpHeader := header.ICMPv6(pkt.TransportHeader().Slice()) 330 if len(icmpHeader) < header.ICMPv6MinimumSize { 331 return nil, nil, false, false 332 } 333 334 switch icmpType := icmpHeader.Type(); icmpType { 335 case header.ICMPv6EchoRequest, header.ICMPv6EchoReply: 336 return pkt.Network(), icmpHeader, false, true 337 case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem: 338 default: 339 panic(fmt.Sprintf("unexpected ICMPv6 type = %d", icmpType)) 340 } 341 342 h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) 343 if !ok { 344 panic(fmt.Sprintf("should have a valid IPv6 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv6MinimumSize)) 345 } 346 347 // We do not support extension headers in ICMP errors so the next header 348 // in the IPv6 packet should be a tracked protocol if we reach this point. 349 // 350 // TODO(https://gvisor.dev/issue/6789): Support extension headers. 351 transProto := pkt.tuple.tupleID.transProto 352 if got := header.IPv6(h).TransportProtocol(); got != transProto { 353 panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) 354 } 355 356 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { 357 return netHdr, transHdr, true, true 358 } 359 return nil, nil, false, false 360 default: 361 panic(fmt.Sprintf("unexpected transport protocol = %d", pkt.TransportProtocolNumber)) 362 } 363 } 364 365 func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID { 366 return tupleID{ 367 srcAddr: netHdr.SourceAddress(), 368 srcPortOrEchoRequestIdent: transHdr.SourcePort(), 369 dstAddr: netHdr.DestinationAddress(), 370 dstPortOrEchoReplyIdent: transHdr.DestinationPort(), 371 transProto: transProto, 372 netProto: netProto, 373 } 374 } 375 376 func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { 377 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { 378 return tupleID{ 379 srcAddr: netHdr.DestinationAddress(), 380 srcPortOrEchoRequestIdent: transHdr.DestinationPort(), 381 dstAddr: netHdr.SourceAddress(), 382 dstPortOrEchoReplyIdent: transHdr.SourcePort(), 383 transProto: transProto, 384 netProto: netProto, 385 }, true 386 } 387 388 return tupleID{}, false 389 } 390 391 type getTupleIDDisposition int 392 393 const ( 394 getTupleIDNotOK getTupleIDDisposition = iota 395 getTupleIDOKAndAllowNewConn 396 getTupleIDOKAndDontAllowNewConn 397 ) 398 399 func getTupleIDForEchoPacket(pkt *PacketBuffer, ident uint16, request bool) tupleID { 400 netHdr := pkt.Network() 401 tid := tupleID{ 402 srcAddr: netHdr.SourceAddress(), 403 dstAddr: netHdr.DestinationAddress(), 404 transProto: pkt.TransportProtocolNumber, 405 netProto: pkt.NetworkProtocolNumber, 406 } 407 408 if request { 409 tid.srcPortOrEchoRequestIdent = ident 410 } else { 411 tid.dstPortOrEchoReplyIdent = ident 412 } 413 414 return tid 415 } 416 417 func getTupleID(pkt *PacketBuffer) (tupleID, getTupleIDDisposition) { 418 switch pkt.TransportProtocolNumber { 419 case header.TCPProtocolNumber: 420 if transHeader := header.TCP(pkt.TransportHeader().Slice()); len(transHeader) >= header.TCPMinimumSize { 421 return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), getTupleIDOKAndAllowNewConn 422 } 423 case header.UDPProtocolNumber: 424 if transHeader := header.UDP(pkt.TransportHeader().Slice()); len(transHeader) >= header.UDPMinimumSize { 425 return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), getTupleIDOKAndAllowNewConn 426 } 427 case header.ICMPv4ProtocolNumber: 428 icmp := header.ICMPv4(pkt.TransportHeader().Slice()) 429 if len(icmp) < header.ICMPv4MinimumSize { 430 return tupleID{}, getTupleIDNotOK 431 } 432 433 switch icmp.Type() { 434 case header.ICMPv4Echo: 435 return getTupleIDForEchoPacket(pkt, icmp.Ident(), true /* request */), getTupleIDOKAndAllowNewConn 436 case header.ICMPv4EchoReply: 437 // Do not create a new connection in response to a reply packet as only 438 // the first packet of a connection should create a conntrack entry but 439 // a reply is never the first packet sent for a connection. 440 return getTupleIDForEchoPacket(pkt, icmp.Ident(), false /* request */), getTupleIDOKAndDontAllowNewConn 441 case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem: 442 default: 443 // Unsupported ICMP type for NAT-ing. 444 return tupleID{}, getTupleIDNotOK 445 } 446 447 h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) 448 if !ok { 449 return tupleID{}, getTupleIDNotOK 450 } 451 452 ipv4 := header.IPv4(h) 453 if ipv4.HeaderLength() > header.IPv4MinimumSize { 454 // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. 455 return tupleID{}, getTupleIDNotOK 456 } 457 458 if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { 459 // Do not create a new connection in response to an ICMP error. 460 return tid, getTupleIDOKAndDontAllowNewConn 461 } 462 case header.ICMPv6ProtocolNumber: 463 icmp := header.ICMPv6(pkt.TransportHeader().Slice()) 464 if len(icmp) < header.ICMPv6MinimumSize { 465 return tupleID{}, getTupleIDNotOK 466 } 467 468 switch icmp.Type() { 469 case header.ICMPv6EchoRequest: 470 return getTupleIDForEchoPacket(pkt, icmp.Ident(), true /* request */), getTupleIDOKAndAllowNewConn 471 case header.ICMPv6EchoReply: 472 // Do not create a new connection in response to a reply packet as only 473 // the first packet of a connection should create a conntrack entry but 474 // a reply is never the first packet sent for a connection. 475 return getTupleIDForEchoPacket(pkt, icmp.Ident(), false /* request */), getTupleIDOKAndDontAllowNewConn 476 case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem: 477 default: 478 return tupleID{}, getTupleIDNotOK 479 } 480 481 h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) 482 if !ok { 483 return tupleID{}, getTupleIDNotOK 484 } 485 486 // TODO(https://gvisor.dev/issue/6789): Handle extension headers. 487 if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { 488 // Do not create a new connection in response to an ICMP error. 489 return tid, getTupleIDOKAndDontAllowNewConn 490 } 491 } 492 493 return tupleID{}, getTupleIDNotOK 494 } 495 496 func (ct *ConnTrack) init() { 497 ct.mu.Lock() 498 defer ct.mu.Unlock() 499 ct.buckets = make([]bucket, numBuckets) 500 } 501 502 // getConnAndUpdate attempts to get a connection or creates one if no 503 // connection exists for the packet and packet's protocol is trackable. 504 // 505 // If the packet's protocol is trackable, the connection's state is updated to 506 // match the contents of the packet. 507 func (ct *ConnTrack) getConnAndUpdate(pkt *PacketBuffer, skipChecksumValidation bool) *tuple { 508 // Get or (maybe) create a connection. 509 t := func() *tuple { 510 var allowNewConn bool 511 tid, res := getTupleID(pkt) 512 switch res { 513 case getTupleIDNotOK: 514 return nil 515 case getTupleIDOKAndAllowNewConn: 516 allowNewConn = true 517 case getTupleIDOKAndDontAllowNewConn: 518 allowNewConn = false 519 default: 520 panic(fmt.Sprintf("unhandled %[1]T = %[1]d", res)) 521 } 522 523 // Just skip bad packets. They'll be rejected later by the appropriate 524 // protocol package. 525 switch pkt.TransportProtocolNumber { 526 case header.TCPProtocolNumber: 527 _, csumValid, ok := header.TCPValid( 528 header.TCP(pkt.TransportHeader().Slice()), 529 func() uint16 { return pkt.Data().Checksum() }, 530 uint16(pkt.Data().Size()), 531 tid.srcAddr, 532 tid.dstAddr, 533 pkt.RXChecksumValidated || skipChecksumValidation) 534 if !csumValid || !ok { 535 return nil 536 } 537 case header.UDPProtocolNumber: 538 lengthValid, csumValid := header.UDPValid( 539 header.UDP(pkt.TransportHeader().Slice()), 540 func() uint16 { return pkt.Data().Checksum() }, 541 uint16(pkt.Data().Size()), 542 pkt.NetworkProtocolNumber, 543 tid.srcAddr, 544 tid.dstAddr, 545 pkt.RXChecksumValidated || skipChecksumValidation) 546 if !lengthValid || !csumValid { 547 return nil 548 } 549 } 550 551 ct.mu.RLock() 552 bkt := &ct.buckets[ct.bucket(tid)] 553 ct.mu.RUnlock() 554 555 now := ct.clock.NowMonotonic() 556 if t := bkt.connForTID(tid, now); t != nil { 557 return t 558 } 559 560 if !allowNewConn { 561 return nil 562 } 563 564 bkt.mu.Lock() 565 defer bkt.mu.Unlock() 566 567 // Make sure a connection wasn't added between when we last checked the 568 // bucket and acquired the bucket's write lock. 569 if t := bkt.connForTIDRLocked(tid, now); t != nil { 570 return t 571 } 572 573 // This is the first packet we're seeing for the connection. Create an entry 574 // for this new connection. 575 conn := &conn{ 576 ct: ct, 577 original: tuple{tupleID: tid}, 578 reply: tuple{tupleID: tid.reply(), reply: true}, 579 lastUsed: now, 580 } 581 conn.original.conn = conn 582 conn.reply.conn = conn 583 584 // For now, we only map an entry for the packet's original tuple as NAT may be 585 // performed on this connection. Until the packet goes through all the hooks 586 // and its final address/port is known, we cannot know what the response 587 // packet's addresses/ports will look like. 588 // 589 // This is okay because the destination cannot send its response until it 590 // receives the packet; the packet will only be received once all the hooks 591 // have been performed. 592 // 593 // See (*conn).finalize. 594 bkt.tuples.PushFront(&conn.original) 595 return &conn.original 596 }() 597 if t != nil { 598 t.conn.update(pkt, t.reply) 599 } 600 return t 601 } 602 603 func (ct *ConnTrack) connForTID(tid tupleID) *tuple { 604 ct.mu.RLock() 605 bkt := &ct.buckets[ct.bucket(tid)] 606 ct.mu.RUnlock() 607 608 return bkt.connForTID(tid, ct.clock.NowMonotonic()) 609 } 610 611 func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple { 612 bkt.mu.RLock() 613 defer bkt.mu.RUnlock() 614 return bkt.connForTIDRLocked(tid, now) 615 } 616 617 // +checklocksread:bkt.mu 618 func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple { 619 for other := bkt.tuples.Front(); other != nil; other = other.Next() { 620 if tid == other.tupleID && !other.conn.timedOut(now) { 621 return other 622 } 623 } 624 return nil 625 } 626 627 func (ct *ConnTrack) finalize(cn *conn) finalizeResult { 628 ct.mu.RLock() 629 buckets := ct.buckets 630 ct.mu.RUnlock() 631 632 { 633 tid := cn.reply.tupleID 634 id := ct.bucketWithTableLength(tid, len(buckets)) 635 636 bkt := &buckets[id] 637 bkt.mu.Lock() 638 t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()) 639 if t == nil { 640 bkt.tuples.PushFront(&cn.reply) 641 bkt.mu.Unlock() 642 return finalizeResultSuccess 643 } 644 bkt.mu.Unlock() 645 646 if t.conn == cn { 647 // We already have an entry for the reply tuple. 648 // 649 // This can occur when the source address/port is the same as the 650 // destination address/port. In this scenario, tid == tid.reply(). 651 return finalizeResultSuccess 652 } 653 } 654 655 // Another connection for the reply already exists. Remove the original and 656 // let the caller know we failed. 657 // 658 // TODO(https://gvisor.dev/issue/6850): Investigate handling this clash 659 // better. 660 661 tid := cn.original.tupleID 662 id := ct.bucketWithTableLength(tid, len(buckets)) 663 bkt := &buckets[id] 664 bkt.mu.Lock() 665 defer bkt.mu.Unlock() 666 bkt.tuples.Remove(&cn.original) 667 return finalizeResultConflict 668 } 669 670 func (cn *conn) getFinalizeResult() finalizeResult { 671 return finalizeResult(cn.finalizeResult.Load()) 672 } 673 674 // finalize attempts to finalize the connection and returns true iff the 675 // connection was successfully finalized. 676 // 677 // If the connection failed to finalize, the caller should drop the packet 678 // associated with the connection. 679 // 680 // If multiple goroutines attempt to finalize at the same time, only one 681 // goroutine will perform the work to finalize the connection, but all 682 // goroutines will block until the finalizing goroutine finishes finalizing. 683 func (cn *conn) finalize() bool { 684 cn.finalizeOnce.Do(func() { 685 cn.finalizeResult.Store(uint32(cn.ct.finalize(cn))) 686 }) 687 688 switch res := cn.getFinalizeResult(); res { 689 case finalizeResultSuccess: 690 return true 691 case finalizeResultConflict: 692 return false 693 default: 694 panic(fmt.Sprintf("unhandled result = %d", res)) 695 } 696 } 697 698 // If NAT has not been configured for this connection, either mark the 699 // connection as configured for "no-op NAT", in the case of DNAT, or, in the 700 // case of SNAT, perform source port remapping so that source ports used by 701 // locally-generated traffic do not conflict with ports occupied by existing NAT 702 // bindings. 703 // 704 // Note that in the typical case this is also a no-op, because `snatAction` 705 // will do nothing if the original tuple is already unique. 706 func (cn *conn) maybePerformNoopNAT(pkt *PacketBuffer, hook Hook, r *Route, dnat bool) { 707 cn.mu.Lock() 708 var manip *manipType 709 if dnat { 710 manip = &cn.destinationManip 711 } else { 712 manip = &cn.sourceManip 713 } 714 if *manip != manipNotPerformed { 715 cn.mu.Unlock() 716 _ = cn.handlePacket(pkt, hook, r) 717 return 718 } 719 if dnat { 720 *manip = manipPerformedNoop 721 cn.mu.Unlock() 722 _ = cn.handlePacket(pkt, hook, r) 723 return 724 } 725 cn.mu.Unlock() 726 727 // At this point, we know that NAT has not yet been performed on this 728 // connection, and the DNAT case has been handled with a no-op. For SNAT, we 729 // simply perform source port remapping to ensure that source ports for 730 // locally generated traffic do not clash with ports used by existing NAT 731 // bindings. 732 _, _ = snatAction(pkt, hook, r, 0, tcpip.Address{}, true /* changePort */, false /* changeAddress */) 733 } 734 735 type portOrIdentRange struct { 736 start uint16 737 size uint32 738 } 739 740 // performNAT setups up the connection for the specified NAT and rewrites the 741 // packet. 742 // 743 // If NAT has already been performed on the connection, then the packet will 744 // be rewritten with the NAT performed on the connection, ignoring the passed 745 // address and port range. 746 // 747 // Generally, only the first packet of a connection reaches this method; other 748 // packets will be manipulated without needing to modify the connection. 749 func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, portsOrIdents portOrIdentRange, natAddress tcpip.Address, dnat, changePort, changeAddress bool) { 750 lastPortOrIdent := func() uint16 { 751 lastPortOrIdent := uint32(portsOrIdents.start) + portsOrIdents.size - 1 752 if lastPortOrIdent > math.MaxUint16 { 753 panic(fmt.Sprintf("got lastPortOrIdent = %d, want <= MaxUint16(=%d); portsOrIdents=%#v", lastPortOrIdent, math.MaxUint16, portsOrIdents)) 754 } 755 return uint16(lastPortOrIdent) 756 }() 757 758 // Make sure the packet is re-written after performing NAT. 759 defer func() { 760 // handlePacket returns true if the packet may skip the NAT table as the 761 // connection is already NATed, but if we reach this point we must be in the 762 // NAT table, so the return value is useless for us. 763 _ = cn.handlePacket(pkt, hook, r) 764 }() 765 766 cn.mu.Lock() 767 defer cn.mu.Unlock() 768 769 var manip *manipType 770 var address *tcpip.Address 771 var portOrIdent *uint16 772 if dnat { 773 manip = &cn.destinationManip 774 address = &cn.reply.tupleID.srcAddr 775 portOrIdent = &cn.reply.tupleID.srcPortOrEchoRequestIdent 776 } else { 777 manip = &cn.sourceManip 778 address = &cn.reply.tupleID.dstAddr 779 portOrIdent = &cn.reply.tupleID.dstPortOrEchoReplyIdent 780 } 781 782 if *manip != manipNotPerformed { 783 return 784 } 785 *manip = manipPerformed 786 if changeAddress { 787 *address = natAddress 788 } 789 790 // Everything below here is port-fiddling. 791 if !changePort { 792 return 793 } 794 795 // Does the current port/ident fit in the range? 796 if portsOrIdents.start <= *portOrIdent && *portOrIdent <= lastPortOrIdent { 797 // Yes, is the current reply tuple unique? 798 // 799 // Or, does the reply tuple refer to the same connection as the current one that 800 // we are NATing? This would apply, for example, to a self-connected socket, 801 // where the original and reply tuples are identical. 802 other := cn.ct.connForTID(cn.reply.tupleID) 803 if other == nil || other.conn == cn { 804 // Yes! No need to change the port. 805 return 806 } 807 } 808 809 // Try our best to find a port/ident that results in a unique reply tuple. 810 // 811 // We limit the number of attempts to find a unique tuple to not waste a lot 812 // of time looking for a unique tuple. 813 // 814 // Matches linux behaviour introduced in 815 // https://github.com/torvalds/linux/commit/a504b703bb1da526a01593da0e4be2af9d9f5fa8. 816 const maxAttemptsForInitialRound uint32 = 128 817 const minAttemptsToContinue = 16 818 819 allowedInitialAttempts := maxAttemptsForInitialRound 820 if allowedInitialAttempts > portsOrIdents.size { 821 allowedInitialAttempts = portsOrIdents.size 822 } 823 824 for maxAttempts := allowedInitialAttempts; ; maxAttempts /= 2 { 825 // Start reach round with a random initial port/ident offset. 826 randOffset := cn.ct.rand.Uint32() 827 828 for i := uint32(0); i < maxAttempts; i++ { 829 newPortOrIdentU32 := uint32(portsOrIdents.start) + (randOffset+i)%portsOrIdents.size 830 if newPortOrIdentU32 > math.MaxUint16 { 831 panic(fmt.Sprintf("got newPortOrIdentU32 = %d, want <= MaxUint16(=%d); portsOrIdents=%#v, randOffset=%d", newPortOrIdentU32, math.MaxUint16, portsOrIdents, randOffset)) 832 } 833 834 *portOrIdent = uint16(newPortOrIdentU32) 835 836 if other := cn.ct.connForTID(cn.reply.tupleID); other == nil { 837 // We found a unique tuple! 838 return 839 } 840 } 841 842 if maxAttempts == portsOrIdents.size { 843 // We already tried all the ports/idents in the range so no need to keep 844 // trying. 845 return 846 } 847 848 if maxAttempts < minAttemptsToContinue { 849 return 850 } 851 } 852 853 // We did not find a unique tuple, use the last used port anyways. 854 // TODO(https://gvisor.dev/issue/6850): Handle not finding a unique tuple 855 // better (e.g. remove the connection and drop the packet). 856 } 857 858 // handlePacket attempts to handle a packet and perform NAT if the connection 859 // has had NAT performed on it. 860 // 861 // Returns true if the packet can skip the NAT table. 862 func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { 863 netHdr, transHdr, isICMPError, ok := getHeaders(pkt) 864 if !ok { 865 return false 866 } 867 868 fullChecksum := false 869 updatePseudoHeader := false 870 natDone := &pkt.snatDone 871 dnat := false 872 switch hook { 873 case Prerouting: 874 // Packet came from outside the stack so it must have a checksum set 875 // already. 876 fullChecksum = true 877 updatePseudoHeader = true 878 879 natDone = &pkt.dnatDone 880 dnat = true 881 case Input: 882 case Forward: 883 panic("should not handle packet in the forwarding hook") 884 case Output: 885 natDone = &pkt.dnatDone 886 dnat = true 887 fallthrough 888 case Postrouting: 889 if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { 890 updatePseudoHeader = true 891 } else if rt.RequiresTXTransportChecksum() { 892 fullChecksum = true 893 updatePseudoHeader = true 894 } 895 default: 896 panic(fmt.Sprintf("unrecognized hook = %d", hook)) 897 } 898 899 if *natDone { 900 panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt)) 901 } 902 903 // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be 904 // validated if checksum offloading is off. It may require IP defrag if the 905 // packets are fragmented. 906 907 reply := pkt.tuple.reply 908 909 tid, manip := func() (tupleID, manipType) { 910 cn.mu.RLock() 911 defer cn.mu.RUnlock() 912 913 if reply { 914 tid := cn.original.tupleID 915 916 if dnat { 917 return tid, cn.sourceManip 918 } 919 return tid, cn.destinationManip 920 } 921 922 tid := cn.reply.tupleID 923 if dnat { 924 return tid, cn.destinationManip 925 } 926 return tid, cn.sourceManip 927 }() 928 switch manip { 929 case manipNotPerformed: 930 return false 931 case manipPerformedNoop: 932 *natDone = true 933 return true 934 case manipPerformed: 935 default: 936 panic(fmt.Sprintf("unhandled manip = %d", manip)) 937 } 938 939 newPort := tid.dstPortOrEchoReplyIdent 940 newAddr := tid.dstAddr 941 if dnat { 942 newPort = tid.srcPortOrEchoRequestIdent 943 newAddr = tid.srcAddr 944 } 945 946 rewritePacket( 947 netHdr, 948 transHdr, 949 !dnat != isICMPError, 950 fullChecksum, 951 updatePseudoHeader, 952 newPort, 953 newAddr, 954 ) 955 956 *natDone = true 957 958 if !isICMPError { 959 return true 960 } 961 962 // We performed NAT on (erroneous) packet that triggered an ICMP response, but 963 // not the ICMP packet itself. 964 switch pkt.TransportProtocolNumber { 965 case header.ICMPv4ProtocolNumber: 966 icmp := header.ICMPv4(pkt.TransportHeader().Slice()) 967 // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum. 968 icmp.SetChecksum(0) 969 icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().Checksum())) 970 971 network := header.IPv4(pkt.NetworkHeader().Slice()) 972 if dnat { 973 network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr) 974 } else { 975 network.SetSourceAddressWithChecksumUpdate(tid.dstAddr) 976 } 977 case header.ICMPv6ProtocolNumber: 978 network := header.IPv6(pkt.NetworkHeader().Slice()) 979 srcAddr := network.SourceAddress() 980 dstAddr := network.DestinationAddress() 981 if dnat { 982 dstAddr = tid.srcAddr 983 } else { 984 srcAddr = tid.dstAddr 985 } 986 987 icmp := header.ICMPv6(pkt.TransportHeader().Slice()) 988 // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum. 989 icmp.SetChecksum(0) 990 payload := pkt.Data() 991 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 992 Header: icmp, 993 Src: srcAddr, 994 Dst: dstAddr, 995 PayloadCsum: payload.Checksum(), 996 PayloadLen: payload.Size(), 997 })) 998 999 if dnat { 1000 network.SetDestinationAddress(dstAddr) 1001 } else { 1002 network.SetSourceAddress(srcAddr) 1003 } 1004 } 1005 1006 return true 1007 } 1008 1009 // bucket gets the conntrack bucket for a tupleID. 1010 // +checklocksread:ct.mu 1011 func (ct *ConnTrack) bucket(id tupleID) int { 1012 return ct.bucketWithTableLength(id, len(ct.buckets)) 1013 } 1014 1015 func (ct *ConnTrack) bucketWithTableLength(id tupleID, tableLength int) int { 1016 h := jenkins.Sum32(ct.seed) 1017 h.Write(id.srcAddr.AsSlice()) 1018 h.Write(id.dstAddr.AsSlice()) 1019 shortBuf := make([]byte, 2) 1020 binary.LittleEndian.PutUint16(shortBuf, id.srcPortOrEchoRequestIdent) 1021 h.Write([]byte(shortBuf)) 1022 binary.LittleEndian.PutUint16(shortBuf, id.dstPortOrEchoReplyIdent) 1023 h.Write([]byte(shortBuf)) 1024 binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) 1025 h.Write([]byte(shortBuf)) 1026 binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) 1027 h.Write([]byte(shortBuf)) 1028 return int(h.Sum32()) % tableLength 1029 } 1030 1031 // reapUnused deletes timed out entries from the conntrack map. The rules for 1032 // reaping are: 1033 // - Each call to reapUnused traverses a fraction of the conntrack table. 1034 // Specifically, it traverses len(ct.buckets)/fractionPerReaping. 1035 // - After reaping, reapUnused decides when it should next run based on the 1036 // ratio of expired connections to examined connections. If the ratio is 1037 // greater than maxExpiredPct, it schedules the next run quickly. Otherwise it 1038 // slightly increases the interval between runs. 1039 // - maxFullTraversal caps the time it takes to traverse the entire table. 1040 // 1041 // reapUnused returns the next bucket that should be checked and the time after 1042 // which it should be called again. 1043 func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { 1044 const fractionPerReaping = 128 1045 const maxExpiredPct = 50 1046 const maxFullTraversal = 60 * time.Second 1047 const minInterval = 10 * time.Millisecond 1048 const maxInterval = maxFullTraversal / fractionPerReaping 1049 1050 now := ct.clock.NowMonotonic() 1051 checked := 0 1052 expired := 0 1053 var idx int 1054 ct.mu.RLock() 1055 defer ct.mu.RUnlock() 1056 for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { 1057 idx = (i + start) % len(ct.buckets) 1058 bkt := &ct.buckets[idx] 1059 bkt.mu.Lock() 1060 for tuple := bkt.tuples.Front(); tuple != nil; { 1061 // reapTupleLocked updates tuple's next pointer so we grab it here. 1062 nextTuple := tuple.Next() 1063 1064 checked++ 1065 if ct.reapTupleLocked(tuple, idx, bkt, now) { 1066 expired++ 1067 } 1068 1069 tuple = nextTuple 1070 } 1071 bkt.mu.Unlock() 1072 } 1073 // We already checked buckets[idx]. 1074 idx++ 1075 1076 // If half or more of the connections are expired, the table has gotten 1077 // stale. Reschedule quickly. 1078 expiredPct := 0 1079 if checked != 0 { 1080 expiredPct = expired * 100 / checked 1081 } 1082 if expiredPct > maxExpiredPct { 1083 return idx, minInterval 1084 } 1085 if interval := prevInterval + minInterval; interval <= maxInterval { 1086 // Increment the interval between runs. 1087 return idx, interval 1088 } 1089 // We've hit the maximum interval. 1090 return idx, maxInterval 1091 } 1092 1093 // reapTupleLocked tries to remove tuple and its reply from the table. It 1094 // returns whether the tuple's connection has timed out. 1095 // 1096 // Precondition: ct.mu is read locked and bkt.mu is write locked. 1097 // +checklocksread:ct.mu 1098 // +checklocks:bkt.mu 1099 func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { 1100 if !reapingTuple.conn.timedOut(now) { 1101 return false 1102 } 1103 1104 var otherTuple *tuple 1105 if reapingTuple.reply { 1106 otherTuple = &reapingTuple.conn.original 1107 } else { 1108 otherTuple = &reapingTuple.conn.reply 1109 } 1110 1111 otherTupleBktID := ct.bucket(otherTuple.tupleID) 1112 replyTupleInserted := reapingTuple.conn.getFinalizeResult() == finalizeResultSuccess 1113 1114 // To maintain lock order, we can only reap both tuples if the tuple for the 1115 // other direction appears later in the table. 1116 if bktID > otherTupleBktID && replyTupleInserted { 1117 return true 1118 } 1119 1120 bkt.tuples.Remove(reapingTuple) 1121 1122 if !replyTupleInserted { 1123 // The other tuple is the reply which has not yet been inserted. 1124 return true 1125 } 1126 1127 // Reap the other connection. 1128 if bktID == otherTupleBktID { 1129 // Don't re-lock if both tuples are in the same bucket. 1130 bkt.tuples.Remove(otherTuple) 1131 } else { 1132 otherTupleBkt := &ct.buckets[otherTupleBktID] 1133 otherTupleBkt.mu.NestedLock(bucketLockOthertuple) 1134 otherTupleBkt.tuples.Remove(otherTuple) 1135 otherTupleBkt.mu.NestedUnlock(bucketLockOthertuple) 1136 } 1137 1138 return true 1139 } 1140 1141 func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { 1142 // Lookup the connection. The reply's original destination 1143 // describes the original address. 1144 tid := tupleID{ 1145 srcAddr: epID.LocalAddress, 1146 srcPortOrEchoRequestIdent: epID.LocalPort, 1147 dstAddr: epID.RemoteAddress, 1148 dstPortOrEchoReplyIdent: epID.RemotePort, 1149 transProto: transProto, 1150 netProto: netProto, 1151 } 1152 t := ct.connForTID(tid) 1153 if t == nil { 1154 // Not a tracked connection. 1155 return tcpip.Address{}, 0, &tcpip.ErrNotConnected{} 1156 } 1157 1158 t.conn.mu.RLock() 1159 defer t.conn.mu.RUnlock() 1160 if t.conn.destinationManip == manipNotPerformed { 1161 // Unmanipulated destination. 1162 return tcpip.Address{}, 0, &tcpip.ErrInvalidOptionValue{} 1163 } 1164 1165 id := t.conn.original.tupleID 1166 return id.dstAddr, id.dstPortOrEchoReplyIdent, nil 1167 }