github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/offload_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"io"
    13  	"unsafe"
    14  
    15  	"github.com/amnezia-vpn/amneziawg-go/conn"
    16  	"golang.org/x/sys/unix"
    17  )
    18  
    19  const tcpFlagsOffset = 13
    20  
    21  const (
    22  	tcpFlagFIN uint8 = 0x01
    23  	tcpFlagPSH uint8 = 0x08
    24  	tcpFlagACK uint8 = 0x10
    25  )
    26  
    27  // virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
    28  // kernel symbol is virtio_net_hdr.
    29  type virtioNetHdr struct {
    30  	flags      uint8
    31  	gsoType    uint8
    32  	hdrLen     uint16
    33  	gsoSize    uint16
    34  	csumStart  uint16
    35  	csumOffset uint16
    36  }
    37  
    38  func (v *virtioNetHdr) decode(b []byte) error {
    39  	if len(b) < virtioNetHdrLen {
    40  		return io.ErrShortBuffer
    41  	}
    42  	copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
    43  	return nil
    44  }
    45  
    46  func (v *virtioNetHdr) encode(b []byte) error {
    47  	if len(b) < virtioNetHdrLen {
    48  		return io.ErrShortBuffer
    49  	}
    50  	copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
    51  	return nil
    52  }
    53  
    54  const (
    55  	// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
    56  	// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
    57  	virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
    58  )
    59  
    60  // tcpFlowKey represents the key for a TCP flow.
    61  type tcpFlowKey struct {
    62  	srcAddr, dstAddr [16]byte
    63  	srcPort, dstPort uint16
    64  	rxAck            uint32 // varying ack values should not be coalesced. Treat them as separate flows.
    65  	isV6             bool
    66  }
    67  
    68  // tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
    69  type tcpGROTable struct {
    70  	itemsByFlow map[tcpFlowKey][]tcpGROItem
    71  	itemsPool   [][]tcpGROItem
    72  }
    73  
    74  func newTCPGROTable() *tcpGROTable {
    75  	t := &tcpGROTable{
    76  		itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
    77  		itemsPool:   make([][]tcpGROItem, conn.IdealBatchSize),
    78  	}
    79  	for i := range t.itemsPool {
    80  		t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
    81  	}
    82  	return t
    83  }
    84  
    85  func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
    86  	key := tcpFlowKey{}
    87  	addrSize := dstAddrOffset - srcAddrOffset
    88  	copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
    89  	copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
    90  	key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
    91  	key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
    92  	key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
    93  	key.isV6 = addrSize == 16
    94  	return key
    95  }
    96  
    97  // lookupOrInsert looks up a flow for the provided packet and metadata,
    98  // returning the packets found for the flow, or inserting a new one if none
    99  // is found.
   100  func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
   101  	key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
   102  	items, ok := t.itemsByFlow[key]
   103  	if ok {
   104  		return items, ok
   105  	}
   106  	// TODO: insert() performs another map lookup. This could be rearranged to avoid.
   107  	t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
   108  	return nil, false
   109  }
   110  
   111  // insert an item in the table for the provided packet and packet metadata.
   112  func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
   113  	key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
   114  	item := tcpGROItem{
   115  		key:       key,
   116  		bufsIndex: uint16(bufsIndex),
   117  		gsoSize:   uint16(len(pkt[tcphOffset+tcphLen:])),
   118  		iphLen:    uint8(tcphOffset),
   119  		tcphLen:   uint8(tcphLen),
   120  		sentSeq:   binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
   121  		pshSet:    pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
   122  	}
   123  	items, ok := t.itemsByFlow[key]
   124  	if !ok {
   125  		items = t.newItems()
   126  	}
   127  	items = append(items, item)
   128  	t.itemsByFlow[key] = items
   129  }
   130  
   131  func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
   132  	items, _ := t.itemsByFlow[item.key]
   133  	items[i] = item
   134  }
   135  
   136  func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
   137  	items, _ := t.itemsByFlow[key]
   138  	items = append(items[:i], items[i+1:]...)
   139  	t.itemsByFlow[key] = items
   140  }
   141  
   142  // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
   143  // of a GRO evaluation across a vector of packets.
   144  type tcpGROItem struct {
   145  	key       tcpFlowKey
   146  	sentSeq   uint32 // the sequence number
   147  	bufsIndex uint16 // the index into the original bufs slice
   148  	numMerged uint16 // the number of packets merged into this item
   149  	gsoSize   uint16 // payload size
   150  	iphLen    uint8  // ip header len
   151  	tcphLen   uint8  // tcp header len
   152  	pshSet    bool   // psh flag is set
   153  }
   154  
   155  func (t *tcpGROTable) newItems() []tcpGROItem {
   156  	var items []tcpGROItem
   157  	items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
   158  	return items
   159  }
   160  
   161  func (t *tcpGROTable) reset() {
   162  	for k, items := range t.itemsByFlow {
   163  		items = items[:0]
   164  		t.itemsPool = append(t.itemsPool, items)
   165  		delete(t.itemsByFlow, k)
   166  	}
   167  }
   168  
   169  // udpFlowKey represents the key for a UDP flow.
   170  type udpFlowKey struct {
   171  	srcAddr, dstAddr [16]byte
   172  	srcPort, dstPort uint16
   173  	isV6             bool
   174  }
   175  
   176  // udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
   177  type udpGROTable struct {
   178  	itemsByFlow map[udpFlowKey][]udpGROItem
   179  	itemsPool   [][]udpGROItem
   180  }
   181  
   182  func newUDPGROTable() *udpGROTable {
   183  	u := &udpGROTable{
   184  		itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
   185  		itemsPool:   make([][]udpGROItem, conn.IdealBatchSize),
   186  	}
   187  	for i := range u.itemsPool {
   188  		u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
   189  	}
   190  	return u
   191  }
   192  
   193  func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
   194  	key := udpFlowKey{}
   195  	addrSize := dstAddrOffset - srcAddrOffset
   196  	copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
   197  	copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
   198  	key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
   199  	key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
   200  	key.isV6 = addrSize == 16
   201  	return key
   202  }
   203  
   204  // lookupOrInsert looks up a flow for the provided packet and metadata,
   205  // returning the packets found for the flow, or inserting a new one if none
   206  // is found.
   207  func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
   208  	key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
   209  	items, ok := u.itemsByFlow[key]
   210  	if ok {
   211  		return items, ok
   212  	}
   213  	// TODO: insert() performs another map lookup. This could be rearranged to avoid.
   214  	u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
   215  	return nil, false
   216  }
   217  
   218  // insert an item in the table for the provided packet and packet metadata.
   219  func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
   220  	key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
   221  	item := udpGROItem{
   222  		key:              key,
   223  		bufsIndex:        uint16(bufsIndex),
   224  		gsoSize:          uint16(len(pkt[udphOffset+udphLen:])),
   225  		iphLen:           uint8(udphOffset),
   226  		cSumKnownInvalid: cSumKnownInvalid,
   227  	}
   228  	items, ok := u.itemsByFlow[key]
   229  	if !ok {
   230  		items = u.newItems()
   231  	}
   232  	items = append(items, item)
   233  	u.itemsByFlow[key] = items
   234  }
   235  
   236  func (u *udpGROTable) updateAt(item udpGROItem, i int) {
   237  	items, _ := u.itemsByFlow[item.key]
   238  	items[i] = item
   239  }
   240  
   241  // udpGROItem represents bookkeeping data for a UDP packet during the lifetime
   242  // of a GRO evaluation across a vector of packets.
   243  type udpGROItem struct {
   244  	key              udpFlowKey
   245  	bufsIndex        uint16 // the index into the original bufs slice
   246  	numMerged        uint16 // the number of packets merged into this item
   247  	gsoSize          uint16 // payload size
   248  	iphLen           uint8  // ip header len
   249  	cSumKnownInvalid bool   // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
   250  }
   251  
   252  func (u *udpGROTable) newItems() []udpGROItem {
   253  	var items []udpGROItem
   254  	items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
   255  	return items
   256  }
   257  
   258  func (u *udpGROTable) reset() {
   259  	for k, items := range u.itemsByFlow {
   260  		items = items[:0]
   261  		u.itemsPool = append(u.itemsPool, items)
   262  		delete(u.itemsByFlow, k)
   263  	}
   264  }
   265  
   266  // canCoalesce represents the outcome of checking if two TCP packets are
   267  // candidates for coalescing.
   268  type canCoalesce int
   269  
   270  const (
   271  	coalescePrepend     canCoalesce = -1
   272  	coalesceUnavailable canCoalesce = 0
   273  	coalesceAppend      canCoalesce = 1
   274  )
   275  
   276  // ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
   277  // meet all requirements to be merged as part of a GRO operation, otherwise it
   278  // returns false.
   279  func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
   280  	if len(pktA) < 9 || len(pktB) < 9 {
   281  		return false
   282  	}
   283  	if pktA[0]>>4 == 6 {
   284  		if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
   285  			// cannot coalesce with unequal Traffic class values
   286  			return false
   287  		}
   288  		if pktA[7] != pktB[7] {
   289  			// cannot coalesce with unequal Hop limit values
   290  			return false
   291  		}
   292  	} else {
   293  		if pktA[1] != pktB[1] {
   294  			// cannot coalesce with unequal ToS values
   295  			return false
   296  		}
   297  		if pktA[6]>>5 != pktB[6]>>5 {
   298  			// cannot coalesce with unequal DF or reserved bits. MF is checked
   299  			// further up the stack.
   300  			return false
   301  		}
   302  		if pktA[8] != pktB[8] {
   303  			// cannot coalesce with unequal TTL values
   304  			return false
   305  		}
   306  	}
   307  	return true
   308  }
   309  
   310  // udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
   311  // described by item. iphLen and gsoSize describe pkt. bufs is the vector of
   312  // packets involved in the current GRO evaluation. bufsOffset is the offset at
   313  // which packet data begins within bufs.
   314  func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
   315  	pktTarget := bufs[item.bufsIndex][bufsOffset:]
   316  	if !ipHeadersCanCoalesce(pkt, pktTarget) {
   317  		return coalesceUnavailable
   318  	}
   319  	if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
   320  		// A smaller than gsoSize packet has been appended previously.
   321  		// Nothing can come after a smaller packet on the end.
   322  		return coalesceUnavailable
   323  	}
   324  	if gsoSize > item.gsoSize {
   325  		// We cannot have a larger packet following a smaller one.
   326  		return coalesceUnavailable
   327  	}
   328  	return coalesceAppend
   329  }
   330  
   331  // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
   332  // described by item. This function makes considerations that match the kernel's
   333  // GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
   334  func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
   335  	pktTarget := bufs[item.bufsIndex][bufsOffset:]
   336  	if tcphLen != item.tcphLen {
   337  		// cannot coalesce with unequal tcp options len
   338  		return coalesceUnavailable
   339  	}
   340  	if tcphLen > 20 {
   341  		if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
   342  			// cannot coalesce with unequal tcp options
   343  			return coalesceUnavailable
   344  		}
   345  	}
   346  	if !ipHeadersCanCoalesce(pkt, pktTarget) {
   347  		return coalesceUnavailable
   348  	}
   349  	// seq adjacency
   350  	lhsLen := item.gsoSize
   351  	lhsLen += item.numMerged * item.gsoSize
   352  	if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
   353  		if item.pshSet {
   354  			// We cannot append to a segment that has the PSH flag set, PSH
   355  			// can only be set on the final segment in a reassembled group.
   356  			return coalesceUnavailable
   357  		}
   358  		if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
   359  			// A smaller than gsoSize packet has been appended previously.
   360  			// Nothing can come after a smaller packet on the end.
   361  			return coalesceUnavailable
   362  		}
   363  		if gsoSize > item.gsoSize {
   364  			// We cannot have a larger packet following a smaller one.
   365  			return coalesceUnavailable
   366  		}
   367  		return coalesceAppend
   368  	} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
   369  		if pshSet {
   370  			// We cannot prepend with a segment that has the PSH flag set, PSH
   371  			// can only be set on the final segment in a reassembled group.
   372  			return coalesceUnavailable
   373  		}
   374  		if gsoSize < item.gsoSize {
   375  			// We cannot have a larger packet following a smaller one.
   376  			return coalesceUnavailable
   377  		}
   378  		if gsoSize > item.gsoSize && item.numMerged > 0 {
   379  			// There's at least one previous merge, and we're larger than all
   380  			// previous. This would put multiple smaller packets on the end.
   381  			return coalesceUnavailable
   382  		}
   383  		return coalescePrepend
   384  	}
   385  	return coalesceUnavailable
   386  }
   387  
   388  func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
   389  	srcAddrAt := ipv4SrcAddrOffset
   390  	addrSize := 4
   391  	if isV6 {
   392  		srcAddrAt = ipv6SrcAddrOffset
   393  		addrSize = 16
   394  	}
   395  	lenForPseudo := uint16(len(pkt) - int(iphLen))
   396  	cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
   397  	return ^checksum(pkt[iphLen:], cSum) == 0
   398  }
   399  
   400  // coalesceResult represents the result of attempting to coalesce two TCP
   401  // packets.
   402  type coalesceResult int
   403  
   404  const (
   405  	coalesceInsufficientCap coalesceResult = iota
   406  	coalescePSHEnding
   407  	coalesceItemInvalidCSum
   408  	coalescePktInvalidCSum
   409  	coalesceSuccess
   410  )
   411  
   412  // coalesceUDPPackets attempts to coalesce pkt with the packet described by
   413  // item, and returns the outcome.
   414  func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
   415  	pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
   416  	headersLen := item.iphLen + udphLen
   417  	coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
   418  
   419  	if cap(pktHead)-bufsOffset < coalescedLen {
   420  		// We don't want to allocate a new underlying array if capacity is
   421  		// too small.
   422  		return coalesceInsufficientCap
   423  	}
   424  	if item.numMerged == 0 {
   425  		if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
   426  			return coalesceItemInvalidCSum
   427  		}
   428  	}
   429  	if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
   430  		return coalescePktInvalidCSum
   431  	}
   432  	extendBy := len(pkt) - int(headersLen)
   433  	bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
   434  	copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
   435  
   436  	item.numMerged++
   437  	return coalesceSuccess
   438  }
   439  
   440  // coalesceTCPPackets attempts to coalesce pkt with the packet described by
   441  // item, and returns the outcome. This function may swap bufs elements in the
   442  // event of a prepend as item's bufs index is already being tracked for writing
   443  // to a Device.
   444  func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
   445  	var pktHead []byte // the packet that will end up at the front
   446  	headersLen := item.iphLen + item.tcphLen
   447  	coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
   448  
   449  	// Copy data
   450  	if mode == coalescePrepend {
   451  		pktHead = pkt
   452  		if cap(pkt)-bufsOffset < coalescedLen {
   453  			// We don't want to allocate a new underlying array if capacity is
   454  			// too small.
   455  			return coalesceInsufficientCap
   456  		}
   457  		if pshSet {
   458  			return coalescePSHEnding
   459  		}
   460  		if item.numMerged == 0 {
   461  			if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
   462  				return coalesceItemInvalidCSum
   463  			}
   464  		}
   465  		if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
   466  			return coalescePktInvalidCSum
   467  		}
   468  		item.sentSeq = seq
   469  		extendBy := coalescedLen - len(pktHead)
   470  		bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
   471  		copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
   472  		// Flip the slice headers in bufs as part of prepend. The index of item
   473  		// is already being tracked for writing.
   474  		bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
   475  	} else {
   476  		pktHead = bufs[item.bufsIndex][bufsOffset:]
   477  		if cap(pktHead)-bufsOffset < coalescedLen {
   478  			// We don't want to allocate a new underlying array if capacity is
   479  			// too small.
   480  			return coalesceInsufficientCap
   481  		}
   482  		if item.numMerged == 0 {
   483  			if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
   484  				return coalesceItemInvalidCSum
   485  			}
   486  		}
   487  		if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
   488  			return coalescePktInvalidCSum
   489  		}
   490  		if pshSet {
   491  			// We are appending a segment with PSH set.
   492  			item.pshSet = pshSet
   493  			pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
   494  		}
   495  		extendBy := len(pkt) - int(headersLen)
   496  		bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
   497  		copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
   498  	}
   499  
   500  	if gsoSize > item.gsoSize {
   501  		item.gsoSize = gsoSize
   502  	}
   503  
   504  	item.numMerged++
   505  	return coalesceSuccess
   506  }
   507  
   508  const (
   509  	ipv4FlagMoreFragments uint8 = 0x20
   510  )
   511  
   512  const (
   513  	ipv4SrcAddrOffset = 12
   514  	ipv6SrcAddrOffset = 8
   515  	maxUint16         = 1<<16 - 1
   516  )
   517  
   518  type groResult int
   519  
   520  const (
   521  	groResultNoop groResult = iota
   522  	groResultTableInsert
   523  	groResultCoalesced
   524  )
   525  
   526  // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
   527  // existing packets tracked in table. It returns a groResultNoop when no
   528  // action was taken, groResultTableInsert when the evaluated packet was
   529  // inserted into table, and groResultCoalesced when the evaluated packet was
   530  // coalesced with another packet in table.
   531  func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
   532  	pkt := bufs[pktI][offset:]
   533  	if len(pkt) > maxUint16 {
   534  		// A valid IPv4 or IPv6 packet will never exceed this.
   535  		return groResultNoop
   536  	}
   537  	iphLen := int((pkt[0] & 0x0F) * 4)
   538  	if isV6 {
   539  		iphLen = 40
   540  		ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
   541  		if ipv6HPayloadLen != len(pkt)-iphLen {
   542  			return groResultNoop
   543  		}
   544  	} else {
   545  		totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
   546  		if totalLen != len(pkt) {
   547  			return groResultNoop
   548  		}
   549  	}
   550  	if len(pkt) < iphLen {
   551  		return groResultNoop
   552  	}
   553  	tcphLen := int((pkt[iphLen+12] >> 4) * 4)
   554  	if tcphLen < 20 || tcphLen > 60 {
   555  		return groResultNoop
   556  	}
   557  	if len(pkt) < iphLen+tcphLen {
   558  		return groResultNoop
   559  	}
   560  	if !isV6 {
   561  		if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
   562  			// no GRO support for fragmented segments for now
   563  			return groResultNoop
   564  		}
   565  	}
   566  	tcpFlags := pkt[iphLen+tcpFlagsOffset]
   567  	var pshSet bool
   568  	// not a candidate if any non-ACK flags (except PSH+ACK) are set
   569  	if tcpFlags != tcpFlagACK {
   570  		if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
   571  			return groResultNoop
   572  		}
   573  		pshSet = true
   574  	}
   575  	gsoSize := uint16(len(pkt) - tcphLen - iphLen)
   576  	// not a candidate if payload len is 0
   577  	if gsoSize < 1 {
   578  		return groResultNoop
   579  	}
   580  	seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
   581  	srcAddrOffset := ipv4SrcAddrOffset
   582  	addrLen := 4
   583  	if isV6 {
   584  		srcAddrOffset = ipv6SrcAddrOffset
   585  		addrLen = 16
   586  	}
   587  	items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
   588  	if !existing {
   589  		return groResultTableInsert
   590  	}
   591  	for i := len(items) - 1; i >= 0; i-- {
   592  		// In the best case of packets arriving in order iterating in reverse is
   593  		// more efficient if there are multiple items for a given flow. This
   594  		// also enables a natural table.deleteAt() in the
   595  		// coalesceItemInvalidCSum case without the need for index tracking.
   596  		// This algorithm makes a best effort to coalesce in the event of
   597  		// unordered packets, where pkt may land anywhere in items from a
   598  		// sequence number perspective, however once an item is inserted into
   599  		// the table it is never compared across other items later.
   600  		item := items[i]
   601  		can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
   602  		if can != coalesceUnavailable {
   603  			result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
   604  			switch result {
   605  			case coalesceSuccess:
   606  				table.updateAt(item, i)
   607  				return groResultCoalesced
   608  			case coalesceItemInvalidCSum:
   609  				// delete the item with an invalid csum
   610  				table.deleteAt(item.key, i)
   611  			case coalescePktInvalidCSum:
   612  				// no point in inserting an item that we can't coalesce
   613  				return groResultNoop
   614  			default:
   615  			}
   616  		}
   617  	}
   618  	// failed to coalesce with any other packets; store the item in the flow
   619  	table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
   620  	return groResultTableInsert
   621  }
   622  
   623  // applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
   624  // metadata found in table.
   625  func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
   626  	for _, items := range table.itemsByFlow {
   627  		for _, item := range items {
   628  			if item.numMerged > 0 {
   629  				hdr := virtioNetHdr{
   630  					flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
   631  					hdrLen:     uint16(item.iphLen + item.tcphLen),
   632  					gsoSize:    item.gsoSize,
   633  					csumStart:  uint16(item.iphLen),
   634  					csumOffset: 16,
   635  				}
   636  				pkt := bufs[item.bufsIndex][offset:]
   637  
   638  				// Recalculate the total len (IPv4) or payload len (IPv6).
   639  				// Recalculate the (IPv4) header checksum.
   640  				if item.key.isV6 {
   641  					hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
   642  					binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
   643  				} else {
   644  					hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
   645  					pkt[10], pkt[11] = 0, 0
   646  					binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
   647  					iphCSum := ^checksum(pkt[:item.iphLen], 0)            // compute IPv4 header checksum
   648  					binary.BigEndian.PutUint16(pkt[10:], iphCSum)         // set IPv4 header checksum field
   649  				}
   650  				err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
   651  				if err != nil {
   652  					return err
   653  				}
   654  
   655  				// Calculate the pseudo header checksum and place it at the TCP
   656  				// checksum offset. Downstream checksum offloading will combine
   657  				// this with computation of the tcp header and payload checksum.
   658  				addrLen := 4
   659  				addrOffset := ipv4SrcAddrOffset
   660  				if item.key.isV6 {
   661  					addrLen = 16
   662  					addrOffset = ipv6SrcAddrOffset
   663  				}
   664  				srcAddrAt := offset + addrOffset
   665  				srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
   666  				dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
   667  				psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
   668  				binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
   669  			} else {
   670  				hdr := virtioNetHdr{}
   671  				err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
   672  				if err != nil {
   673  					return err
   674  				}
   675  			}
   676  		}
   677  	}
   678  	return nil
   679  }
   680  
   681  // applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
   682  // metadata found in table.
   683  func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
   684  	for _, items := range table.itemsByFlow {
   685  		for _, item := range items {
   686  			if item.numMerged > 0 {
   687  				hdr := virtioNetHdr{
   688  					flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
   689  					hdrLen:     uint16(item.iphLen + udphLen),
   690  					gsoSize:    item.gsoSize,
   691  					csumStart:  uint16(item.iphLen),
   692  					csumOffset: 6,
   693  				}
   694  				pkt := bufs[item.bufsIndex][offset:]
   695  
   696  				// Recalculate the total len (IPv4) or payload len (IPv6).
   697  				// Recalculate the (IPv4) header checksum.
   698  				hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
   699  				if item.key.isV6 {
   700  					binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
   701  				} else {
   702  					pkt[10], pkt[11] = 0, 0
   703  					binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
   704  					iphCSum := ^checksum(pkt[:item.iphLen], 0)            // compute IPv4 header checksum
   705  					binary.BigEndian.PutUint16(pkt[10:], iphCSum)         // set IPv4 header checksum field
   706  				}
   707  				err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
   708  				if err != nil {
   709  					return err
   710  				}
   711  
   712  				// Recalculate the UDP len field value
   713  				binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
   714  
   715  				// Calculate the pseudo header checksum and place it at the UDP
   716  				// checksum offset. Downstream checksum offloading will combine
   717  				// this with computation of the udp header and payload checksum.
   718  				addrLen := 4
   719  				addrOffset := ipv4SrcAddrOffset
   720  				if item.key.isV6 {
   721  					addrLen = 16
   722  					addrOffset = ipv6SrcAddrOffset
   723  				}
   724  				srcAddrAt := offset + addrOffset
   725  				srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
   726  				dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
   727  				psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
   728  				binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
   729  			} else {
   730  				hdr := virtioNetHdr{}
   731  				err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
   732  				if err != nil {
   733  					return err
   734  				}
   735  			}
   736  		}
   737  	}
   738  	return nil
   739  }
   740  
   741  type groCandidateType uint8
   742  
   743  const (
   744  	notGROCandidate groCandidateType = iota
   745  	tcp4GROCandidate
   746  	tcp6GROCandidate
   747  	udp4GROCandidate
   748  	udp6GROCandidate
   749  )
   750  
   751  func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
   752  	if len(b) < 28 {
   753  		return notGROCandidate
   754  	}
   755  	if b[0]>>4 == 4 {
   756  		if b[0]&0x0F != 5 {
   757  			// IPv4 packets w/IP options do not coalesce
   758  			return notGROCandidate
   759  		}
   760  		if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
   761  			return tcp4GROCandidate
   762  		}
   763  		if b[9] == unix.IPPROTO_UDP && canUDPGRO {
   764  			return udp4GROCandidate
   765  		}
   766  	} else if b[0]>>4 == 6 {
   767  		if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
   768  			return tcp6GROCandidate
   769  		}
   770  		if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
   771  			return udp6GROCandidate
   772  		}
   773  	}
   774  	return notGROCandidate
   775  }
   776  
   777  const (
   778  	udphLen = 8
   779  )
   780  
   781  // udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
   782  // existing packets tracked in table. It returns a groResultNoop when no
   783  // action was taken, groResultTableInsert when the evaluated packet was
   784  // inserted into table, and groResultCoalesced when the evaluated packet was
   785  // coalesced with another packet in table.
   786  func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
   787  	pkt := bufs[pktI][offset:]
   788  	if len(pkt) > maxUint16 {
   789  		// A valid IPv4 or IPv6 packet will never exceed this.
   790  		return groResultNoop
   791  	}
   792  	iphLen := int((pkt[0] & 0x0F) * 4)
   793  	if isV6 {
   794  		iphLen = 40
   795  		ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
   796  		if ipv6HPayloadLen != len(pkt)-iphLen {
   797  			return groResultNoop
   798  		}
   799  	} else {
   800  		totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
   801  		if totalLen != len(pkt) {
   802  			return groResultNoop
   803  		}
   804  	}
   805  	if len(pkt) < iphLen {
   806  		return groResultNoop
   807  	}
   808  	if len(pkt) < iphLen+udphLen {
   809  		return groResultNoop
   810  	}
   811  	if !isV6 {
   812  		if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
   813  			// no GRO support for fragmented segments for now
   814  			return groResultNoop
   815  		}
   816  	}
   817  	gsoSize := uint16(len(pkt) - udphLen - iphLen)
   818  	// not a candidate if payload len is 0
   819  	if gsoSize < 1 {
   820  		return groResultNoop
   821  	}
   822  	srcAddrOffset := ipv4SrcAddrOffset
   823  	addrLen := 4
   824  	if isV6 {
   825  		srcAddrOffset = ipv6SrcAddrOffset
   826  		addrLen = 16
   827  	}
   828  	items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
   829  	if !existing {
   830  		return groResultTableInsert
   831  	}
   832  	// With UDP we only check the last item, otherwise we could reorder packets
   833  	// for a given flow. We must also always insert a new item, or successfully
   834  	// coalesce with an existing item, for the same reason.
   835  	item := items[len(items)-1]
   836  	can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
   837  	var pktCSumKnownInvalid bool
   838  	if can == coalesceAppend {
   839  		result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
   840  		switch result {
   841  		case coalesceSuccess:
   842  			table.updateAt(item, len(items)-1)
   843  			return groResultCoalesced
   844  		case coalesceItemInvalidCSum:
   845  			// If the existing item has an invalid csum we take no action. A new
   846  			// item will be stored after it, and the existing item will never be
   847  			// revisited as part of future coalescing candidacy checks.
   848  		case coalescePktInvalidCSum:
   849  			// We must insert a new item, but we also mark it as invalid csum
   850  			// to prevent a repeat checksum validation.
   851  			pktCSumKnownInvalid = true
   852  		default:
   853  		}
   854  	}
   855  	// failed to coalesce with any other packets; store the item in the flow
   856  	table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
   857  	return groResultTableInsert
   858  }
   859  
   860  // handleGRO evaluates bufs for GRO, and writes the indices of the resulting
   861  // packets into toWrite. toWrite, tcpTable, and udpTable should initially be
   862  // empty (but non-nil), and are passed in to save allocs as the caller may reset
   863  // and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
   864  // supported.
   865  func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
   866  	for i := range bufs {
   867  		if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
   868  			return errors.New("invalid offset")
   869  		}
   870  		var result groResult
   871  		switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
   872  		case tcp4GROCandidate:
   873  			result = tcpGRO(bufs, offset, i, tcpTable, false)
   874  		case tcp6GROCandidate:
   875  			result = tcpGRO(bufs, offset, i, tcpTable, true)
   876  		case udp4GROCandidate:
   877  			result = udpGRO(bufs, offset, i, udpTable, false)
   878  		case udp6GROCandidate:
   879  			result = udpGRO(bufs, offset, i, udpTable, true)
   880  		}
   881  		switch result {
   882  		case groResultNoop:
   883  			hdr := virtioNetHdr{}
   884  			err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
   885  			if err != nil {
   886  				return err
   887  			}
   888  			fallthrough
   889  		case groResultTableInsert:
   890  			*toWrite = append(*toWrite, i)
   891  		}
   892  	}
   893  	errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
   894  	errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
   895  	return errors.Join(errTCP, errUDP)
   896  }
   897  
   898  // gsoSplit splits packets from in into outBuffs, writing the size of each
   899  // element into sizes. It returns the number of buffers populated, and/or an
   900  // error.
   901  func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
   902  	iphLen := int(hdr.csumStart)
   903  	srcAddrOffset := ipv6SrcAddrOffset
   904  	addrLen := 16
   905  	if !isV6 {
   906  		in[10], in[11] = 0, 0 // clear ipv4 header checksum
   907  		srcAddrOffset = ipv4SrcAddrOffset
   908  		addrLen = 4
   909  	}
   910  	transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
   911  	in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
   912  	var firstTCPSeqNum uint32
   913  	var protocol uint8
   914  	if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
   915  		protocol = unix.IPPROTO_TCP
   916  		firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
   917  	} else {
   918  		protocol = unix.IPPROTO_UDP
   919  	}
   920  	nextSegmentDataAt := int(hdr.hdrLen)
   921  	i := 0
   922  	for ; nextSegmentDataAt < len(in); i++ {
   923  		if i == len(outBuffs) {
   924  			return i - 1, ErrTooManySegments
   925  		}
   926  		nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
   927  		if nextSegmentEnd > len(in) {
   928  			nextSegmentEnd = len(in)
   929  		}
   930  		segmentDataLen := nextSegmentEnd - nextSegmentDataAt
   931  		totalLen := int(hdr.hdrLen) + segmentDataLen
   932  		sizes[i] = totalLen
   933  		out := outBuffs[i][outOffset:]
   934  
   935  		copy(out, in[:iphLen])
   936  		if !isV6 {
   937  			// For IPv4 we are responsible for incrementing the ID field,
   938  			// updating the total len field, and recalculating the header
   939  			// checksum.
   940  			if i > 0 {
   941  				id := binary.BigEndian.Uint16(out[4:])
   942  				id += uint16(i)
   943  				binary.BigEndian.PutUint16(out[4:], id)
   944  			}
   945  			binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
   946  			ipv4CSum := ^checksum(out[:iphLen], 0)
   947  			binary.BigEndian.PutUint16(out[10:], ipv4CSum)
   948  		} else {
   949  			// For IPv6 we are responsible for updating the payload length field.
   950  			binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
   951  		}
   952  
   953  		// copy transport header
   954  		copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
   955  
   956  		if protocol == unix.IPPROTO_TCP {
   957  			// set TCP seq and adjust TCP flags
   958  			tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
   959  			binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
   960  			if nextSegmentEnd != len(in) {
   961  				// FIN and PSH should only be set on last segment
   962  				clearFlags := tcpFlagFIN | tcpFlagPSH
   963  				out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
   964  			}
   965  		} else {
   966  			// set UDP header len
   967  			binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
   968  		}
   969  
   970  		// payload
   971  		copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
   972  
   973  		// transport checksum
   974  		transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
   975  		lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
   976  		transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
   977  		transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
   978  		binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
   979  
   980  		nextSegmentDataAt += int(hdr.gsoSize)
   981  	}
   982  	return i, nil
   983  }
   984  
   985  func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
   986  	cSumAt := cSumStart + cSumOffset
   987  	// The initial value at the checksum offset should be summed with the
   988  	// checksum we compute. This is typically the pseudo-header checksum.
   989  	initial := binary.BigEndian.Uint16(in[cSumAt:])
   990  	in[cSumAt], in[cSumAt+1] = 0, 0
   991  	binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
   992  	return nil
   993  }