inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/stack/conntrack.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"encoding/binary"
    19  	"fmt"
    20  	"math/rand"
    21  	"sync"
    22  	"sync/atomic"
    23  	"time"
    24  
    25  	"inet.af/netstack/tcpip"
    26  	"inet.af/netstack/tcpip/hash/jenkins"
    27  	"inet.af/netstack/tcpip/header"
    28  	"inet.af/netstack/tcpip/transport/tcpconntrack"
    29  )
    30  
    31  // Connection tracking is used to track and manipulate packets for NAT rules.
    32  // The connection is created for a packet if it does not exist. Every
    33  // connection contains two tuples (original and reply). The tuples are
    34  // manipulated if there is a matching NAT rule. The packet is modified by
    35  // looking at the tuples in each hook.
    36  //
    37  // Currently, only TCP tracking is supported.
    38  
    39  // Our hash table has 16K buckets.
    40  const numBuckets = 1 << 14
    41  
    42  const (
    43  	establishedTimeout   time.Duration = 5 * 24 * time.Hour
    44  	unestablishedTimeout time.Duration = 120 * time.Second
    45  )
    46  
    47  // tuple holds a connection's identifying and manipulating data in one
    48  // direction. It is immutable.
    49  //
    50  // +stateify savable
    51  type tuple struct {
    52  	// tupleEntry is used to build an intrusive list of tuples.
    53  	tupleEntry
    54  
    55  	// conn is the connection tracking entry this tuple belongs to.
    56  	conn *conn
    57  
    58  	// reply is true iff the tuple's direction is opposite that of the first
    59  	// packet seen on the connection.
    60  	reply bool
    61  
    62  	mu sync.RWMutex `state:"nosave"`
    63  	// +checklocks:mu
    64  	tupleID tupleID
    65  }
    66  
    67  func (t *tuple) id() tupleID {
    68  	t.mu.RLock()
    69  	defer t.mu.RUnlock()
    70  	return t.tupleID
    71  }
    72  
    73  // tupleID uniquely identifies a connection in one direction. It currently
    74  // contains enough information to distinguish between any TCP or UDP
    75  // connection, and will need to be extended to support other protocols.
    76  //
    77  // +stateify savable
    78  type tupleID struct {
    79  	srcAddr    tcpip.Address
    80  	srcPort    uint16
    81  	dstAddr    tcpip.Address
    82  	dstPort    uint16
    83  	transProto tcpip.TransportProtocolNumber
    84  	netProto   tcpip.NetworkProtocolNumber
    85  }
    86  
    87  // reply creates the reply tupleID.
    88  func (ti tupleID) reply() tupleID {
    89  	return tupleID{
    90  		srcAddr:    ti.dstAddr,
    91  		srcPort:    ti.dstPort,
    92  		dstAddr:    ti.srcAddr,
    93  		dstPort:    ti.srcPort,
    94  		transProto: ti.transProto,
    95  		netProto:   ti.netProto,
    96  	}
    97  }
    98  
    99  type manipType int
   100  
   101  const (
   102  	// manipNotPerformed indicates that NAT has not been performed.
   103  	manipNotPerformed manipType = iota
   104  
   105  	// manipPerformed indicates that NAT was performed.
   106  	manipPerformed
   107  
   108  	// manipPerformedNoop indicates that NAT was performed but it was a no-op.
   109  	manipPerformedNoop
   110  )
   111  
   112  type finalizeResult uint32
   113  
   114  const (
   115  	// A finalizeResult must be explicitly set so we don't make use of the zero
   116  	// value.
   117  	_ finalizeResult = iota
   118  
   119  	finalizeResultSuccess
   120  	finalizeResultConflict
   121  )
   122  
   123  // conn is a tracked connection.
   124  //
   125  // +stateify savable
   126  type conn struct {
   127  	ct *ConnTrack
   128  
   129  	// original is the tuple in original direction. It is immutable.
   130  	original tuple
   131  
   132  	// reply is the tuple in reply direction.
   133  	reply tuple
   134  
   135  	finalizeOnce sync.Once
   136  	// Holds a finalizeResult.
   137  	//
   138  	// +checkatomics
   139  	finalizeResult uint32
   140  
   141  	mu sync.RWMutex `state:"nosave"`
   142  	// sourceManip indicates the source manipulation type.
   143  	//
   144  	// +checklocks:mu
   145  	sourceManip manipType
   146  	// destinationManip indicates the destination's manipulation type.
   147  	//
   148  	// +checklocks:mu
   149  	destinationManip manipType
   150  
   151  	stateMu sync.RWMutex `state:"nosave"`
   152  	// tcb is TCB control block. It is used to keep track of states
   153  	// of tcp connection.
   154  	//
   155  	// +checklocks:stateMu
   156  	tcb tcpconntrack.TCB
   157  	// lastUsed is the last time the connection saw a relevant packet, and
   158  	// is updated by each packet on the connection.
   159  	//
   160  	// +checklocks:stateMu
   161  	lastUsed tcpip.MonotonicTime
   162  }
   163  
   164  // timedOut returns whether the connection timed out based on its state.
   165  func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
   166  	cn.stateMu.RLock()
   167  	defer cn.stateMu.RUnlock()
   168  	if cn.tcb.State() == tcpconntrack.ResultAlive {
   169  		// Use the same default as Linux, which doesn't delete
   170  		// established connections for 5(!) days.
   171  		return now.Sub(cn.lastUsed) > establishedTimeout
   172  	}
   173  	// Use the same default as Linux, which lets connections in most states
   174  	// other than established remain for <= 120 seconds.
   175  	return now.Sub(cn.lastUsed) > unestablishedTimeout
   176  }
   177  
   178  // update the connection tracking state.
   179  func (cn *conn) update(pkt *PacketBuffer, reply bool) {
   180  	cn.stateMu.Lock()
   181  	defer cn.stateMu.Unlock()
   182  
   183  	// Mark the connection as having been used recently so it isn't reaped.
   184  	cn.lastUsed = cn.ct.clock.NowMonotonic()
   185  
   186  	if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
   187  		return
   188  	}
   189  
   190  	tcpHeader := header.TCP(pkt.TransportHeader().View())
   191  
   192  	// Update the state of tcb. tcb assumes it's always initialized on the
   193  	// client. However, we only need to know whether the connection is
   194  	// established or not, so the client/server distinction isn't important.
   195  	if cn.tcb.IsEmpty() {
   196  		cn.tcb.Init(tcpHeader)
   197  		return
   198  	}
   199  
   200  	if reply {
   201  		cn.tcb.UpdateStateReply(tcpHeader)
   202  	} else {
   203  		cn.tcb.UpdateStateOriginal(tcpHeader)
   204  	}
   205  }
   206  
   207  // ConnTrack tracks all connections created for NAT rules. Most users are
   208  // expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
   209  //
   210  // ConnTrack keeps all connections in a slice of buckets, each of which holds a
   211  // linked list of tuples. This gives us some desirable properties:
   212  // - Each bucket has its own lock, lessening lock contention.
   213  // - The slice is large enough that lists stay short (<10 elements on average).
   214  //   Thus traversal is fast.
   215  // - During linked list traversal we reap expired connections. This amortizes
   216  //   the cost of reaping them and makes reapUnused faster.
   217  //
   218  // Locks are ordered by their location in the buckets slice. That is, a
   219  // goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
   220  //
   221  // +stateify savable
   222  type ConnTrack struct {
   223  	// seed is a one-time random value initialized at stack startup
   224  	// and is used in the calculation of hash keys for the list of buckets.
   225  	// It is immutable.
   226  	seed uint32
   227  
   228  	// clock provides timing used to determine conntrack reapings.
   229  	clock tcpip.Clock
   230  	rand  *rand.Rand
   231  
   232  	mu sync.RWMutex `state:"nosave"`
   233  	// mu protects the buckets slice, but not buckets' contents. Only take
   234  	// the write lock if you are modifying the slice or saving for S/R.
   235  	//
   236  	// +checklocks:mu
   237  	buckets []bucket
   238  }
   239  
   240  // +stateify savable
   241  type bucket struct {
   242  	mu sync.RWMutex `state:"nosave"`
   243  	// +checklocks:mu
   244  	tuples tupleList
   245  }
   246  
   247  // A netAndTransHeadersFunc returns the network and transport headers found
   248  // in an ICMP payload. The transport layer's payload will not be returned.
   249  //
   250  // May panic if the packet does not hold the transport header.
   251  type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
   252  
   253  func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
   254  	netHdr := header.IPv4(icmpPayload)
   255  	// Do not use netHdr.Payload() as we might not hold the full packet
   256  	// in the ICMP error; Payload() panics if the buffer is smaller than
   257  	// the total length specified in the IPv4 header.
   258  	transHdr := icmpPayload[netHdr.HeaderLength():]
   259  	return netHdr, transHdr[:minTransHdrLen]
   260  }
   261  
   262  func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
   263  	netHdr := header.IPv6(icmpPayload)
   264  	// Do not use netHdr.Payload() as we might not hold the full packet
   265  	// in the ICMP error; Payload() panics if the IP payload is smaller than
   266  	// the payload length specified in the IPv6 header.
   267  	transHdr := icmpPayload[header.IPv6MinimumSize:]
   268  	return netHdr, transHdr[:minTransHdrLen]
   269  }
   270  
   271  func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
   272  	switch transProto {
   273  	case header.TCPProtocolNumber:
   274  		if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
   275  			netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
   276  			return netHeader, header.TCP(transHeaderBytes), true
   277  		}
   278  	case header.UDPProtocolNumber:
   279  		if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
   280  			netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
   281  			return netHeader, header.UDP(transHeaderBytes), true
   282  		}
   283  	}
   284  	return nil, nil, false
   285  }
   286  
   287  func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.ChecksummableTransport, isICMPError bool, ok bool) {
   288  	switch pkt.TransportProtocolNumber {
   289  	case header.TCPProtocolNumber:
   290  		if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
   291  			return pkt.Network(), tcpHeader, false, true
   292  		}
   293  	case header.UDPProtocolNumber:
   294  		if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
   295  			return pkt.Network(), udpHeader, false, true
   296  		}
   297  	case header.ICMPv4ProtocolNumber:
   298  		h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
   299  		if !ok {
   300  			panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize))
   301  		}
   302  
   303  		if header.IPv4(h).HeaderLength() > header.IPv4MinimumSize {
   304  			// TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
   305  			panic("should have dropped packets with IPv4 options")
   306  		}
   307  
   308  		if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok {
   309  			return netHdr, transHdr, true, true
   310  		}
   311  	case header.ICMPv6ProtocolNumber:
   312  		h, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
   313  		if !ok {
   314  			panic(fmt.Sprintf("should have a valid IPv6 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv6MinimumSize))
   315  		}
   316  
   317  		// We do not support extension headers in ICMP errors so the next header
   318  		// in the IPv6 packet should be a tracked protocol if we reach this point.
   319  		//
   320  		// TODO(https://gvisor.dev/issue/6789): Support extension headers.
   321  		transProto := pkt.tuple.id().transProto
   322  		if got := header.IPv6(h).TransportProtocol(); got != transProto {
   323  			panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
   324  		}
   325  
   326  		if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
   327  			return netHdr, transHdr, true, true
   328  		}
   329  	}
   330  
   331  	return nil, nil, false, false
   332  }
   333  
   334  func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID {
   335  	return tupleID{
   336  		srcAddr:    netHdr.SourceAddress(),
   337  		srcPort:    transHdr.SourcePort(),
   338  		dstAddr:    netHdr.DestinationAddress(),
   339  		dstPort:    transHdr.DestinationPort(),
   340  		transProto: transProto,
   341  		netProto:   netProto,
   342  	}
   343  }
   344  
   345  func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
   346  	if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
   347  		return tupleID{
   348  			srcAddr:    netHdr.DestinationAddress(),
   349  			srcPort:    transHdr.DestinationPort(),
   350  			dstAddr:    netHdr.SourceAddress(),
   351  			dstPort:    transHdr.SourcePort(),
   352  			transProto: transProto,
   353  			netProto:   netProto,
   354  		}, true
   355  	}
   356  
   357  	return tupleID{}, false
   358  }
   359  
   360  func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
   361  	switch pkt.TransportProtocolNumber {
   362  	case header.TCPProtocolNumber:
   363  		if transHeader := header.TCP(pkt.TransportHeader().View()); len(transHeader) >= header.TCPMinimumSize {
   364  			return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true
   365  		}
   366  	case header.UDPProtocolNumber:
   367  		if transHeader := header.UDP(pkt.TransportHeader().View()); len(transHeader) >= header.UDPMinimumSize {
   368  			return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true
   369  		}
   370  	case header.ICMPv4ProtocolNumber:
   371  		icmp := header.ICMPv4(pkt.TransportHeader().View())
   372  		if len(icmp) < header.ICMPv4MinimumSize {
   373  			return tupleID{}, false, false
   374  		}
   375  
   376  		switch icmp.Type() {
   377  		case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem:
   378  		default:
   379  			return tupleID{}, false, false
   380  		}
   381  
   382  		h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
   383  		if !ok {
   384  			return tupleID{}, false, false
   385  		}
   386  
   387  		ipv4 := header.IPv4(h)
   388  		if ipv4.HeaderLength() > header.IPv4MinimumSize {
   389  			// TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
   390  			return tupleID{}, false, false
   391  		}
   392  
   393  		if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
   394  			return tid, true, true
   395  		}
   396  	case header.ICMPv6ProtocolNumber:
   397  		icmp := header.ICMPv6(pkt.TransportHeader().View())
   398  		if len(icmp) < header.ICMPv6MinimumSize {
   399  			return tupleID{}, false, false
   400  		}
   401  
   402  		switch icmp.Type() {
   403  		case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem:
   404  		default:
   405  			return tupleID{}, false, false
   406  		}
   407  
   408  		h, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
   409  		if !ok {
   410  			return tupleID{}, false, false
   411  		}
   412  
   413  		// TODO(https://gvisor.dev/issue/6789): Handle extension headers.
   414  		if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
   415  			return tid, true, true
   416  		}
   417  	}
   418  
   419  	return tupleID{}, false, false
   420  }
   421  
   422  func (ct *ConnTrack) init() {
   423  	ct.mu.Lock()
   424  	defer ct.mu.Unlock()
   425  	ct.buckets = make([]bucket, numBuckets)
   426  }
   427  
   428  // getConnAndUpdate attempts to get a connection or creates one if no
   429  // connection exists for the packet and packet's protocol is trackable.
   430  //
   431  // If the packet's protocol is trackable, the connection's state is updated to
   432  // match the contents of the packet.
   433  func (ct *ConnTrack) getConnAndUpdate(pkt *PacketBuffer) *tuple {
   434  	// Get or (maybe) create a connection.
   435  	t := func() *tuple {
   436  		tid, isICMPError, ok := getTupleID(pkt)
   437  		if !ok {
   438  			return nil
   439  		}
   440  
   441  		bktID := ct.bucket(tid)
   442  
   443  		ct.mu.RLock()
   444  		bkt := &ct.buckets[bktID]
   445  		ct.mu.RUnlock()
   446  
   447  		now := ct.clock.NowMonotonic()
   448  		if t := bkt.connForTID(tid, now); t != nil {
   449  			return t
   450  		}
   451  
   452  		if isICMPError {
   453  			// Do not create a noop entry in response to an ICMP error.
   454  			return nil
   455  		}
   456  
   457  		bkt.mu.Lock()
   458  		defer bkt.mu.Unlock()
   459  
   460  		// Make sure a connection wasn't added between when we last checked the
   461  		// bucket and acquired the bucket's write lock.
   462  		if t := bkt.connForTIDRLocked(tid, now); t != nil {
   463  			return t
   464  		}
   465  
   466  		// This is the first packet we're seeing for the connection. Create an entry
   467  		// for this new connection.
   468  		conn := &conn{
   469  			ct:       ct,
   470  			original: tuple{tupleID: tid},
   471  			reply:    tuple{tupleID: tid.reply(), reply: true},
   472  			lastUsed: now,
   473  		}
   474  		conn.original.conn = conn
   475  		conn.reply.conn = conn
   476  
   477  		// For now, we only map an entry for the packet's original tuple as NAT may be
   478  		// performed on this connection. Until the packet goes through all the hooks
   479  		// and its final address/port is known, we cannot know what the response
   480  		// packet's addresses/ports will look like.
   481  		//
   482  		// This is okay because the destination cannot send its response until it
   483  		// receives the packet; the packet will only be received once all the hooks
   484  		// have been performed.
   485  		//
   486  		// See (*conn).finalize.
   487  		bkt.tuples.PushFront(&conn.original)
   488  		return &conn.original
   489  	}()
   490  	if t != nil {
   491  		t.conn.update(pkt, t.reply)
   492  	}
   493  	return t
   494  }
   495  
   496  func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
   497  	bktID := ct.bucket(tid)
   498  
   499  	ct.mu.RLock()
   500  	bkt := &ct.buckets[bktID]
   501  	ct.mu.RUnlock()
   502  
   503  	return bkt.connForTID(tid, ct.clock.NowMonotonic())
   504  }
   505  
   506  func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple {
   507  	bkt.mu.RLock()
   508  	defer bkt.mu.RUnlock()
   509  	return bkt.connForTIDRLocked(tid, now)
   510  }
   511  
   512  // +checklocksread:bkt.mu
   513  func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple {
   514  	for other := bkt.tuples.Front(); other != nil; other = other.Next() {
   515  		if tid == other.id() && !other.conn.timedOut(now) {
   516  			return other
   517  		}
   518  	}
   519  	return nil
   520  }
   521  
   522  func (ct *ConnTrack) finalize(cn *conn) finalizeResult {
   523  	ct.mu.RLock()
   524  	buckets := ct.buckets
   525  	ct.mu.RUnlock()
   526  
   527  	{
   528  		tid := cn.reply.id()
   529  		id := ct.bucket(tid)
   530  
   531  		bkt := &buckets[id]
   532  		bkt.mu.Lock()
   533  		if bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()) == nil {
   534  			bkt.tuples.PushFront(&cn.reply)
   535  			bkt.mu.Unlock()
   536  			return finalizeResultSuccess
   537  		}
   538  		bkt.mu.Unlock()
   539  	}
   540  
   541  	// Another connection for the reply already exists. Remove the original and
   542  	// let the caller know we failed.
   543  	//
   544  	// TODO(https://gvisor.dev/issue/6850): Investigate handling this clash
   545  	// better.
   546  
   547  	tid := cn.original.id()
   548  	id := ct.bucket(tid)
   549  	bkt := &buckets[id]
   550  	bkt.mu.Lock()
   551  	defer bkt.mu.Unlock()
   552  	bkt.tuples.Remove(&cn.original)
   553  	return finalizeResultConflict
   554  }
   555  
   556  func (cn *conn) getFinalizeResult() finalizeResult {
   557  	return finalizeResult(atomic.LoadUint32(&cn.finalizeResult))
   558  }
   559  
   560  // finalize attempts to finalize the connection and returns true iff the
   561  // connection was successfully finalized.
   562  //
   563  // If the connection failed to finalize, the caller should drop the packet
   564  // associated with the connection.
   565  //
   566  // If multiple goroutines attempt to finalize at the same time, only one
   567  // goroutine will perform the work to finalize the connection, but all
   568  // goroutines will block until the finalizing goroutine finishes finalizing.
   569  func (cn *conn) finalize() bool {
   570  	cn.finalizeOnce.Do(func() {
   571  		atomic.StoreUint32(&cn.finalizeResult, uint32(cn.ct.finalize(cn)))
   572  	})
   573  
   574  	switch res := cn.getFinalizeResult(); res {
   575  	case finalizeResultSuccess:
   576  		return true
   577  	case finalizeResultConflict:
   578  		return false
   579  	default:
   580  		panic(fmt.Sprintf("unhandled result = %d", res))
   581  	}
   582  }
   583  
   584  func (cn *conn) maybePerformNoopNAT(dnat bool) {
   585  	cn.mu.Lock()
   586  	defer cn.mu.Unlock()
   587  
   588  	var manip *manipType
   589  	if dnat {
   590  		manip = &cn.destinationManip
   591  	} else {
   592  		manip = &cn.sourceManip
   593  	}
   594  
   595  	if *manip == manipNotPerformed {
   596  		*manip = manipPerformedNoop
   597  	}
   598  }
   599  
   600  type portRange struct {
   601  	start uint16
   602  	size  uint16
   603  }
   604  
   605  // performNAT setups up the connection for the specified NAT and rewrites the
   606  // packet.
   607  //
   608  // If NAT has already been performed on the connection, then the packet will
   609  // be rewritten with the NAT performed on the connection, ignoring the passed
   610  // address and port range.
   611  //
   612  // Generally, only the first packet of a connection reaches this method; other
   613  // packets will be manipulated without needing to modify the connection.
   614  func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, ports portRange, natAddress tcpip.Address, dnat bool) {
   615  	// Make sure the packet is re-written after performing NAT.
   616  	defer func() {
   617  		// handlePacket returns true if the packet may skip the NAT table as the
   618  		// connection is already NATed, but if we reach this point we must be in the
   619  		// NAT table, so the return value is useless for us.
   620  		_ = cn.handlePacket(pkt, hook, r)
   621  	}()
   622  
   623  	cn.mu.Lock()
   624  	defer cn.mu.Unlock()
   625  
   626  	cn.reply.mu.Lock()
   627  	defer cn.reply.mu.Unlock()
   628  
   629  	var manip *manipType
   630  	var address *tcpip.Address
   631  	var port *uint16
   632  	if dnat {
   633  		manip = &cn.destinationManip
   634  		address = &cn.reply.tupleID.srcAddr
   635  		port = &cn.reply.tupleID.srcPort
   636  	} else {
   637  		manip = &cn.sourceManip
   638  		address = &cn.reply.tupleID.dstAddr
   639  		port = &cn.reply.tupleID.dstPort
   640  	}
   641  
   642  	if *manip != manipNotPerformed {
   643  		return
   644  	}
   645  	*manip = manipPerformed
   646  	*address = natAddress
   647  
   648  	// Does the current port fit in the range?
   649  	if end := ports.start + ports.size - 1; *port >= ports.start && *port <= end {
   650  		// Yes, is the current reply tuple unique?
   651  		if other := cn.ct.connForTID(cn.reply.tupleID); other == nil {
   652  			// Yes! No need to change the port.
   653  			return
   654  		}
   655  	}
   656  
   657  	// Try our best to find a port that results in a unique reply tuple.
   658  	//
   659  	// We limit the number of attempts to find a unique tuple to not waste a lot
   660  	// of time looking for a unique tuple.
   661  	//
   662  	// Matches linux behaviour introduced in
   663  	// https://github.com/torvalds/linux/commit/a504b703bb1da526a01593da0e4be2af9d9f5fa8.
   664  	const maxAttemptsForInitialRound uint16 = 128
   665  	const minAttemptsToContinue = 16
   666  
   667  	allowedInitialAttempts := maxAttemptsForInitialRound
   668  	if allowedInitialAttempts > ports.size {
   669  		allowedInitialAttempts = ports.size
   670  	}
   671  
   672  	for maxAttempts := allowedInitialAttempts; ; maxAttempts /= 2 {
   673  		// Start reach round with a random initial port in the range.
   674  		initial := ports.start + uint16(cn.ct.rand.Uint32())%ports.size
   675  
   676  		for i := uint16(0); i < maxAttempts; i++ {
   677  			*port = initial + i%ports.size
   678  
   679  			if other := cn.ct.connForTID(cn.reply.tupleID); other == nil {
   680  				// We found a unique tuple!
   681  				return
   682  			}
   683  		}
   684  
   685  		if maxAttempts == ports.size {
   686  			// We already tried all the ports in the range so no need to keep trying.
   687  			return
   688  		}
   689  
   690  		if maxAttempts < minAttemptsToContinue {
   691  			return
   692  		}
   693  	}
   694  
   695  	// We did not find a unique tuple, use the last used port anyways.
   696  	// TODO(https://gvisor.dev/issue/6850): Handle not finding a unique tuple
   697  	// better (e.g. remove the connection and drop the packet).
   698  }
   699  
   700  // handlePacket attempts to handle a packet and perform NAT if the connection
   701  // has had NAT performed on it.
   702  //
   703  // Returns true if the packet can skip the NAT table.
   704  func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
   705  	netHdr, transHdr, isICMPError, ok := getHeaders(pkt)
   706  	if !ok {
   707  		return false
   708  	}
   709  
   710  	fullChecksum := false
   711  	updatePseudoHeader := false
   712  	natDone := &pkt.SNATDone
   713  	dnat := false
   714  	switch hook {
   715  	case Prerouting:
   716  		// Packet came from outside the stack so it must have a checksum set
   717  		// already.
   718  		fullChecksum = true
   719  		updatePseudoHeader = true
   720  
   721  		natDone = &pkt.DNATDone
   722  		dnat = true
   723  	case Input:
   724  	case Forward:
   725  		panic("should not handle packet in the forwarding hook")
   726  	case Output:
   727  		natDone = &pkt.DNATDone
   728  		dnat = true
   729  		fallthrough
   730  	case Postrouting:
   731  		if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
   732  			updatePseudoHeader = true
   733  		} else if rt.RequiresTXTransportChecksum() {
   734  			fullChecksum = true
   735  			updatePseudoHeader = true
   736  		}
   737  	default:
   738  		panic(fmt.Sprintf("unrecognized hook = %d", hook))
   739  	}
   740  
   741  	if *natDone {
   742  		panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt))
   743  	}
   744  
   745  	// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
   746  	// validated if checksum offloading is off. It may require IP defrag if the
   747  	// packets are fragmented.
   748  
   749  	reply := pkt.tuple.reply
   750  
   751  	tid, manip := func() (tupleID, manipType) {
   752  		cn.mu.RLock()
   753  		defer cn.mu.RUnlock()
   754  
   755  		if reply {
   756  			tid := cn.original.id()
   757  
   758  			if dnat {
   759  				return tid, cn.sourceManip
   760  			}
   761  			return tid, cn.destinationManip
   762  		}
   763  
   764  		tid := cn.reply.id()
   765  		if dnat {
   766  			return tid, cn.destinationManip
   767  		}
   768  		return tid, cn.sourceManip
   769  	}()
   770  	switch manip {
   771  	case manipNotPerformed:
   772  		return false
   773  	case manipPerformedNoop:
   774  		*natDone = true
   775  		return true
   776  	case manipPerformed:
   777  	default:
   778  		panic(fmt.Sprintf("unhandled manip = %d", manip))
   779  	}
   780  
   781  	newPort := tid.dstPort
   782  	newAddr := tid.dstAddr
   783  	if dnat {
   784  		newPort = tid.srcPort
   785  		newAddr = tid.srcAddr
   786  	}
   787  
   788  	rewritePacket(
   789  		netHdr,
   790  		transHdr,
   791  		!dnat != isICMPError,
   792  		fullChecksum,
   793  		updatePseudoHeader,
   794  		newPort,
   795  		newAddr,
   796  	)
   797  
   798  	*natDone = true
   799  
   800  	if !isICMPError {
   801  		return true
   802  	}
   803  
   804  	// We performed NAT on (erroneous) packet that triggered an ICMP response, but
   805  	// not the ICMP packet itself.
   806  	switch pkt.TransportProtocolNumber {
   807  	case header.ICMPv4ProtocolNumber:
   808  		icmp := header.ICMPv4(pkt.TransportHeader().View())
   809  		// TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum.
   810  		icmp.SetChecksum(0)
   811  		icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().AsRange().Checksum()))
   812  
   813  		network := header.IPv4(pkt.NetworkHeader().View())
   814  		if dnat {
   815  			network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr)
   816  		} else {
   817  			network.SetSourceAddressWithChecksumUpdate(tid.dstAddr)
   818  		}
   819  	case header.ICMPv6ProtocolNumber:
   820  		network := header.IPv6(pkt.NetworkHeader().View())
   821  		srcAddr := network.SourceAddress()
   822  		dstAddr := network.DestinationAddress()
   823  		if dnat {
   824  			dstAddr = tid.srcAddr
   825  		} else {
   826  			srcAddr = tid.dstAddr
   827  		}
   828  
   829  		icmp := header.ICMPv6(pkt.TransportHeader().View())
   830  		// TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum.
   831  		icmp.SetChecksum(0)
   832  		payload := pkt.Data()
   833  		icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   834  			Header:      icmp,
   835  			Src:         srcAddr,
   836  			Dst:         dstAddr,
   837  			PayloadCsum: payload.AsRange().Checksum(),
   838  			PayloadLen:  payload.Size(),
   839  		}))
   840  
   841  		if dnat {
   842  			network.SetDestinationAddress(dstAddr)
   843  		} else {
   844  			network.SetSourceAddress(srcAddr)
   845  		}
   846  	}
   847  
   848  	return true
   849  }
   850  
   851  // bucket gets the conntrack bucket for a tupleID.
   852  func (ct *ConnTrack) bucket(id tupleID) int {
   853  	h := jenkins.Sum32(ct.seed)
   854  	h.Write([]byte(id.srcAddr))
   855  	h.Write([]byte(id.dstAddr))
   856  	shortBuf := make([]byte, 2)
   857  	binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
   858  	h.Write([]byte(shortBuf))
   859  	binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
   860  	h.Write([]byte(shortBuf))
   861  	binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
   862  	h.Write([]byte(shortBuf))
   863  	binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
   864  	h.Write([]byte(shortBuf))
   865  	ct.mu.RLock()
   866  	defer ct.mu.RUnlock()
   867  	return int(h.Sum32()) % len(ct.buckets)
   868  }
   869  
   870  // reapUnused deletes timed out entries from the conntrack map. The rules for
   871  // reaping are:
   872  // - Each call to reapUnused traverses a fraction of the conntrack table.
   873  //   Specifically, it traverses len(ct.buckets)/fractionPerReaping.
   874  // - After reaping, reapUnused decides when it should next run based on the
   875  //   ratio of expired connections to examined connections. If the ratio is
   876  //   greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
   877  //   slightly increases the interval between runs.
   878  // - maxFullTraversal caps the time it takes to traverse the entire table.
   879  //
   880  // reapUnused returns the next bucket that should be checked and the time after
   881  // which it should be called again.
   882  func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
   883  	const fractionPerReaping = 128
   884  	const maxExpiredPct = 50
   885  	const maxFullTraversal = 60 * time.Second
   886  	const minInterval = 10 * time.Millisecond
   887  	const maxInterval = maxFullTraversal / fractionPerReaping
   888  
   889  	now := ct.clock.NowMonotonic()
   890  	checked := 0
   891  	expired := 0
   892  	var idx int
   893  	ct.mu.RLock()
   894  	defer ct.mu.RUnlock()
   895  	for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
   896  		idx = (i + start) % len(ct.buckets)
   897  		bkt := &ct.buckets[idx]
   898  		bkt.mu.Lock()
   899  		for tuple := bkt.tuples.Front(); tuple != nil; {
   900  			// reapTupleLocked updates tuple's next pointer so we grab it here.
   901  			nextTuple := tuple.Next()
   902  
   903  			checked++
   904  			if ct.reapTupleLocked(tuple, idx, bkt, now) {
   905  				expired++
   906  			}
   907  
   908  			tuple = nextTuple
   909  		}
   910  		bkt.mu.Unlock()
   911  	}
   912  	// We already checked buckets[idx].
   913  	idx++
   914  
   915  	// If half or more of the connections are expired, the table has gotten
   916  	// stale. Reschedule quickly.
   917  	expiredPct := 0
   918  	if checked != 0 {
   919  		expiredPct = expired * 100 / checked
   920  	}
   921  	if expiredPct > maxExpiredPct {
   922  		return idx, minInterval
   923  	}
   924  	if interval := prevInterval + minInterval; interval <= maxInterval {
   925  		// Increment the interval between runs.
   926  		return idx, interval
   927  	}
   928  	// We've hit the maximum interval.
   929  	return idx, maxInterval
   930  }
   931  
   932  // reapTupleLocked tries to remove tuple and its reply from the table. It
   933  // returns whether the tuple's connection has timed out.
   934  //
   935  // Precondition: ct.mu is read locked and bkt.mu is write locked.
   936  // +checklocksread:ct.mu
   937  // +checklocks:bkt.mu
   938  func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
   939  	if !reapingTuple.conn.timedOut(now) {
   940  		return false
   941  	}
   942  
   943  	var otherTuple *tuple
   944  	if reapingTuple.reply {
   945  		otherTuple = &reapingTuple.conn.original
   946  	} else {
   947  		otherTuple = &reapingTuple.conn.reply
   948  	}
   949  
   950  	otherTupleBktID := ct.bucket(otherTuple.id())
   951  	replyTupleInserted := reapingTuple.conn.getFinalizeResult() == finalizeResultSuccess
   952  
   953  	// To maintain lock order, we can only reap both tuples if the tuple for the
   954  	// other direction appears later in the table.
   955  	if bktID > otherTupleBktID && replyTupleInserted {
   956  		return true
   957  	}
   958  
   959  	bkt.tuples.Remove(reapingTuple)
   960  
   961  	if !replyTupleInserted {
   962  		// The other tuple is the reply which has not yet been inserted.
   963  		return true
   964  	}
   965  
   966  	// Reap the other connection.
   967  	if bktID == otherTupleBktID {
   968  		// Don't re-lock if both tuples are in the same bucket.
   969  		bkt.tuples.Remove(otherTuple)
   970  	} else {
   971  		otherTupleBkt := &ct.buckets[otherTupleBktID]
   972  		otherTupleBkt.mu.Lock()
   973  		otherTupleBkt.tuples.Remove(otherTuple)
   974  		otherTupleBkt.mu.Unlock()
   975  	}
   976  
   977  	return true
   978  }
   979  
   980  func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
   981  	// Lookup the connection. The reply's original destination
   982  	// describes the original address.
   983  	tid := tupleID{
   984  		srcAddr:    epID.LocalAddress,
   985  		srcPort:    epID.LocalPort,
   986  		dstAddr:    epID.RemoteAddress,
   987  		dstPort:    epID.RemotePort,
   988  		transProto: transProto,
   989  		netProto:   netProto,
   990  	}
   991  	t := ct.connForTID(tid)
   992  	if t == nil {
   993  		// Not a tracked connection.
   994  		return "", 0, &tcpip.ErrNotConnected{}
   995  	}
   996  
   997  	t.conn.mu.RLock()
   998  	defer t.conn.mu.RUnlock()
   999  	if t.conn.destinationManip == manipNotPerformed {
  1000  		// Unmanipulated destination.
  1001  		return "", 0, &tcpip.ErrInvalidOptionValue{}
  1002  	}
  1003  
  1004  	id := t.conn.original.id()
  1005  	return id.dstAddr, id.dstPort, nil
  1006  }