github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/stack/conntrack.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "encoding/binary" 19 "fmt" 20 "math" 21 "math/rand" 22 "sync" 23 "time" 24 25 "github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops" 26 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip" 27 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/hash/jenkins" 28 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header" 29 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/transport/tcpconntrack" 30 ) 31 32 // Connection tracking is used to track and manipulate packets for NAT rules. 33 // The connection is created for a packet if it does not exist. Every 34 // connection contains two tuples (original and reply). The tuples are 35 // manipulated if there is a matching NAT rule. The packet is modified by 36 // looking at the tuples in each hook. 37 // 38 // Currently, only TCP tracking is supported. 39 40 // Our hash table has 16K buckets. 41 const numBuckets = 1 << 14 42 43 const ( 44 establishedTimeout time.Duration = 5 * 24 * time.Hour 45 unestablishedTimeout time.Duration = 120 * time.Second 46 ) 47 48 // tuple holds a connection's identifying and manipulating data in one 49 // direction. It is immutable. 50 // 51 // +stateify savable 52 type tuple struct { 53 // tupleEntry is used to build an intrusive list of tuples. 54 tupleEntry 55 56 // conn is the connection tracking entry this tuple belongs to. 57 conn *conn 58 59 // reply is true iff the tuple's direction is opposite that of the first 60 // packet seen on the connection. 61 reply bool 62 63 // tupleID is set at initialization and is immutable. 64 tupleID tupleID 65 } 66 67 // tupleID uniquely identifies a trackable connection in one direction. 68 // 69 // +stateify savable 70 type tupleID struct { 71 srcAddr tcpip.Address 72 // The source port of a packet in the original direction is overloaded with 73 // the ident of an Echo Request packet. 74 // 75 // This also matches the behaviour of sending packets on Linux where the 76 // socket's source port value is used for the source port of outgoing packets 77 // for TCP/UDP and the ident field for outgoing Echo Requests on Ping sockets: 78 // 79 // IPv4: https://github.com/torvalds/linux/blob/c5c17547b778975b3d83a73c8d84e8fb5ecf3ba5/net/ipv4/ping.c#L810 80 // IPv6: https://github.com/torvalds/linux/blob/c5c17547b778975b3d83a73c8d84e8fb5ecf3ba5/net/ipv6/ping.c#L133 81 srcPortOrEchoRequestIdent uint16 82 dstAddr tcpip.Address 83 // The opposite of srcPortOrEchoRequestIdent; the destination port of a packet 84 // in the reply direction is overloaded with the ident of an Echo Reply. 85 dstPortOrEchoReplyIdent uint16 86 transProto tcpip.TransportProtocolNumber 87 netProto tcpip.NetworkProtocolNumber 88 } 89 90 // reply creates the reply tupleID. 91 func (ti tupleID) reply() tupleID { 92 return tupleID{ 93 srcAddr: ti.dstAddr, 94 srcPortOrEchoRequestIdent: ti.dstPortOrEchoReplyIdent, 95 dstAddr: ti.srcAddr, 96 dstPortOrEchoReplyIdent: ti.srcPortOrEchoRequestIdent, 97 transProto: ti.transProto, 98 netProto: ti.netProto, 99 } 100 } 101 102 type manipType int 103 104 const ( 105 // manipNotPerformed indicates that NAT has not been performed. 106 manipNotPerformed manipType = iota 107 108 // manipPerformed indicates that NAT was performed. 109 manipPerformed 110 111 // manipPerformedNoop indicates that NAT was performed but it was a no-op. 112 manipPerformedNoop 113 ) 114 115 type finalizeResult uint32 116 117 const ( 118 // A finalizeResult must be explicitly set so we don't make use of the zero 119 // value. 120 _ finalizeResult = iota 121 122 finalizeResultSuccess 123 finalizeResultConflict 124 ) 125 126 // conn is a tracked connection. 127 // 128 // +stateify savable 129 type conn struct { 130 ct *ConnTrack 131 132 // original is the tuple in original direction. It is immutable. 133 original tuple 134 135 // reply is the tuple in reply direction. 136 reply tuple 137 138 finalizeOnce sync.Once 139 // Holds a finalizeResult. 140 finalizeResult atomicbitops.Uint32 141 142 mu connRWMutex `state:"nosave"` 143 // sourceManip indicates the source manipulation type. 144 // 145 // +checklocks:mu 146 sourceManip manipType 147 // destinationManip indicates the destination's manipulation type. 148 // 149 // +checklocks:mu 150 destinationManip manipType 151 152 stateMu stateConnRWMutex `state:"nosave"` 153 // tcb is TCB control block. It is used to keep track of states 154 // of tcp connection. 155 // 156 // +checklocks:stateMu 157 tcb tcpconntrack.TCB 158 // lastUsed is the last time the connection saw a relevant packet, and 159 // is updated by each packet on the connection. 160 // 161 // +checklocks:stateMu 162 lastUsed tcpip.MonotonicTime 163 } 164 165 // timedOut returns whether the connection timed out based on its state. 166 func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { 167 cn.stateMu.RLock() 168 defer cn.stateMu.RUnlock() 169 if cn.tcb.State() == tcpconntrack.ResultAlive { 170 // Use the same default as Linux, which doesn't delete 171 // established connections for 5(!) days. 172 return now.Sub(cn.lastUsed) > establishedTimeout 173 } 174 // Use the same default as Linux, which lets connections in most states 175 // other than established remain for <= 120 seconds. 176 return now.Sub(cn.lastUsed) > unestablishedTimeout 177 } 178 179 // update the connection tracking state. 180 func (cn *conn) update(pkt PacketBufferPtr, reply bool) { 181 cn.stateMu.Lock() 182 defer cn.stateMu.Unlock() 183 184 // Mark the connection as having been used recently so it isn't reaped. 185 cn.lastUsed = cn.ct.clock.NowMonotonic() 186 187 if pkt.TransportProtocolNumber != header.TCPProtocolNumber { 188 return 189 } 190 191 tcpHeader := header.TCP(pkt.TransportHeader().Slice()) 192 193 // Update the state of tcb. tcb assumes it's always initialized on the 194 // client. However, we only need to know whether the connection is 195 // established or not, so the client/server distinction isn't important. 196 if cn.tcb.IsEmpty() { 197 cn.tcb.Init(tcpHeader, pkt.Data().Size()) 198 return 199 } 200 201 if reply { 202 cn.tcb.UpdateStateReply(tcpHeader, pkt.Data().Size()) 203 } else { 204 cn.tcb.UpdateStateOriginal(tcpHeader, pkt.Data().Size()) 205 } 206 } 207 208 // ConnTrack tracks all connections created for NAT rules. Most users are 209 // expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. 210 // 211 // ConnTrack keeps all connections in a slice of buckets, each of which holds a 212 // linked list of tuples. This gives us some desirable properties: 213 // - Each bucket has its own lock, lessening lock contention. 214 // - The slice is large enough that lists stay short (<10 elements on average). 215 // Thus traversal is fast. 216 // - During linked list traversal we reap expired connections. This amortizes 217 // the cost of reaping them and makes reapUnused faster. 218 // 219 // Locks are ordered by their location in the buckets slice. That is, a 220 // goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. 221 // 222 // +stateify savable 223 type ConnTrack struct { 224 // seed is a one-time random value initialized at stack startup 225 // and is used in the calculation of hash keys for the list of buckets. 226 // It is immutable. 227 seed uint32 228 229 // clock provides timing used to determine conntrack reapings. 230 clock tcpip.Clock 231 rand *rand.Rand 232 233 mu connTrackRWMutex `state:"nosave"` 234 // mu protects the buckets slice, but not buckets' contents. Only take 235 // the write lock if you are modifying the slice or saving for S/R. 236 // 237 // +checklocks:mu 238 buckets []bucket 239 } 240 241 // +stateify savable 242 type bucket struct { 243 mu bucketRWMutex `state:"nosave"` 244 // +checklocks:mu 245 tuples tupleList 246 } 247 248 // A netAndTransHeadersFunc returns the network and transport headers found 249 // in an ICMP payload. The transport layer's payload will not be returned. 250 // 251 // May panic if the packet does not hold the transport header. 252 type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) 253 254 func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { 255 netHdr := header.IPv4(icmpPayload) 256 // Do not use netHdr.Payload() as we might not hold the full packet 257 // in the ICMP error; Payload() panics if the buffer is smaller than 258 // the total length specified in the IPv4 header. 259 transHdr := icmpPayload[netHdr.HeaderLength():] 260 return netHdr, transHdr[:minTransHdrLen] 261 } 262 263 func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { 264 netHdr := header.IPv6(icmpPayload) 265 // Do not use netHdr.Payload() as we might not hold the full packet 266 // in the ICMP error; Payload() panics if the IP payload is smaller than 267 // the payload length specified in the IPv6 header. 268 transHdr := icmpPayload[header.IPv6MinimumSize:] 269 return netHdr, transHdr[:minTransHdrLen] 270 } 271 272 func getEmbeddedNetAndTransHeaders(pkt PacketBufferPtr, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { 273 switch transProto { 274 case header.TCPProtocolNumber: 275 if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { 276 netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) 277 return netHeader, header.TCP(transHeaderBytes), true 278 } 279 case header.UDPProtocolNumber: 280 if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { 281 netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) 282 return netHeader, header.UDP(transHeaderBytes), true 283 } 284 } 285 return nil, nil, false 286 } 287 288 func getHeaders(pkt PacketBufferPtr) (netHdr header.Network, transHdr header.Transport, isICMPError bool, ok bool) { 289 switch pkt.TransportProtocolNumber { 290 case header.TCPProtocolNumber: 291 if tcpHeader := header.TCP(pkt.TransportHeader().Slice()); len(tcpHeader) >= header.TCPMinimumSize { 292 return pkt.Network(), tcpHeader, false, true 293 } 294 return nil, nil, false, false 295 case header.UDPProtocolNumber: 296 if udpHeader := header.UDP(pkt.TransportHeader().Slice()); len(udpHeader) >= header.UDPMinimumSize { 297 return pkt.Network(), udpHeader, false, true 298 } 299 return nil, nil, false, false 300 case header.ICMPv4ProtocolNumber: 301 icmpHeader := header.ICMPv4(pkt.TransportHeader().Slice()) 302 if len(icmpHeader) < header.ICMPv4MinimumSize { 303 return nil, nil, false, false 304 } 305 306 switch icmpType := icmpHeader.Type(); icmpType { 307 case header.ICMPv4Echo, header.ICMPv4EchoReply: 308 return pkt.Network(), icmpHeader, false, true 309 case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem: 310 default: 311 panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType)) 312 } 313 314 h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) 315 if !ok { 316 panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize)) 317 } 318 319 if header.IPv4(h).HeaderLength() > header.IPv4MinimumSize { 320 // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. 321 panic("should have dropped packets with IPv4 options") 322 } 323 324 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.tupleID.transProto); ok { 325 return netHdr, transHdr, true, true 326 } 327 return nil, nil, false, false 328 case header.ICMPv6ProtocolNumber: 329 icmpHeader := header.ICMPv6(pkt.TransportHeader().Slice()) 330 if len(icmpHeader) < header.ICMPv6MinimumSize { 331 return nil, nil, false, false 332 } 333 334 switch icmpType := icmpHeader.Type(); icmpType { 335 case header.ICMPv6EchoRequest, header.ICMPv6EchoReply: 336 return pkt.Network(), icmpHeader, false, true 337 case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem: 338 default: 339 panic(fmt.Sprintf("unexpected ICMPv6 type = %d", icmpType)) 340 } 341 342 h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) 343 if !ok { 344 panic(fmt.Sprintf("should have a valid IPv6 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv6MinimumSize)) 345 } 346 347 // We do not support extension headers in ICMP errors so the next header 348 // in the IPv6 packet should be a tracked protocol if we reach this point. 349 // 350 // TODO(https://gvisor.dev/issue/6789): Support extension headers. 351 transProto := pkt.tuple.tupleID.transProto 352 if got := header.IPv6(h).TransportProtocol(); got != transProto { 353 panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) 354 } 355 356 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { 357 return netHdr, transHdr, true, true 358 } 359 return nil, nil, false, false 360 default: 361 panic(fmt.Sprintf("unexpected transport protocol = %d", pkt.TransportProtocolNumber)) 362 } 363 } 364 365 func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID { 366 return tupleID{ 367 srcAddr: netHdr.SourceAddress(), 368 srcPortOrEchoRequestIdent: transHdr.SourcePort(), 369 dstAddr: netHdr.DestinationAddress(), 370 dstPortOrEchoReplyIdent: transHdr.DestinationPort(), 371 transProto: transProto, 372 netProto: netProto, 373 } 374 } 375 376 func getTupleIDForPacketInICMPError(pkt PacketBufferPtr, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { 377 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { 378 return tupleID{ 379 srcAddr: netHdr.DestinationAddress(), 380 srcPortOrEchoRequestIdent: transHdr.DestinationPort(), 381 dstAddr: netHdr.SourceAddress(), 382 dstPortOrEchoReplyIdent: transHdr.SourcePort(), 383 transProto: transProto, 384 netProto: netProto, 385 }, true 386 } 387 388 return tupleID{}, false 389 } 390 391 type getTupleIDDisposition int 392 393 const ( 394 getTupleIDNotOK getTupleIDDisposition = iota 395 getTupleIDOKAndAllowNewConn 396 getTupleIDOKAndDontAllowNewConn 397 ) 398 399 func getTupleIDForEchoPacket(pkt PacketBufferPtr, ident uint16, request bool) tupleID { 400 netHdr := pkt.Network() 401 tid := tupleID{ 402 srcAddr: netHdr.SourceAddress(), 403 dstAddr: netHdr.DestinationAddress(), 404 transProto: pkt.TransportProtocolNumber, 405 netProto: pkt.NetworkProtocolNumber, 406 } 407 408 if request { 409 tid.srcPortOrEchoRequestIdent = ident 410 } else { 411 tid.dstPortOrEchoReplyIdent = ident 412 } 413 414 return tid 415 } 416 417 func getTupleID(pkt PacketBufferPtr) (tupleID, getTupleIDDisposition) { 418 switch pkt.TransportProtocolNumber { 419 case header.TCPProtocolNumber: 420 if transHeader := header.TCP(pkt.TransportHeader().Slice()); len(transHeader) >= header.TCPMinimumSize { 421 return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), getTupleIDOKAndAllowNewConn 422 } 423 case header.UDPProtocolNumber: 424 if transHeader := header.UDP(pkt.TransportHeader().Slice()); len(transHeader) >= header.UDPMinimumSize { 425 return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), getTupleIDOKAndAllowNewConn 426 } 427 case header.ICMPv4ProtocolNumber: 428 icmp := header.ICMPv4(pkt.TransportHeader().Slice()) 429 if len(icmp) < header.ICMPv4MinimumSize { 430 return tupleID{}, getTupleIDNotOK 431 } 432 433 switch icmp.Type() { 434 case header.ICMPv4Echo: 435 return getTupleIDForEchoPacket(pkt, icmp.Ident(), true /* request */), getTupleIDOKAndAllowNewConn 436 case header.ICMPv4EchoReply: 437 // Do not create a new connection in response to a reply packet as only 438 // the first packet of a connection should create a conntrack entry but 439 // a reply is never the first packet sent for a connection. 440 return getTupleIDForEchoPacket(pkt, icmp.Ident(), false /* request */), getTupleIDOKAndDontAllowNewConn 441 case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem: 442 default: 443 // Unsupported ICMP type for NAT-ing. 444 return tupleID{}, getTupleIDNotOK 445 } 446 447 h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) 448 if !ok { 449 return tupleID{}, getTupleIDNotOK 450 } 451 452 ipv4 := header.IPv4(h) 453 if ipv4.HeaderLength() > header.IPv4MinimumSize { 454 // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options. 455 return tupleID{}, getTupleIDNotOK 456 } 457 458 if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { 459 // Do not create a new connection in response to an ICMP error. 460 return tid, getTupleIDOKAndDontAllowNewConn 461 } 462 case header.ICMPv6ProtocolNumber: 463 icmp := header.ICMPv6(pkt.TransportHeader().Slice()) 464 if len(icmp) < header.ICMPv6MinimumSize { 465 return tupleID{}, getTupleIDNotOK 466 } 467 468 switch icmp.Type() { 469 case header.ICMPv6EchoRequest: 470 return getTupleIDForEchoPacket(pkt, icmp.Ident(), true /* request */), getTupleIDOKAndAllowNewConn 471 case header.ICMPv6EchoReply: 472 // Do not create a new connection in response to a reply packet as only 473 // the first packet of a connection should create a conntrack entry but 474 // a reply is never the first packet sent for a connection. 475 return getTupleIDForEchoPacket(pkt, icmp.Ident(), false /* request */), getTupleIDOKAndDontAllowNewConn 476 case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem: 477 default: 478 return tupleID{}, getTupleIDNotOK 479 } 480 481 h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) 482 if !ok { 483 return tupleID{}, getTupleIDNotOK 484 } 485 486 // TODO(https://gvisor.dev/issue/6789): Handle extension headers. 487 if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { 488 // Do not create a new connection in response to an ICMP error. 489 return tid, getTupleIDOKAndDontAllowNewConn 490 } 491 } 492 493 return tupleID{}, getTupleIDNotOK 494 } 495 496 func (ct *ConnTrack) init() { 497 ct.mu.Lock() 498 defer ct.mu.Unlock() 499 ct.buckets = make([]bucket, numBuckets) 500 } 501 502 // getConnAndUpdate attempts to get a connection or creates one if no 503 // connection exists for the packet and packet's protocol is trackable. 504 // 505 // If the packet's protocol is trackable, the connection's state is updated to 506 // match the contents of the packet. 507 func (ct *ConnTrack) getConnAndUpdate(pkt PacketBufferPtr, skipChecksumValidation bool) *tuple { 508 // Get or (maybe) create a connection. 509 t := func() *tuple { 510 var allowNewConn bool 511 tid, res := getTupleID(pkt) 512 switch res { 513 case getTupleIDNotOK: 514 return nil 515 case getTupleIDOKAndAllowNewConn: 516 allowNewConn = true 517 case getTupleIDOKAndDontAllowNewConn: 518 allowNewConn = false 519 default: 520 panic(fmt.Sprintf("unhandled %[1]T = %[1]d", res)) 521 } 522 523 // Just skip bad packets. They'll be rejected later by the appropriate 524 // protocol package. 525 switch pkt.TransportProtocolNumber { 526 case header.TCPProtocolNumber: 527 _, csumValid, ok := header.TCPValid( 528 header.TCP(pkt.TransportHeader().Slice()), 529 func() uint16 { return pkt.Data().Checksum() }, 530 uint16(pkt.Data().Size()), 531 tid.srcAddr, 532 tid.dstAddr, 533 pkt.RXChecksumValidated || skipChecksumValidation) 534 if !csumValid || !ok { 535 return nil 536 } 537 case header.UDPProtocolNumber: 538 lengthValid, csumValid := header.UDPValid( 539 header.UDP(pkt.TransportHeader().Slice()), 540 func() uint16 { return pkt.Data().Checksum() }, 541 uint16(pkt.Data().Size()), 542 pkt.NetworkProtocolNumber, 543 tid.srcAddr, 544 tid.dstAddr, 545 pkt.RXChecksumValidated || skipChecksumValidation) 546 if !lengthValid || !csumValid { 547 return nil 548 } 549 } 550 551 ct.mu.RLock() 552 bkt := &ct.buckets[ct.bucket(tid)] 553 ct.mu.RUnlock() 554 555 now := ct.clock.NowMonotonic() 556 if t := bkt.connForTID(tid, now); t != nil { 557 return t 558 } 559 560 if !allowNewConn { 561 return nil 562 } 563 564 bkt.mu.Lock() 565 defer bkt.mu.Unlock() 566 567 // Make sure a connection wasn't added between when we last checked the 568 // bucket and acquired the bucket's write lock. 569 if t := bkt.connForTIDRLocked(tid, now); t != nil { 570 return t 571 } 572 573 // This is the first packet we're seeing for the connection. Create an entry 574 // for this new connection. 575 conn := &conn{ 576 ct: ct, 577 original: tuple{tupleID: tid}, 578 reply: tuple{tupleID: tid.reply(), reply: true}, 579 lastUsed: now, 580 } 581 conn.original.conn = conn 582 conn.reply.conn = conn 583 584 // For now, we only map an entry for the packet's original tuple as NAT may be 585 // performed on this connection. Until the packet goes through all the hooks 586 // and its final address/port is known, we cannot know what the response 587 // packet's addresses/ports will look like. 588 // 589 // This is okay because the destination cannot send its response until it 590 // receives the packet; the packet will only be received once all the hooks 591 // have been performed. 592 // 593 // See (*conn).finalize. 594 bkt.tuples.PushFront(&conn.original) 595 return &conn.original 596 }() 597 if t != nil { 598 t.conn.update(pkt, t.reply) 599 } 600 return t 601 } 602 603 func (ct *ConnTrack) connForTID(tid tupleID) *tuple { 604 ct.mu.RLock() 605 bkt := &ct.buckets[ct.bucket(tid)] 606 ct.mu.RUnlock() 607 608 return bkt.connForTID(tid, ct.clock.NowMonotonic()) 609 } 610 611 func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple { 612 bkt.mu.RLock() 613 defer bkt.mu.RUnlock() 614 return bkt.connForTIDRLocked(tid, now) 615 } 616 617 // +checklocksread:bkt.mu 618 func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple { 619 for other := bkt.tuples.Front(); other != nil; other = other.Next() { 620 if tid == other.tupleID && !other.conn.timedOut(now) { 621 return other 622 } 623 } 624 return nil 625 } 626 627 func (ct *ConnTrack) finalize(cn *conn) finalizeResult { 628 ct.mu.RLock() 629 buckets := ct.buckets 630 ct.mu.RUnlock() 631 632 { 633 tid := cn.reply.tupleID 634 id := ct.bucketWithTableLength(tid, len(buckets)) 635 636 bkt := &buckets[id] 637 bkt.mu.Lock() 638 t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()) 639 if t == nil { 640 bkt.tuples.PushFront(&cn.reply) 641 bkt.mu.Unlock() 642 return finalizeResultSuccess 643 } 644 bkt.mu.Unlock() 645 646 if t.conn == cn { 647 // We already have an entry for the reply tuple. 648 // 649 // This can occur when the source address/port is the same as the 650 // destination address/port. In this scenario, tid == tid.reply(). 651 return finalizeResultSuccess 652 } 653 } 654 655 // Another connection for the reply already exists. Remove the original and 656 // let the caller know we failed. 657 // 658 // TODO(https://gvisor.dev/issue/6850): Investigate handling this clash 659 // better. 660 661 tid := cn.original.tupleID 662 id := ct.bucketWithTableLength(tid, len(buckets)) 663 bkt := &buckets[id] 664 bkt.mu.Lock() 665 defer bkt.mu.Unlock() 666 bkt.tuples.Remove(&cn.original) 667 return finalizeResultConflict 668 } 669 670 func (cn *conn) getFinalizeResult() finalizeResult { 671 return finalizeResult(cn.finalizeResult.Load()) 672 } 673 674 // finalize attempts to finalize the connection and returns true iff the 675 // connection was successfully finalized. 676 // 677 // If the connection failed to finalize, the caller should drop the packet 678 // associated with the connection. 679 // 680 // If multiple goroutines attempt to finalize at the same time, only one 681 // goroutine will perform the work to finalize the connection, but all 682 // goroutines will block until the finalizing goroutine finishes finalizing. 683 func (cn *conn) finalize() bool { 684 cn.finalizeOnce.Do(func() { 685 cn.finalizeResult.Store(uint32(cn.ct.finalize(cn))) 686 }) 687 688 switch res := cn.getFinalizeResult(); res { 689 case finalizeResultSuccess: 690 return true 691 case finalizeResultConflict: 692 return false 693 default: 694 panic(fmt.Sprintf("unhandled result = %d", res)) 695 } 696 } 697 698 func (cn *conn) maybePerformNoopNAT(dnat bool) { 699 cn.mu.Lock() 700 defer cn.mu.Unlock() 701 702 var manip *manipType 703 if dnat { 704 manip = &cn.destinationManip 705 } else { 706 manip = &cn.sourceManip 707 } 708 709 if *manip == manipNotPerformed { 710 *manip = manipPerformedNoop 711 } 712 } 713 714 type portOrIdentRange struct { 715 start uint16 716 size uint32 717 } 718 719 // performNAT setups up the connection for the specified NAT and rewrites the 720 // packet. 721 // 722 // If NAT has already been performed on the connection, then the packet will 723 // be rewritten with the NAT performed on the connection, ignoring the passed 724 // address and port range. 725 // 726 // Generally, only the first packet of a connection reaches this method; other 727 // packets will be manipulated without needing to modify the connection. 728 func (cn *conn) performNAT(pkt PacketBufferPtr, hook Hook, r *Route, portsOrIdents portOrIdentRange, natAddress tcpip.Address, dnat bool) { 729 lastPortOrIdent := func() uint16 { 730 lastPortOrIdent := uint32(portsOrIdents.start) + portsOrIdents.size - 1 731 if lastPortOrIdent > math.MaxUint16 { 732 panic(fmt.Sprintf("got lastPortOrIdent = %d, want <= MaxUint16(=%d); portsOrIdents=%#v", lastPortOrIdent, math.MaxUint16, portsOrIdents)) 733 } 734 return uint16(lastPortOrIdent) 735 }() 736 737 // Make sure the packet is re-written after performing NAT. 738 defer func() { 739 // handlePacket returns true if the packet may skip the NAT table as the 740 // connection is already NATed, but if we reach this point we must be in the 741 // NAT table, so the return value is useless for us. 742 _ = cn.handlePacket(pkt, hook, r) 743 }() 744 745 cn.mu.Lock() 746 defer cn.mu.Unlock() 747 748 var manip *manipType 749 var address *tcpip.Address 750 var portOrIdent *uint16 751 if dnat { 752 manip = &cn.destinationManip 753 address = &cn.reply.tupleID.srcAddr 754 portOrIdent = &cn.reply.tupleID.srcPortOrEchoRequestIdent 755 } else { 756 manip = &cn.sourceManip 757 address = &cn.reply.tupleID.dstAddr 758 portOrIdent = &cn.reply.tupleID.dstPortOrEchoReplyIdent 759 } 760 761 if *manip != manipNotPerformed { 762 return 763 } 764 *manip = manipPerformed 765 *address = natAddress 766 767 // Does the current port/ident fit in the range? 768 if portsOrIdents.start <= *portOrIdent && *portOrIdent <= lastPortOrIdent { 769 // Yes, is the current reply tuple unique? 770 if other := cn.ct.connForTID(cn.reply.tupleID); other == nil { 771 // Yes! No need to change the port. 772 return 773 } 774 } 775 776 // Try our best to find a port/ident that results in a unique reply tuple. 777 // 778 // We limit the number of attempts to find a unique tuple to not waste a lot 779 // of time looking for a unique tuple. 780 // 781 // Matches linux behaviour introduced in 782 // https://github.com/torvalds/linux/commit/a504b703bb1da526a01593da0e4be2af9d9f5fa8. 783 const maxAttemptsForInitialRound uint32 = 128 784 const minAttemptsToContinue = 16 785 786 allowedInitialAttempts := maxAttemptsForInitialRound 787 if allowedInitialAttempts > portsOrIdents.size { 788 allowedInitialAttempts = portsOrIdents.size 789 } 790 791 for maxAttempts := allowedInitialAttempts; ; maxAttempts /= 2 { 792 // Start reach round with a random initial port/ident offset. 793 randOffset := cn.ct.rand.Uint32() 794 795 for i := uint32(0); i < maxAttempts; i++ { 796 newPortOrIdentU32 := uint32(portsOrIdents.start) + (randOffset+i)%portsOrIdents.size 797 if newPortOrIdentU32 > math.MaxUint16 { 798 panic(fmt.Sprintf("got newPortOrIdentU32 = %d, want <= MaxUint16(=%d); portsOrIdents=%#v, randOffset=%d", newPortOrIdentU32, math.MaxUint16, portsOrIdents, randOffset)) 799 } 800 801 *portOrIdent = uint16(newPortOrIdentU32) 802 803 if other := cn.ct.connForTID(cn.reply.tupleID); other == nil { 804 // We found a unique tuple! 805 return 806 } 807 } 808 809 if maxAttempts == portsOrIdents.size { 810 // We already tried all the ports/idents in the range so no need to keep 811 // trying. 812 return 813 } 814 815 if maxAttempts < minAttemptsToContinue { 816 return 817 } 818 } 819 820 // We did not find a unique tuple, use the last used port anyways. 821 // TODO(https://gvisor.dev/issue/6850): Handle not finding a unique tuple 822 // better (e.g. remove the connection and drop the packet). 823 } 824 825 // handlePacket attempts to handle a packet and perform NAT if the connection 826 // has had NAT performed on it. 827 // 828 // Returns true if the packet can skip the NAT table. 829 func (cn *conn) handlePacket(pkt PacketBufferPtr, hook Hook, rt *Route) bool { 830 netHdr, transHdr, isICMPError, ok := getHeaders(pkt) 831 if !ok { 832 return false 833 } 834 835 fullChecksum := false 836 updatePseudoHeader := false 837 natDone := &pkt.snatDone 838 dnat := false 839 switch hook { 840 case Prerouting: 841 // Packet came from outside the stack so it must have a checksum set 842 // already. 843 fullChecksum = true 844 updatePseudoHeader = true 845 846 natDone = &pkt.dnatDone 847 dnat = true 848 case Input: 849 case Forward: 850 panic("should not handle packet in the forwarding hook") 851 case Output: 852 natDone = &pkt.dnatDone 853 dnat = true 854 fallthrough 855 case Postrouting: 856 if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { 857 updatePseudoHeader = true 858 } else if rt.RequiresTXTransportChecksum() { 859 fullChecksum = true 860 updatePseudoHeader = true 861 } 862 default: 863 panic(fmt.Sprintf("unrecognized hook = %d", hook)) 864 } 865 866 if *natDone { 867 panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt)) 868 } 869 870 // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be 871 // validated if checksum offloading is off. It may require IP defrag if the 872 // packets are fragmented. 873 874 reply := pkt.tuple.reply 875 876 tid, manip := func() (tupleID, manipType) { 877 cn.mu.RLock() 878 defer cn.mu.RUnlock() 879 880 if reply { 881 tid := cn.original.tupleID 882 883 if dnat { 884 return tid, cn.sourceManip 885 } 886 return tid, cn.destinationManip 887 } 888 889 tid := cn.reply.tupleID 890 if dnat { 891 return tid, cn.destinationManip 892 } 893 return tid, cn.sourceManip 894 }() 895 switch manip { 896 case manipNotPerformed: 897 return false 898 case manipPerformedNoop: 899 *natDone = true 900 return true 901 case manipPerformed: 902 default: 903 panic(fmt.Sprintf("unhandled manip = %d", manip)) 904 } 905 906 newPort := tid.dstPortOrEchoReplyIdent 907 newAddr := tid.dstAddr 908 if dnat { 909 newPort = tid.srcPortOrEchoRequestIdent 910 newAddr = tid.srcAddr 911 } 912 913 rewritePacket( 914 netHdr, 915 transHdr, 916 !dnat != isICMPError, 917 fullChecksum, 918 updatePseudoHeader, 919 newPort, 920 newAddr, 921 ) 922 923 *natDone = true 924 925 if !isICMPError { 926 return true 927 } 928 929 // We performed NAT on (erroneous) packet that triggered an ICMP response, but 930 // not the ICMP packet itself. 931 switch pkt.TransportProtocolNumber { 932 case header.ICMPv4ProtocolNumber: 933 icmp := header.ICMPv4(pkt.TransportHeader().Slice()) 934 // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum. 935 icmp.SetChecksum(0) 936 icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().Checksum())) 937 938 network := header.IPv4(pkt.NetworkHeader().Slice()) 939 if dnat { 940 network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr) 941 } else { 942 network.SetSourceAddressWithChecksumUpdate(tid.dstAddr) 943 } 944 case header.ICMPv6ProtocolNumber: 945 network := header.IPv6(pkt.NetworkHeader().Slice()) 946 srcAddr := network.SourceAddress() 947 dstAddr := network.DestinationAddress() 948 if dnat { 949 dstAddr = tid.srcAddr 950 } else { 951 srcAddr = tid.dstAddr 952 } 953 954 icmp := header.ICMPv6(pkt.TransportHeader().Slice()) 955 // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum. 956 icmp.SetChecksum(0) 957 payload := pkt.Data() 958 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 959 Header: icmp, 960 Src: srcAddr, 961 Dst: dstAddr, 962 PayloadCsum: payload.Checksum(), 963 PayloadLen: payload.Size(), 964 })) 965 966 if dnat { 967 network.SetDestinationAddress(dstAddr) 968 } else { 969 network.SetSourceAddress(srcAddr) 970 } 971 } 972 973 return true 974 } 975 976 // bucket gets the conntrack bucket for a tupleID. 977 // +checklocksread:ct.mu 978 func (ct *ConnTrack) bucket(id tupleID) int { 979 return ct.bucketWithTableLength(id, len(ct.buckets)) 980 } 981 982 func (ct *ConnTrack) bucketWithTableLength(id tupleID, tableLength int) int { 983 h := jenkins.Sum32(ct.seed) 984 h.Write(id.srcAddr.AsSlice()) 985 h.Write(id.dstAddr.AsSlice()) 986 shortBuf := make([]byte, 2) 987 binary.LittleEndian.PutUint16(shortBuf, id.srcPortOrEchoRequestIdent) 988 h.Write([]byte(shortBuf)) 989 binary.LittleEndian.PutUint16(shortBuf, id.dstPortOrEchoReplyIdent) 990 h.Write([]byte(shortBuf)) 991 binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) 992 h.Write([]byte(shortBuf)) 993 binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) 994 h.Write([]byte(shortBuf)) 995 return int(h.Sum32()) % tableLength 996 } 997 998 // reapUnused deletes timed out entries from the conntrack map. The rules for 999 // reaping are: 1000 // - Each call to reapUnused traverses a fraction of the conntrack table. 1001 // Specifically, it traverses len(ct.buckets)/fractionPerReaping. 1002 // - After reaping, reapUnused decides when it should next run based on the 1003 // ratio of expired connections to examined connections. If the ratio is 1004 // greater than maxExpiredPct, it schedules the next run quickly. Otherwise it 1005 // slightly increases the interval between runs. 1006 // - maxFullTraversal caps the time it takes to traverse the entire table. 1007 // 1008 // reapUnused returns the next bucket that should be checked and the time after 1009 // which it should be called again. 1010 func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { 1011 const fractionPerReaping = 128 1012 const maxExpiredPct = 50 1013 const maxFullTraversal = 60 * time.Second 1014 const minInterval = 10 * time.Millisecond 1015 const maxInterval = maxFullTraversal / fractionPerReaping 1016 1017 now := ct.clock.NowMonotonic() 1018 checked := 0 1019 expired := 0 1020 var idx int 1021 ct.mu.RLock() 1022 defer ct.mu.RUnlock() 1023 for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { 1024 idx = (i + start) % len(ct.buckets) 1025 bkt := &ct.buckets[idx] 1026 bkt.mu.Lock() 1027 for tuple := bkt.tuples.Front(); tuple != nil; { 1028 // reapTupleLocked updates tuple's next pointer so we grab it here. 1029 nextTuple := tuple.Next() 1030 1031 checked++ 1032 if ct.reapTupleLocked(tuple, idx, bkt, now) { 1033 expired++ 1034 } 1035 1036 tuple = nextTuple 1037 } 1038 bkt.mu.Unlock() 1039 } 1040 // We already checked buckets[idx]. 1041 idx++ 1042 1043 // If half or more of the connections are expired, the table has gotten 1044 // stale. Reschedule quickly. 1045 expiredPct := 0 1046 if checked != 0 { 1047 expiredPct = expired * 100 / checked 1048 } 1049 if expiredPct > maxExpiredPct { 1050 return idx, minInterval 1051 } 1052 if interval := prevInterval + minInterval; interval <= maxInterval { 1053 // Increment the interval between runs. 1054 return idx, interval 1055 } 1056 // We've hit the maximum interval. 1057 return idx, maxInterval 1058 } 1059 1060 // reapTupleLocked tries to remove tuple and its reply from the table. It 1061 // returns whether the tuple's connection has timed out. 1062 // 1063 // Precondition: ct.mu is read locked and bkt.mu is write locked. 1064 // +checklocksread:ct.mu 1065 // +checklocks:bkt.mu 1066 func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { 1067 if !reapingTuple.conn.timedOut(now) { 1068 return false 1069 } 1070 1071 var otherTuple *tuple 1072 if reapingTuple.reply { 1073 otherTuple = &reapingTuple.conn.original 1074 } else { 1075 otherTuple = &reapingTuple.conn.reply 1076 } 1077 1078 otherTupleBktID := ct.bucket(otherTuple.tupleID) 1079 replyTupleInserted := reapingTuple.conn.getFinalizeResult() == finalizeResultSuccess 1080 1081 // To maintain lock order, we can only reap both tuples if the tuple for the 1082 // other direction appears later in the table. 1083 if bktID > otherTupleBktID && replyTupleInserted { 1084 return true 1085 } 1086 1087 bkt.tuples.Remove(reapingTuple) 1088 1089 if !replyTupleInserted { 1090 // The other tuple is the reply which has not yet been inserted. 1091 return true 1092 } 1093 1094 // Reap the other connection. 1095 if bktID == otherTupleBktID { 1096 // Don't re-lock if both tuples are in the same bucket. 1097 bkt.tuples.Remove(otherTuple) 1098 } else { 1099 otherTupleBkt := &ct.buckets[otherTupleBktID] 1100 otherTupleBkt.mu.NestedLock(bucketLockOthertuple) 1101 otherTupleBkt.tuples.Remove(otherTuple) 1102 otherTupleBkt.mu.NestedUnlock(bucketLockOthertuple) 1103 } 1104 1105 return true 1106 } 1107 1108 func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { 1109 // Lookup the connection. The reply's original destination 1110 // describes the original address. 1111 tid := tupleID{ 1112 srcAddr: epID.LocalAddress, 1113 srcPortOrEchoRequestIdent: epID.LocalPort, 1114 dstAddr: epID.RemoteAddress, 1115 dstPortOrEchoReplyIdent: epID.RemotePort, 1116 transProto: transProto, 1117 netProto: netProto, 1118 } 1119 t := ct.connForTID(tid) 1120 if t == nil { 1121 // Not a tracked connection. 1122 return tcpip.Address{}, 0, &tcpip.ErrNotConnected{} 1123 } 1124 1125 t.conn.mu.RLock() 1126 defer t.conn.mu.RUnlock() 1127 if t.conn.destinationManip == manipNotPerformed { 1128 // Unmanipulated destination. 1129 return tcpip.Address{}, 0, &tcpip.ErrInvalidOptionValue{} 1130 } 1131 1132 id := t.conn.original.tupleID 1133 return id.dstAddr, id.dstPortOrEchoReplyIdent, nil 1134 }