github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/tcpip/stack/gro.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"time"
    21  
    22  	"github.com/metacubex/gvisor/pkg/atomicbitops"
    23  	"github.com/metacubex/gvisor/pkg/sync"
    24  	"github.com/metacubex/gvisor/pkg/tcpip"
    25  	"github.com/metacubex/gvisor/pkg/tcpip/header"
    26  )
    27  
    28  // TODO(b/256037250): Enable by default.
    29  // TODO(b/256037250): We parse headers here. We should save those headers in
    30  // PacketBuffers so they don't have to be re-parsed later.
    31  // TODO(b/256037250): I still see the occasional SACK block in the zero-loss
    32  // benchmark, which should not happen.
    33  // TODO(b/256037250): Some dispatchers, e.g. XDP and RecvMmsg, can receive
    34  // multiple packets at a time. Even if the GRO interval is 0, there is an
    35  // opportunity for coalescing.
    36  // TODO(b/256037250): We're doing some header parsing here, which presents the
    37  // opportunity to skip it later.
    38  // TODO(b/256037250): We may be able to remove locking by pairing
    39  // groDispatchers with link endpoint dispatchers.
    40  
    41  const (
    42  	// groNBuckets is the number of GRO buckets.
    43  	groNBuckets = 8
    44  
    45  	groNBucketsMask = groNBuckets - 1
    46  
    47  	// groBucketSize is the size of each GRO bucket.
    48  	groBucketSize = 8
    49  
    50  	// groMaxPacketSize is the maximum size of a GRO'd packet.
    51  	groMaxPacketSize = 1 << 16 // 65KB.
    52  )
    53  
    54  // A groBucket holds packets that are undergoing GRO.
    55  type groBucket struct {
    56  	// mu protects the fields of a bucket.
    57  	mu sync.Mutex
    58  
    59  	// count is the number of packets in the bucket.
    60  	// +checklocks:mu
    61  	count int
    62  
    63  	// packets is the linked list of packets.
    64  	// +checklocks:mu
    65  	packets groPacketList
    66  
    67  	// packetsPrealloc and allocIdxs are used to preallocate and reuse
    68  	// groPacket structs and avoid allocation.
    69  	// +checklocks:mu
    70  	packetsPrealloc [groBucketSize]groPacket
    71  
    72  	// +checklocks:mu
    73  	allocIdxs [groBucketSize]int
    74  }
    75  
    76  // +checklocks:gb.mu
    77  func (gb *groBucket) full() bool {
    78  	return gb.count == groBucketSize
    79  }
    80  
    81  // insert inserts pkt into the bucket.
    82  // +checklocks:gb.mu
    83  func (gb *groBucket) insert(pkt *PacketBuffer, ipHdr []byte, tcpHdr header.TCP, ep NetworkEndpoint) {
    84  	groPkt := &gb.packetsPrealloc[gb.allocIdxs[gb.count]]
    85  	*groPkt = groPacket{
    86  		pkt:           pkt,
    87  		created:       time.Now(),
    88  		ep:            ep,
    89  		ipHdr:         ipHdr,
    90  		tcpHdr:        tcpHdr,
    91  		initialLength: pkt.Data().Size(), // pkt.Data() contains network header.
    92  		idx:           groPkt.idx,
    93  	}
    94  	gb.count++
    95  	gb.packets.PushBack(groPkt)
    96  }
    97  
    98  // removeOldest removes the oldest packet from gb and returns the contained
    99  // *PacketBuffer. gb must not be empty.
   100  // +checklocks:gb.mu
   101  func (gb *groBucket) removeOldest() *PacketBuffer {
   102  	pkt := gb.packets.Front()
   103  	gb.packets.Remove(pkt)
   104  	gb.count--
   105  	gb.allocIdxs[gb.count] = pkt.idx
   106  	ret := pkt.pkt
   107  	pkt.reset()
   108  	return ret
   109  }
   110  
   111  // removeOne removes a packet from gb. It also resets pkt to its zero value.
   112  // +checklocks:gb.mu
   113  func (gb *groBucket) removeOne(pkt *groPacket) {
   114  	gb.packets.Remove(pkt)
   115  	gb.count--
   116  	gb.allocIdxs[gb.count] = pkt.idx
   117  	pkt.reset()
   118  }
   119  
   120  // findGROPacket4 returns the groPkt that matches ipHdr and tcpHdr, or nil if
   121  // none exists. It also returns whether the groPkt should be flushed based on
   122  // differences between the two headers.
   123  // +checklocks:gb.mu
   124  func (gb *groBucket) findGROPacket4(pkt *PacketBuffer, ipHdr header.IPv4, tcpHdr header.TCP, ep NetworkEndpoint) (*groPacket, bool) {
   125  	for groPkt := gb.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
   126  		// Do the addresses match?
   127  		groIPHdr := header.IPv4(groPkt.ipHdr)
   128  		if ipHdr.SourceAddress() != groIPHdr.SourceAddress() || ipHdr.DestinationAddress() != groIPHdr.DestinationAddress() {
   129  			continue
   130  		}
   131  
   132  		// Do the ports match?
   133  		if tcpHdr.SourcePort() != groPkt.tcpHdr.SourcePort() || tcpHdr.DestinationPort() != groPkt.tcpHdr.DestinationPort() {
   134  			continue
   135  		}
   136  
   137  		// We've found a packet of the same flow.
   138  
   139  		// IP checks.
   140  		TOS, _ := ipHdr.TOS()
   141  		groTOS, _ := groIPHdr.TOS()
   142  		if ipHdr.TTL() != groIPHdr.TTL() || TOS != groTOS {
   143  			return groPkt, true
   144  		}
   145  
   146  		// TCP checks.
   147  		if shouldFlushTCP(groPkt, tcpHdr) {
   148  			return groPkt, true
   149  		}
   150  
   151  		// There's an upper limit on coalesced packet size.
   152  		if pkt.Data().Size()-header.IPv4MinimumSize-int(tcpHdr.DataOffset())+groPkt.pkt.Data().Size() >= groMaxPacketSize {
   153  			return groPkt, true
   154  		}
   155  
   156  		return groPkt, false
   157  	}
   158  
   159  	return nil, false
   160  }
   161  
   162  // findGROPacket6 returns the groPkt that matches ipHdr and tcpHdr, or nil if
   163  // none exists. It also returns whether the groPkt should be flushed based on
   164  // differences between the two headers.
   165  // +checklocks:gb.mu
   166  func (gb *groBucket) findGROPacket6(pkt *PacketBuffer, ipHdr header.IPv6, tcpHdr header.TCP, ep NetworkEndpoint) (*groPacket, bool) {
   167  	for groPkt := gb.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
   168  		// Do the addresses match?
   169  		groIPHdr := header.IPv6(groPkt.ipHdr)
   170  		if ipHdr.SourceAddress() != groIPHdr.SourceAddress() || ipHdr.DestinationAddress() != groIPHdr.DestinationAddress() {
   171  			continue
   172  		}
   173  
   174  		// Need to check that headers are the same except:
   175  		// - Traffic class, a difference of which causes a flush.
   176  		// - Hop limit, a difference of which causes a flush.
   177  		// - Length, which is checked later.
   178  		// - Version, which is checked by an earlier call to IsValid().
   179  		trafficClass, flowLabel := ipHdr.TOS()
   180  		groTrafficClass, groFlowLabel := groIPHdr.TOS()
   181  		if flowLabel != groFlowLabel || ipHdr.NextHeader() != groIPHdr.NextHeader() {
   182  			continue
   183  		}
   184  		// Unlike IPv4, IPv6 packets with extension headers can be coalesced.
   185  		if !bytes.Equal(ipHdr[header.IPv6MinimumSize:], groIPHdr[header.IPv6MinimumSize:]) {
   186  			continue
   187  		}
   188  
   189  		// Do the ports match?
   190  		if tcpHdr.SourcePort() != groPkt.tcpHdr.SourcePort() || tcpHdr.DestinationPort() != groPkt.tcpHdr.DestinationPort() {
   191  			continue
   192  		}
   193  
   194  		// We've found a packet of the same flow.
   195  
   196  		// TCP checks.
   197  		if shouldFlushTCP(groPkt, tcpHdr) {
   198  			return groPkt, true
   199  		}
   200  
   201  		// Do the traffic class and hop limit match?
   202  		if trafficClass != groTrafficClass || ipHdr.HopLimit() != groIPHdr.HopLimit() {
   203  			return groPkt, true
   204  		}
   205  
   206  		// This limit is artificial for IPv6 -- we could allow even
   207  		// larger packets via jumbograms.
   208  		if pkt.Data().Size()-len(ipHdr)-int(tcpHdr.DataOffset())+groPkt.pkt.Data().Size() >= groMaxPacketSize {
   209  			return groPkt, true
   210  		}
   211  
   212  		return groPkt, false
   213  	}
   214  
   215  	return nil, false
   216  }
   217  
   218  // +checklocks:gb.mu
   219  func (gb *groBucket) found(gd *groDispatcher, groPkt *groPacket, flushGROPkt bool, pkt *PacketBuffer, ipHdr []byte, tcpHdr header.TCP, ep NetworkEndpoint, updateIPHdr func([]byte, int)) {
   220  	// Flush groPkt or merge the packets.
   221  	pktSize := pkt.Data().Size()
   222  	flags := tcpHdr.Flags()
   223  	dataOff := tcpHdr.DataOffset()
   224  	tcpPayloadSize := pkt.Data().Size() - len(ipHdr) - int(dataOff)
   225  	if flushGROPkt {
   226  		// Flush the existing GRO packet. Don't hold bucket.mu while
   227  		// processing the packet.
   228  		pkt := groPkt.pkt
   229  		gb.removeOne(groPkt)
   230  		gb.mu.Unlock()
   231  		ep.HandlePacket(pkt)
   232  		pkt.DecRef()
   233  		gb.mu.Lock()
   234  		groPkt = nil
   235  	} else if groPkt != nil {
   236  		// Merge pkt in to GRO packet.
   237  		pkt.Data().TrimFront(len(ipHdr) + int(dataOff))
   238  		groPkt.pkt.Data().Merge(pkt.Data())
   239  		// Update the IP total length.
   240  		updateIPHdr(groPkt.ipHdr, tcpPayloadSize)
   241  		// Add flags from the packet to the GRO packet.
   242  		groPkt.tcpHdr.SetFlags(uint8(groPkt.tcpHdr.Flags() | (flags & (header.TCPFlagFin | header.TCPFlagPsh))))
   243  
   244  		pkt = nil
   245  	}
   246  
   247  	// Flush if the packet isn't the same size as the previous packets or
   248  	// if certain flags are set. The reason for checking size equality is:
   249  	// - If the packet is smaller than the others, this is likely the end
   250  	//   of some message. Peers will send MSS-sized packets until they have
   251  	//   insufficient data to do so.
   252  	// - If the packet is larger than the others, this packet is either
   253  	//   malformed, a local GSO packet, or has already been handled by host
   254  	//   GRO.
   255  	flush := header.TCPFlags(flags)&(header.TCPFlagUrg|header.TCPFlagPsh|header.TCPFlagRst|header.TCPFlagSyn|header.TCPFlagFin) != 0
   256  	flush = flush || tcpPayloadSize == 0
   257  	if groPkt != nil {
   258  		flush = flush || pktSize != groPkt.initialLength
   259  	}
   260  
   261  	switch {
   262  	case flush && groPkt != nil:
   263  		// A merge occurred and we need to flush groPkt.
   264  		pkt := groPkt.pkt
   265  		gb.removeOne(groPkt)
   266  		gb.mu.Unlock()
   267  		ep.HandlePacket(pkt)
   268  		pkt.DecRef()
   269  	case flush && groPkt == nil:
   270  		// No merge occurred and the incoming packet needs to be flushed.
   271  		gb.mu.Unlock()
   272  		ep.HandlePacket(pkt)
   273  	case !flush && groPkt == nil:
   274  		// New flow and we don't need to flush. Insert pkt into GRO.
   275  		if gb.full() {
   276  			// Head is always the oldest packet
   277  			toFlush := gb.removeOldest()
   278  			gb.insert(pkt.IncRef(), ipHdr, tcpHdr, ep)
   279  			gb.mu.Unlock()
   280  			ep.HandlePacket(toFlush)
   281  			toFlush.DecRef()
   282  		} else {
   283  			gb.insert(pkt.IncRef(), ipHdr, tcpHdr, ep)
   284  			gb.mu.Unlock()
   285  		}
   286  	default:
   287  		// A merge occurred and we don't need to flush anything.
   288  		gb.mu.Unlock()
   289  	}
   290  
   291  	// Schedule a timer if we never had one set before.
   292  	if gd.flushTimerState.CompareAndSwap(flushTimerUnset, flushTimerSet) {
   293  		gd.flushTimer.Reset(gd.getInterval())
   294  	}
   295  }
   296  
   297  // A groPacket is packet undergoing GRO. It may be several packets coalesced
   298  // together.
   299  type groPacket struct {
   300  	// groPacketEntry is an intrusive list.
   301  	groPacketEntry
   302  
   303  	// pkt is the coalesced packet.
   304  	pkt *PacketBuffer
   305  
   306  	// ipHdr is the IP (v4 or v6) header for the coalesced packet.
   307  	ipHdr []byte
   308  
   309  	// tcpHdr is the TCP header for the coalesced packet.
   310  	tcpHdr header.TCP
   311  
   312  	// created is when the packet was received.
   313  	created time.Time
   314  
   315  	// ep is the endpoint to which the packet will be sent after GRO.
   316  	ep NetworkEndpoint
   317  
   318  	// initialLength is the length of the first packet in the flow. It is
   319  	// used as a best-effort guess at MSS: senders will send MSS-sized
   320  	// packets until they run out of data, so we coalesce as long as
   321  	// packets are the same size.
   322  	initialLength int
   323  
   324  	// idx is the groPacket's index in its bucket packetsPrealloc. It is
   325  	// immutable.
   326  	idx int
   327  }
   328  
   329  // reset resets all mutable fields of the groPacket.
   330  func (pk *groPacket) reset() {
   331  	*pk = groPacket{
   332  		idx: pk.idx,
   333  	}
   334  }
   335  
   336  // payloadSize is the payload size of the coalesced packet, which does not
   337  // include the network or transport headers.
   338  func (pk *groPacket) payloadSize() int {
   339  	return pk.pkt.Data().Size() - len(pk.ipHdr) - int(pk.tcpHdr.DataOffset())
   340  }
   341  
   342  // Values held in groDispatcher.flushTimerState.
   343  const (
   344  	flushTimerUnset = iota
   345  	flushTimerSet
   346  	flushTimerClosed
   347  )
   348  
   349  // groDispatcher coalesces incoming packets to increase throughput.
   350  type groDispatcher struct {
   351  	// intervalNS is the interval in nanoseconds.
   352  	intervalNS atomicbitops.Int64
   353  
   354  	buckets [groNBuckets]groBucket
   355  
   356  	flushTimerState atomicbitops.Int32
   357  	flushTimer      *time.Timer
   358  }
   359  
   360  func (gd *groDispatcher) init(interval time.Duration) {
   361  	gd.intervalNS.Store(interval.Nanoseconds())
   362  
   363  	for i := range gd.buckets {
   364  		bucket := &gd.buckets[i]
   365  		bucket.mu.Lock()
   366  		for j := range bucket.packetsPrealloc {
   367  			bucket.allocIdxs[j] = j
   368  			bucket.packetsPrealloc[j].idx = j
   369  		}
   370  		bucket.mu.Unlock()
   371  	}
   372  
   373  	// Create a timer to fire far from now and cancel it immediately.
   374  	//
   375  	// The timer will be reset when there is a need for it to fire.
   376  	gd.flushTimer = time.AfterFunc(time.Hour, func() {
   377  		if !gd.flushTimerState.CompareAndSwap(flushTimerSet, flushTimerUnset) {
   378  			// Timer was unset or GRO is closed, do nothing further.
   379  			return
   380  		}
   381  
   382  		interval := gd.getInterval()
   383  		if interval == 0 {
   384  			gd.flushAll()
   385  			return
   386  		}
   387  
   388  		if gd.flush() && gd.flushTimerState.CompareAndSwap(flushTimerUnset, flushTimerSet) {
   389  			// Only reset the timer if we have more packets and the timer was
   390  			// previously unset. If we have no packets left, the timer is already set
   391  			// or GRO is being closed, do not reset the timer.
   392  			gd.flushTimer.Reset(interval)
   393  		}
   394  	})
   395  	gd.flushTimer.Stop()
   396  }
   397  
   398  func (gd *groDispatcher) getInterval() time.Duration {
   399  	return time.Duration(gd.intervalNS.Load()) * time.Nanosecond
   400  }
   401  
   402  // setInterval is not thread-safe and so much be protected by callers.
   403  func (gd *groDispatcher) setInterval(interval time.Duration) {
   404  	gd.intervalNS.Store(interval.Nanoseconds())
   405  
   406  	if gd.flushTimerState.Load() == flushTimerSet {
   407  		// Timer was previously set, reset it.
   408  		gd.flushTimer.Reset(interval)
   409  	}
   410  }
   411  
   412  // dispatch sends pkt up the stack after it undergoes GRO coalescing.
   413  func (gd *groDispatcher) dispatch(pkt *PacketBuffer, netProto tcpip.NetworkProtocolNumber, ep NetworkEndpoint) {
   414  	// If GRO is disabled simply pass the packet along.
   415  	if gd.getInterval() == 0 {
   416  		ep.HandlePacket(pkt)
   417  		return
   418  	}
   419  
   420  	switch netProto {
   421  	case header.IPv4ProtocolNumber:
   422  		gd.dispatch4(pkt, ep)
   423  	case header.IPv6ProtocolNumber:
   424  		gd.dispatch6(pkt, ep)
   425  	default:
   426  		// We can't GRO this.
   427  		ep.HandlePacket(pkt)
   428  	}
   429  }
   430  
   431  func (gd *groDispatcher) dispatch4(pkt *PacketBuffer, ep NetworkEndpoint) {
   432  	// Immediately get the IPv4 and TCP headers. We need a way to hash the
   433  	// packet into its bucket, which requires addresses and ports. Linux
   434  	// simply gets a hash passed by hardware, but we're not so lucky.
   435  
   436  	// We only GRO TCP packets. The check for the transport protocol number
   437  	// is done below so that we can PullUp both the IP and TCP headers
   438  	// together.
   439  	hdrBytes, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize)
   440  	if !ok {
   441  		ep.HandlePacket(pkt)
   442  		return
   443  	}
   444  	ipHdr := header.IPv4(hdrBytes)
   445  
   446  	// We don't handle fragments. That should be the vast majority of
   447  	// traffic, and simplifies handling.
   448  	if ipHdr.FragmentOffset() != 0 || ipHdr.Flags()&header.IPv4FlagMoreFragments != 0 {
   449  		ep.HandlePacket(pkt)
   450  		return
   451  	}
   452  
   453  	// We only handle TCP packets without IP options.
   454  	if ipHdr.HeaderLength() != header.IPv4MinimumSize || tcpip.TransportProtocolNumber(ipHdr.Protocol()) != header.TCPProtocolNumber {
   455  		ep.HandlePacket(pkt)
   456  		return
   457  	}
   458  	tcpHdr := header.TCP(hdrBytes[header.IPv4MinimumSize:])
   459  	ipHdr = ipHdr[:header.IPv4MinimumSize]
   460  	dataOff := tcpHdr.DataOffset()
   461  	if dataOff < header.TCPMinimumSize {
   462  		// Malformed packet: will be handled further up the stack.
   463  		ep.HandlePacket(pkt)
   464  		return
   465  	}
   466  	hdrBytes, ok = pkt.Data().PullUp(header.IPv4MinimumSize + int(dataOff))
   467  	if !ok {
   468  		// Malformed packet: will be handled further up the stack.
   469  		ep.HandlePacket(pkt)
   470  		return
   471  	}
   472  
   473  	tcpHdr = header.TCP(hdrBytes[header.IPv4MinimumSize:])
   474  
   475  	// If either checksum is bad, flush the packet. Since we don't know
   476  	// what bits were flipped, we can't identify this packet with a flow.
   477  	if !pkt.RXChecksumValidated {
   478  		if !ipHdr.IsValid(pkt.Data().Size()) || !ipHdr.IsChecksumValid() {
   479  			ep.HandlePacket(pkt)
   480  			return
   481  		}
   482  		payloadChecksum := pkt.Data().ChecksumAtOffset(header.IPv4MinimumSize + int(dataOff))
   483  		tcpPayloadSize := pkt.Data().Size() - header.IPv4MinimumSize - int(dataOff)
   484  		if !tcpHdr.IsChecksumValid(ipHdr.SourceAddress(), ipHdr.DestinationAddress(), payloadChecksum, uint16(tcpPayloadSize)) {
   485  			ep.HandlePacket(pkt)
   486  			return
   487  		}
   488  		// We've validated the checksum, no reason for others to do it
   489  		// again.
   490  		pkt.RXChecksumValidated = true
   491  	}
   492  
   493  	// Now we can get the bucket for the packet.
   494  	bucket := &gd.buckets[gd.bucketForPacket(ipHdr, tcpHdr)&groNBucketsMask]
   495  	bucket.mu.Lock()
   496  	groPkt, flushGROPkt := bucket.findGROPacket4(pkt, ipHdr, tcpHdr, ep)
   497  	bucket.found(gd, groPkt, flushGROPkt, pkt, ipHdr, tcpHdr, ep, updateIPv4Hdr)
   498  }
   499  
   500  func (gd *groDispatcher) dispatch6(pkt *PacketBuffer, ep NetworkEndpoint) {
   501  	// Immediately get the IPv6 and TCP headers. We need a way to hash the
   502  	// packet into its bucket, which requires addresses and ports. Linux
   503  	// simply gets a hash passed by hardware, but we're not so lucky.
   504  
   505  	hdrBytes, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
   506  	if !ok {
   507  		ep.HandlePacket(pkt)
   508  		return
   509  	}
   510  	ipHdr := header.IPv6(hdrBytes)
   511  
   512  	// Getting the IP header (+ extension headers) size is a bit of a pain
   513  	// on IPv6.
   514  	transProto := tcpip.TransportProtocolNumber(ipHdr.NextHeader())
   515  	buf := pkt.Data().ToBuffer()
   516  	buf.TrimFront(header.IPv6MinimumSize)
   517  	it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(transProto), buf)
   518  	ipHdrSize := int(header.IPv6MinimumSize)
   519  	for {
   520  		transProto = tcpip.TransportProtocolNumber(it.NextHeaderIdentifier())
   521  		extHdr, done, err := it.Next()
   522  		if err != nil {
   523  			ep.HandlePacket(pkt)
   524  			return
   525  		}
   526  		if done {
   527  			break
   528  		}
   529  		switch extHdr.(type) {
   530  		// We can GRO these, so just skip over them.
   531  		case header.IPv6HopByHopOptionsExtHdr:
   532  		case header.IPv6RoutingExtHdr:
   533  		case header.IPv6DestinationOptionsExtHdr:
   534  		default:
   535  			// This is either a TCP header or something we can't handle.
   536  			ipHdrSize = int(it.HeaderOffset())
   537  			done = true
   538  		}
   539  		extHdr.Release()
   540  		if done {
   541  			break
   542  		}
   543  	}
   544  
   545  	hdrBytes, ok = pkt.Data().PullUp(ipHdrSize + header.TCPMinimumSize)
   546  	if !ok {
   547  		ep.HandlePacket(pkt)
   548  		return
   549  	}
   550  	ipHdr = header.IPv6(hdrBytes[:ipHdrSize])
   551  
   552  	// We only handle TCP packets.
   553  	if transProto != header.TCPProtocolNumber {
   554  		ep.HandlePacket(pkt)
   555  		return
   556  	}
   557  	tcpHdr := header.TCP(hdrBytes[ipHdrSize:])
   558  	dataOff := tcpHdr.DataOffset()
   559  	if dataOff < header.TCPMinimumSize {
   560  		// Malformed packet: will be handled further up the stack.
   561  		ep.HandlePacket(pkt)
   562  		return
   563  	}
   564  
   565  	hdrBytes, ok = pkt.Data().PullUp(ipHdrSize + int(dataOff))
   566  	if !ok {
   567  		// Malformed packet: will be handled further up the stack.
   568  		ep.HandlePacket(pkt)
   569  		return
   570  	}
   571  	tcpHdr = header.TCP(hdrBytes[ipHdrSize:])
   572  
   573  	// If either checksum is bad, flush the packet. Since we don't know
   574  	// what bits were flipped, we can't identify this packet with a flow.
   575  	if !pkt.RXChecksumValidated {
   576  		if !ipHdr.IsValid(pkt.Data().Size()) {
   577  			ep.HandlePacket(pkt)
   578  			return
   579  		}
   580  		payloadChecksum := pkt.Data().ChecksumAtOffset(ipHdrSize + int(dataOff))
   581  		tcpPayloadSize := pkt.Data().Size() - ipHdrSize - int(dataOff)
   582  		if !tcpHdr.IsChecksumValid(ipHdr.SourceAddress(), ipHdr.DestinationAddress(), payloadChecksum, uint16(tcpPayloadSize)) {
   583  			ep.HandlePacket(pkt)
   584  			return
   585  		}
   586  		// We've validated the checksum, no reason for others to do it
   587  		// again.
   588  		pkt.RXChecksumValidated = true
   589  	}
   590  
   591  	// Now we can get the bucket for the packet.
   592  	bucket := &gd.buckets[gd.bucketForPacket(ipHdr, tcpHdr)&groNBucketsMask]
   593  	bucket.mu.Lock()
   594  	groPkt, flushGROPkt := bucket.findGROPacket6(pkt, ipHdr, tcpHdr, ep)
   595  	bucket.found(gd, groPkt, flushGROPkt, pkt, ipHdr, tcpHdr, ep, updateIPv6Hdr)
   596  }
   597  
   598  func (gd *groDispatcher) bucketForPacket(ipHdr header.Network, tcpHdr header.TCP) int {
   599  	// TODO(b/256037250): Use jenkins or checksum. Write a test to print
   600  	// distribution.
   601  	var sum int
   602  	srcAddr := ipHdr.SourceAddress()
   603  	for _, val := range srcAddr.AsSlice() {
   604  		sum += int(val)
   605  	}
   606  	dstAddr := ipHdr.DestinationAddress()
   607  	for _, val := range dstAddr.AsSlice() {
   608  		sum += int(val)
   609  	}
   610  	sum += int(tcpHdr.SourcePort())
   611  	sum += int(tcpHdr.DestinationPort())
   612  	return sum
   613  }
   614  
   615  // flush sends any packets older than interval up the stack.
   616  //
   617  // Returns true iff packets remain.
   618  func (gd *groDispatcher) flush() bool {
   619  	interval := gd.intervalNS.Load()
   620  	old := time.Now().Add(-time.Duration(interval) * time.Nanosecond)
   621  	return gd.flushSinceOrEqualTo(old)
   622  }
   623  
   624  // flushSinceOrEqualTo sends any packets older than or equal to the specified
   625  // time.
   626  //
   627  // Returns true iff packets remain.
   628  func (gd *groDispatcher) flushSinceOrEqualTo(old time.Time) bool {
   629  	type pair struct {
   630  		pkt *PacketBuffer
   631  		ep  NetworkEndpoint
   632  	}
   633  
   634  	hasMore := false
   635  
   636  	for i := range gd.buckets {
   637  		// Put packets in a slice so we don't have to hold bucket.mu
   638  		// when we call HandlePacket.
   639  		var pairsBacking [groNBuckets]pair
   640  		pairs := pairsBacking[:0]
   641  
   642  		bucket := &gd.buckets[i]
   643  		bucket.mu.Lock()
   644  		for groPkt := bucket.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
   645  			if groPkt.created.After(old) {
   646  				// Packets are ordered by age, so we can move
   647  				// on once we find one that's too new.
   648  				hasMore = true
   649  				break
   650  			} else {
   651  				pairs = append(pairs, pair{groPkt.pkt, groPkt.ep})
   652  				bucket.removeOne(groPkt)
   653  			}
   654  		}
   655  		bucket.mu.Unlock()
   656  
   657  		for _, pair := range pairs {
   658  			pair.ep.HandlePacket(pair.pkt)
   659  			pair.pkt.DecRef()
   660  		}
   661  	}
   662  
   663  	return hasMore
   664  }
   665  
   666  func (gd *groDispatcher) flushAll() {
   667  	if gd.flushSinceOrEqualTo(time.Now()) {
   668  		panic("packets unexpectedly remain after flushing all")
   669  	}
   670  }
   671  
   672  // close stops the GRO goroutine and releases any held packets.
   673  func (gd *groDispatcher) close() {
   674  	gd.flushTimer.Stop()
   675  	// Prevent the timer from being scheduled again.
   676  	gd.flushTimerState.Store(flushTimerClosed)
   677  
   678  	for i := range gd.buckets {
   679  		bucket := &gd.buckets[i]
   680  		bucket.mu.Lock()
   681  		for groPkt := bucket.packets.Front(); groPkt != nil; groPkt = bucket.packets.Front() {
   682  			groPkt.pkt.DecRef()
   683  			bucket.removeOne(groPkt)
   684  		}
   685  		bucket.mu.Unlock()
   686  	}
   687  }
   688  
   689  // String implements fmt.Stringer.
   690  func (gd *groDispatcher) String() string {
   691  	ret := "GRO state: \n"
   692  	for i := range gd.buckets {
   693  		bucket := &gd.buckets[i]
   694  		bucket.mu.Lock()
   695  		ret += fmt.Sprintf("bucket %d: %d packets: ", i, bucket.count)
   696  		for groPkt := bucket.packets.Front(); groPkt != nil; groPkt = groPkt.Next() {
   697  			ret += fmt.Sprintf("%s (%d), ", groPkt.created, groPkt.pkt.Data().Size())
   698  		}
   699  		ret += "\n"
   700  		bucket.mu.Unlock()
   701  	}
   702  	return ret
   703  }
   704  
   705  // shouldFlushTCP returns whether the TCP headers indicate that groPkt should
   706  // be flushed
   707  func shouldFlushTCP(groPkt *groPacket, tcpHdr header.TCP) bool {
   708  	flags := tcpHdr.Flags()
   709  	groPktFlags := groPkt.tcpHdr.Flags()
   710  	dataOff := tcpHdr.DataOffset()
   711  	if flags&header.TCPFlagCwr != 0 || // Is congestion control occurring?
   712  		(flags^groPktFlags)&^(header.TCPFlagCwr|header.TCPFlagFin|header.TCPFlagPsh) != 0 || // Do the flags differ besides CRW, FIN, and PSH?
   713  		tcpHdr.AckNumber() != groPkt.tcpHdr.AckNumber() || // Do the ACKs match?
   714  		dataOff != groPkt.tcpHdr.DataOffset() || // Are the TCP headers the same length?
   715  		groPkt.tcpHdr.SequenceNumber()+uint32(groPkt.payloadSize()) != tcpHdr.SequenceNumber() { // Does the incoming packet match the expected sequence number?
   716  		return true
   717  	}
   718  	// The options, including timestamps, must be identical.
   719  	return !bytes.Equal(tcpHdr[header.TCPMinimumSize:], groPkt.tcpHdr[header.TCPMinimumSize:])
   720  }
   721  
   722  func updateIPv4Hdr(ipHdrBytes []byte, newBytes int) {
   723  	ipHdr := header.IPv4(ipHdrBytes)
   724  	ipHdr.SetTotalLength(ipHdr.TotalLength() + uint16(newBytes))
   725  }
   726  
   727  func updateIPv6Hdr(ipHdrBytes []byte, newBytes int) {
   728  	ipHdr := header.IPv6(ipHdrBytes)
   729  	ipHdr.SetPayloadLength(ipHdr.PayloadLength() + uint16(newBytes))
   730  }