github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/tun_linux_offload.go (about)

     1  //go:build linux
     2  
     3  /* SPDX-License-Identifier: MIT
     4   *
     5   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     6   */
     7  
     8  package tun
     9  
    10  import (
    11  	"bytes"
    12  	"encoding/binary"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"unsafe"
    17  
    18  	"github.com/metacubex/sing-tun/internal/clashtcpip"
    19  	E "github.com/sagernet/sing/common/exceptions"
    20  
    21  	"golang.org/x/sys/unix"
    22  )
    23  
    24  const (
    25  	gsoMaxSize     = 65536
    26  	tcpFlagsOffset = 13
    27  	idealBatchSize = 128
    28  )
    29  
    30  const (
    31  	tcpFlagFIN uint8 = 0x01
    32  	tcpFlagPSH uint8 = 0x08
    33  	tcpFlagACK uint8 = 0x10
    34  )
    35  
    36  // virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
    37  // kernel symbol is virtio_net_hdr.
    38  type virtioNetHdr struct {
    39  	flags      uint8
    40  	gsoType    uint8
    41  	hdrLen     uint16
    42  	gsoSize    uint16
    43  	csumStart  uint16
    44  	csumOffset uint16
    45  }
    46  
    47  func (v *virtioNetHdr) decode(b []byte) error {
    48  	if len(b) < virtioNetHdrLen {
    49  		return io.ErrShortBuffer
    50  	}
    51  	copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
    52  	return nil
    53  }
    54  
    55  func (v *virtioNetHdr) encode(b []byte) error {
    56  	if len(b) < virtioNetHdrLen {
    57  		return io.ErrShortBuffer
    58  	}
    59  	copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
    60  	return nil
    61  }
    62  
    63  const (
    64  	// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
    65  	// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
    66  	virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
    67  )
    68  
    69  // flowKey represents the key for a flow.
    70  type flowKey struct {
    71  	srcAddr, dstAddr [16]byte
    72  	srcPort, dstPort uint16
    73  	rxAck            uint32 // varying ack values should not be coalesced. Treat them as separate flows.
    74  }
    75  
    76  // tcpGROTable holds flow and coalescing information for the purposes of GRO.
    77  type tcpGROTable struct {
    78  	itemsByFlow map[flowKey][]tcpGROItem
    79  	itemsPool   [][]tcpGROItem
    80  }
    81  
    82  func newTCPGROTable() *tcpGROTable {
    83  	t := &tcpGROTable{
    84  		itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize),
    85  		itemsPool:   make([][]tcpGROItem, idealBatchSize),
    86  	}
    87  	for i := range t.itemsPool {
    88  		t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
    89  	}
    90  	return t
    91  }
    92  
    93  func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
    94  	key := flowKey{}
    95  	addrSize := dstAddr - srcAddr
    96  	copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
    97  	copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
    98  	key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
    99  	key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
   100  	key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
   101  	return key
   102  }
   103  
   104  // lookupOrInsert looks up a flow for the provided packet and metadata,
   105  // returning the packets found for the flow, or inserting a new one if none
   106  // is found.
   107  func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
   108  	key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
   109  	items, ok := t.itemsByFlow[key]
   110  	if ok {
   111  		return items, ok
   112  	}
   113  	// TODO: insert() performs another map lookup. This could be rearranged to avoid.
   114  	t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
   115  	return nil, false
   116  }
   117  
   118  // insert an item in the table for the provided packet and packet metadata.
   119  func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
   120  	key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
   121  	item := tcpGROItem{
   122  		key:       key,
   123  		bufsIndex: uint16(bufsIndex),
   124  		gsoSize:   uint16(len(pkt[tcphOffset+tcphLen:])),
   125  		iphLen:    uint8(tcphOffset),
   126  		tcphLen:   uint8(tcphLen),
   127  		sentSeq:   binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
   128  		pshSet:    pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
   129  	}
   130  	items, ok := t.itemsByFlow[key]
   131  	if !ok {
   132  		items = t.newItems()
   133  	}
   134  	items = append(items, item)
   135  	t.itemsByFlow[key] = items
   136  }
   137  
   138  func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
   139  	items, _ := t.itemsByFlow[item.key]
   140  	items[i] = item
   141  }
   142  
   143  func (t *tcpGROTable) deleteAt(key flowKey, i int) {
   144  	items, _ := t.itemsByFlow[key]
   145  	items = append(items[:i], items[i+1:]...)
   146  	t.itemsByFlow[key] = items
   147  }
   148  
   149  // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
   150  // of a GRO evaluation across a vector of packets.
   151  type tcpGROItem struct {
   152  	key       flowKey
   153  	sentSeq   uint32 // the sequence number
   154  	bufsIndex uint16 // the index into the original bufs slice
   155  	numMerged uint16 // the number of packets merged into this item
   156  	gsoSize   uint16 // payload size
   157  	iphLen    uint8  // ip header len
   158  	tcphLen   uint8  // tcp header len
   159  	pshSet    bool   // psh flag is set
   160  }
   161  
   162  func (t *tcpGROTable) newItems() []tcpGROItem {
   163  	var items []tcpGROItem
   164  	items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
   165  	return items
   166  }
   167  
   168  func (t *tcpGROTable) reset() {
   169  	for k, items := range t.itemsByFlow {
   170  		items = items[:0]
   171  		t.itemsPool = append(t.itemsPool, items)
   172  		delete(t.itemsByFlow, k)
   173  	}
   174  }
   175  
   176  // canCoalesce represents the outcome of checking if two TCP packets are
   177  // candidates for coalescing.
   178  type canCoalesce int
   179  
   180  const (
   181  	coalescePrepend     canCoalesce = -1
   182  	coalesceUnavailable canCoalesce = 0
   183  	coalesceAppend      canCoalesce = 1
   184  )
   185  
   186  // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
   187  // described by item. This function makes considerations that match the kernel's
   188  // GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
   189  func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
   190  	pktTarget := bufs[item.bufsIndex][bufsOffset:]
   191  	if tcphLen != item.tcphLen {
   192  		// cannot coalesce with unequal tcp options len
   193  		return coalesceUnavailable
   194  	}
   195  	if tcphLen > 20 {
   196  		if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
   197  			// cannot coalesce with unequal tcp options
   198  			return coalesceUnavailable
   199  		}
   200  	}
   201  	if pkt[0]>>4 == 6 {
   202  		if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
   203  			// cannot coalesce with unequal Traffic class values
   204  			return coalesceUnavailable
   205  		}
   206  		if pkt[7] != pktTarget[7] {
   207  			// cannot coalesce with unequal Hop limit values
   208  			return coalesceUnavailable
   209  		}
   210  	} else {
   211  		if pkt[1] != pktTarget[1] {
   212  			// cannot coalesce with unequal ToS values
   213  			return coalesceUnavailable
   214  		}
   215  		if pkt[6]>>5 != pktTarget[6]>>5 {
   216  			// cannot coalesce with unequal DF or reserved bits. MF is checked
   217  			// further up the stack.
   218  			return coalesceUnavailable
   219  		}
   220  		if pkt[8] != pktTarget[8] {
   221  			// cannot coalesce with unequal TTL values
   222  			return coalesceUnavailable
   223  		}
   224  	}
   225  	// seq adjacency
   226  	lhsLen := item.gsoSize
   227  	lhsLen += item.numMerged * item.gsoSize
   228  	if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
   229  		if item.pshSet {
   230  			// We cannot append to a segment that has the PSH flag set, PSH
   231  			// can only be set on the final segment in a reassembled group.
   232  			return coalesceUnavailable
   233  		}
   234  		if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
   235  			// A smaller than gsoSize packet has been appended previously.
   236  			// Nothing can come after a smaller packet on the end.
   237  			return coalesceUnavailable
   238  		}
   239  		if gsoSize > item.gsoSize {
   240  			// We cannot have a larger packet following a smaller one.
   241  			return coalesceUnavailable
   242  		}
   243  		return coalesceAppend
   244  	} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
   245  		if pshSet {
   246  			// We cannot prepend with a segment that has the PSH flag set, PSH
   247  			// can only be set on the final segment in a reassembled group.
   248  			return coalesceUnavailable
   249  		}
   250  		if gsoSize < item.gsoSize {
   251  			// We cannot have a larger packet following a smaller one.
   252  			return coalesceUnavailable
   253  		}
   254  		if gsoSize > item.gsoSize && item.numMerged > 0 {
   255  			// There's at least one previous merge, and we're larger than all
   256  			// previous. This would put multiple smaller packets on the end.
   257  			return coalesceUnavailable
   258  		}
   259  		return coalescePrepend
   260  	}
   261  	return coalesceUnavailable
   262  }
   263  
   264  func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
   265  	srcAddrAt := ipv4SrcAddrOffset
   266  	addrSize := 4
   267  	if isV6 {
   268  		srcAddrAt = ipv6SrcAddrOffset
   269  		addrSize = 16
   270  	}
   271  	tcpTotalLen := uint16(len(pkt) - int(iphLen))
   272  	tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
   273  	return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0
   274  }
   275  
   276  // coalesceResult represents the result of attempting to coalesce two TCP
   277  // packets.
   278  type coalesceResult int
   279  
   280  const (
   281  	coalesceInsufficientCap coalesceResult = iota
   282  	coalescePSHEnding
   283  	coalesceItemInvalidCSum
   284  	coalescePktInvalidCSum
   285  	coalesceSuccess
   286  )
   287  
   288  // coalesceTCPPackets attempts to coalesce pkt with the packet described by
   289  // item, returning the outcome. This function may swap bufs elements in the
   290  // event of a prepend as item's bufs index is already being tracked for writing
   291  // to a Device.
   292  func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
   293  	var pktHead []byte // the packet that will end up at the front
   294  	headersLen := item.iphLen + item.tcphLen
   295  	coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
   296  
   297  	// Copy data
   298  	if mode == coalescePrepend {
   299  		pktHead = pkt
   300  		if cap(pkt)-bufsOffset < coalescedLen {
   301  			// We don't want to allocate a new underlying array if capacity is
   302  			// too small.
   303  			return coalesceInsufficientCap
   304  		}
   305  		if pshSet {
   306  			return coalescePSHEnding
   307  		}
   308  		if item.numMerged == 0 {
   309  			if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
   310  				return coalesceItemInvalidCSum
   311  			}
   312  		}
   313  		if !tcpChecksumValid(pkt, item.iphLen, isV6) {
   314  			return coalescePktInvalidCSum
   315  		}
   316  		item.sentSeq = seq
   317  		extendBy := coalescedLen - len(pktHead)
   318  		bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
   319  		copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
   320  		// Flip the slice headers in bufs as part of prepend. The index of item
   321  		// is already being tracked for writing.
   322  		bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
   323  	} else {
   324  		pktHead = bufs[item.bufsIndex][bufsOffset:]
   325  		if cap(pktHead)-bufsOffset < coalescedLen {
   326  			// We don't want to allocate a new underlying array if capacity is
   327  			// too small.
   328  			return coalesceInsufficientCap
   329  		}
   330  		if item.numMerged == 0 {
   331  			if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
   332  				return coalesceItemInvalidCSum
   333  			}
   334  		}
   335  		if !tcpChecksumValid(pkt, item.iphLen, isV6) {
   336  			return coalescePktInvalidCSum
   337  		}
   338  		if pshSet {
   339  			// We are appending a segment with PSH set.
   340  			item.pshSet = pshSet
   341  			pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
   342  		}
   343  		extendBy := len(pkt) - int(headersLen)
   344  		bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
   345  		copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
   346  	}
   347  
   348  	if gsoSize > item.gsoSize {
   349  		item.gsoSize = gsoSize
   350  	}
   351  
   352  	item.numMerged++
   353  	return coalesceSuccess
   354  }
   355  
   356  const (
   357  	ipv4FlagMoreFragments uint8 = 0x20
   358  )
   359  
   360  const (
   361  	ipv4SrcAddrOffset = 12
   362  	ipv6SrcAddrOffset = 8
   363  	maxUint16         = 1<<16 - 1
   364  )
   365  
   366  type tcpGROResult int
   367  
   368  const (
   369  	tcpGROResultNoop tcpGROResult = iota
   370  	tcpGROResultTableInsert
   371  	tcpGROResultCoalesced
   372  )
   373  
   374  // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
   375  // existing packets tracked in table. It returns a tcpGROResultNoop when no
   376  // action was taken, tcpGROResultTableInsert when the evaluated packet was
   377  // inserted into table, and tcpGROResultCoalesced when the evaluated packet was
   378  // coalesced with another packet in table.
   379  func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
   380  	pkt := bufs[pktI][offset:]
   381  	if len(pkt) > maxUint16 {
   382  		// A valid IPv4 or IPv6 packet will never exceed this.
   383  		return tcpGROResultNoop
   384  	}
   385  	iphLen := int((pkt[0] & 0x0F) * 4)
   386  	if isV6 {
   387  		iphLen = 40
   388  		ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
   389  		if ipv6HPayloadLen != len(pkt)-iphLen {
   390  			return tcpGROResultNoop
   391  		}
   392  	} else {
   393  		totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
   394  		if totalLen != len(pkt) {
   395  			return tcpGROResultNoop
   396  		}
   397  	}
   398  	if len(pkt) < iphLen {
   399  		return tcpGROResultNoop
   400  	}
   401  	tcphLen := int((pkt[iphLen+12] >> 4) * 4)
   402  	if tcphLen < 20 || tcphLen > 60 {
   403  		return tcpGROResultNoop
   404  	}
   405  	if len(pkt) < iphLen+tcphLen {
   406  		return tcpGROResultNoop
   407  	}
   408  	if !isV6 {
   409  		if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
   410  			// no GRO support for fragmented segments for now
   411  			return tcpGROResultNoop
   412  		}
   413  	}
   414  	tcpFlags := pkt[iphLen+tcpFlagsOffset]
   415  	var pshSet bool
   416  	// not a candidate if any non-ACK flags (except PSH+ACK) are set
   417  	if tcpFlags != tcpFlagACK {
   418  		if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
   419  			return tcpGROResultNoop
   420  		}
   421  		pshSet = true
   422  	}
   423  	gsoSize := uint16(len(pkt) - tcphLen - iphLen)
   424  	// not a candidate if payload len is 0
   425  	if gsoSize < 1 {
   426  		return tcpGROResultNoop
   427  	}
   428  	seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
   429  	srcAddrOffset := ipv4SrcAddrOffset
   430  	addrLen := 4
   431  	if isV6 {
   432  		srcAddrOffset = ipv6SrcAddrOffset
   433  		addrLen = 16
   434  	}
   435  	items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
   436  	if !existing {
   437  		return tcpGROResultNoop
   438  	}
   439  	for i := len(items) - 1; i >= 0; i-- {
   440  		// In the best case of packets arriving in order iterating in reverse is
   441  		// more efficient if there are multiple items for a given flow. This
   442  		// also enables a natural table.deleteAt() in the
   443  		// coalesceItemInvalidCSum case without the need for index tracking.
   444  		// This algorithm makes a best effort to coalesce in the event of
   445  		// unordered packets, where pkt may land anywhere in items from a
   446  		// sequence number perspective, however once an item is inserted into
   447  		// the table it is never compared across other items later.
   448  		item := items[i]
   449  		can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
   450  		if can != coalesceUnavailable {
   451  			result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
   452  			switch result {
   453  			case coalesceSuccess:
   454  				table.updateAt(item, i)
   455  				return tcpGROResultCoalesced
   456  			case coalesceItemInvalidCSum:
   457  				// delete the item with an invalid csum
   458  				table.deleteAt(item.key, i)
   459  			case coalescePktInvalidCSum:
   460  				// no point in inserting an item that we can't coalesce
   461  				return tcpGROResultNoop
   462  			default:
   463  			}
   464  		}
   465  	}
   466  	// failed to coalesce with any other packets; store the item in the flow
   467  	table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
   468  	return tcpGROResultTableInsert
   469  }
   470  
   471  func isTCP4NoIPOptions(b []byte) bool {
   472  	if len(b) < 40 {
   473  		return false
   474  	}
   475  	if b[0]>>4 != 4 {
   476  		return false
   477  	}
   478  	if b[0]&0x0F != 5 {
   479  		return false
   480  	}
   481  	if b[9] != unix.IPPROTO_TCP {
   482  		return false
   483  	}
   484  	return true
   485  }
   486  
   487  func isTCP6NoEH(b []byte) bool {
   488  	if len(b) < 60 {
   489  		return false
   490  	}
   491  	if b[0]>>4 != 6 {
   492  		return false
   493  	}
   494  	if b[6] != unix.IPPROTO_TCP {
   495  		return false
   496  	}
   497  	return true
   498  }
   499  
   500  // applyCoalesceAccounting updates bufs to account for coalescing based on the
   501  // metadata found in table.
   502  func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
   503  	for _, items := range table.itemsByFlow {
   504  		for _, item := range items {
   505  			if item.numMerged > 0 {
   506  				hdr := virtioNetHdr{
   507  					flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
   508  					hdrLen:     uint16(item.iphLen + item.tcphLen),
   509  					gsoSize:    item.gsoSize,
   510  					csumStart:  uint16(item.iphLen),
   511  					csumOffset: 16,
   512  				}
   513  				pkt := bufs[item.bufsIndex][offset:]
   514  
   515  				// Recalculate the total len (IPv4) or payload len (IPv6).
   516  				// Recalculate the (IPv4) header checksum.
   517  				if isV6 {
   518  					hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
   519  					binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
   520  				} else {
   521  					hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
   522  					pkt[10], pkt[11] = 0, 0
   523  					binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
   524  					iphCSum := ^checksumFold(pkt[:item.iphLen], 0)        // compute IPv4 header checksum
   525  					binary.BigEndian.PutUint16(pkt[10:], iphCSum)         // set IPv4 header checksum field
   526  				}
   527  				err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
   528  				if err != nil {
   529  					return err
   530  				}
   531  
   532  				// Calculate the pseudo header checksum and place it at the TCP
   533  				// checksum offset. Downstream checksum offloading will combine
   534  				// this with computation of the tcp header and payload checksum.
   535  				addrLen := 4
   536  				addrOffset := ipv4SrcAddrOffset
   537  				if isV6 {
   538  					addrLen = 16
   539  					addrOffset = ipv6SrcAddrOffset
   540  				}
   541  				srcAddrAt := offset + addrOffset
   542  				srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
   543  				dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
   544  				psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
   545  				binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum))
   546  			} else {
   547  				hdr := virtioNetHdr{}
   548  				err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
   549  				if err != nil {
   550  					return err
   551  				}
   552  			}
   553  		}
   554  	}
   555  	return nil
   556  }
   557  
   558  // handleGRO evaluates bufs for GRO, and writes the indices of the resulting
   559  // packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
   560  // empty (but non-nil), and are passed in to save allocs as the caller may reset
   561  // and recycle them across vectors of packets.
   562  func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
   563  	for i := range bufs {
   564  		if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
   565  			return errors.New("invalid offset")
   566  		}
   567  		var result tcpGROResult
   568  		switch {
   569  		case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
   570  			result = tcpGRO(bufs, offset, i, tcp4Table, false)
   571  		case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
   572  			result = tcpGRO(bufs, offset, i, tcp6Table, true)
   573  		}
   574  		switch result {
   575  		case tcpGROResultNoop:
   576  			hdr := virtioNetHdr{}
   577  			err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
   578  			if err != nil {
   579  				return err
   580  			}
   581  			fallthrough
   582  		case tcpGROResultTableInsert:
   583  			*toWrite = append(*toWrite, i)
   584  		}
   585  	}
   586  	err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
   587  	err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
   588  	return E.Errors(err4, err6)
   589  }
   590  
   591  // tcpTSO splits packets from in into outBuffs, writing the size of each
   592  // element into sizes. It returns the number of buffers populated, and/or an
   593  // error.
   594  func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
   595  	iphLen := int(hdr.csumStart)
   596  	srcAddrOffset := ipv6SrcAddrOffset
   597  	addrLen := 16
   598  	if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
   599  		in[10], in[11] = 0, 0 // clear ipv4 header checksum
   600  		srcAddrOffset = ipv4SrcAddrOffset
   601  		addrLen = 4
   602  	}
   603  	tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
   604  	in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
   605  	firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
   606  	nextSegmentDataAt := int(hdr.hdrLen)
   607  	i := 0
   608  	for ; nextSegmentDataAt < len(in); i++ {
   609  		if i == len(outBuffs) {
   610  			return i - 1, ErrTooManySegments
   611  		}
   612  		nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
   613  		if nextSegmentEnd > len(in) {
   614  			nextSegmentEnd = len(in)
   615  		}
   616  		segmentDataLen := nextSegmentEnd - nextSegmentDataAt
   617  		totalLen := int(hdr.hdrLen) + segmentDataLen
   618  		sizes[i] = totalLen
   619  		out := outBuffs[i][outOffset:]
   620  
   621  		copy(out, in[:iphLen])
   622  		if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
   623  			// For IPv4 we are responsible for incrementing the ID field,
   624  			// updating the total len field, and recalculating the header
   625  			// checksum.
   626  			if i > 0 {
   627  				id := binary.BigEndian.Uint16(out[4:])
   628  				id += uint16(i)
   629  				binary.BigEndian.PutUint16(out[4:], id)
   630  			}
   631  			binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
   632  			ipv4CSum := ^checksumFold(out[:iphLen], 0)
   633  			binary.BigEndian.PutUint16(out[10:], ipv4CSum)
   634  		} else {
   635  			// For IPv6 we are responsible for updating the payload length field.
   636  			binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
   637  		}
   638  
   639  		// TCP header
   640  		copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
   641  		tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
   642  		binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
   643  		if nextSegmentEnd != len(in) {
   644  			// FIN and PSH should only be set on last segment
   645  			clearFlags := tcpFlagFIN | tcpFlagPSH
   646  			out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
   647  		}
   648  
   649  		// payload
   650  		copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
   651  
   652  		// TCP checksum
   653  		tcpHLen := int(hdr.hdrLen - hdr.csumStart)
   654  		tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
   655  		tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
   656  		tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold)
   657  		binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
   658  
   659  		nextSegmentDataAt += int(hdr.gsoSize)
   660  	}
   661  	return i, nil
   662  }
   663  
   664  func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
   665  	cSumAt := cSumStart + cSumOffset
   666  	// The initial value at the checksum offset should be summed with the
   667  	// checksum we compute. This is typically the pseudo-header checksum.
   668  	initial := binary.BigEndian.Uint16(in[cSumAt:])
   669  	in[cSumAt], in[cSumAt+1] = 0, 0
   670  	binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial)))
   671  	return nil
   672  }
   673  
   674  // handleVirtioRead splits in into bufs, leaving offset bytes at the front of
   675  // each buffer. It mutates sizes to reflect the size of each element of bufs,
   676  // and returns the number of packets read.
   677  func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
   678  	var hdr virtioNetHdr
   679  	err := hdr.decode(in)
   680  	if err != nil {
   681  		return 0, err
   682  	}
   683  	in = in[virtioNetHdrLen:]
   684  	if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
   685  		if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
   686  			// This means CHECKSUM_PARTIAL in skb context. We are responsible
   687  			// for computing the checksum starting at hdr.csumStart and placing
   688  			// at hdr.csumOffset.
   689  			err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
   690  			if err != nil {
   691  				return 0, err
   692  			}
   693  		}
   694  		if len(in) > len(bufs[0][offset:]) {
   695  			return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
   696  		}
   697  		n := copy(bufs[0][offset:], in)
   698  		sizes[0] = n
   699  		return 1, nil
   700  	}
   701  	if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
   702  		return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
   703  	}
   704  
   705  	ipVersion := in[0] >> 4
   706  	switch ipVersion {
   707  	case 4:
   708  		if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
   709  			return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
   710  		}
   711  	case 6:
   712  		if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
   713  			return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
   714  		}
   715  	default:
   716  		return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
   717  	}
   718  
   719  	if len(in) <= int(hdr.csumStart+12) {
   720  		return 0, errors.New("packet is too short")
   721  	}
   722  	// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
   723  	// of the entire first packet when the kernel is handling it as part of a
   724  	// FORWARD path. Instead, parse the TCP header length and add it onto
   725  	// csumStart, which is synonymous for IP header length.
   726  	tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
   727  	if tcpHLen < 20 || tcpHLen > 60 {
   728  		// A TCP header must be between 20 and 60 bytes in length.
   729  		return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
   730  	}
   731  	hdr.hdrLen = hdr.csumStart + tcpHLen
   732  
   733  	if len(in) < int(hdr.hdrLen) {
   734  		return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
   735  	}
   736  
   737  	if hdr.hdrLen < hdr.csumStart {
   738  		return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
   739  	}
   740  	cSumAt := int(hdr.csumStart + hdr.csumOffset)
   741  	if cSumAt+1 >= len(in) {
   742  		return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
   743  	}
   744  
   745  	return tcpTSO(in, hdr, bufs, sizes, offset)
   746  }
   747  
   748  func checksumNoFold(b []byte, initial uint64) uint64 {
   749  	return initial + uint64(clashtcpip.Sum(b))
   750  }
   751  
   752  func checksumFold(b []byte, initial uint64) uint16 {
   753  	ac := checksumNoFold(b, initial)
   754  	ac = (ac >> 16) + (ac & 0xffff)
   755  	ac = (ac >> 16) + (ac & 0xffff)
   756  	ac = (ac >> 16) + (ac & 0xffff)
   757  	ac = (ac >> 16) + (ac & 0xffff)
   758  	return uint16(ac)
   759  }
   760  
   761  func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
   762  	sum := checksumNoFold(srcAddr, 0)
   763  	sum = checksumNoFold(dstAddr, sum)
   764  	sum = checksumNoFold([]byte{0, protocol}, sum)
   765  	tmp := make([]byte, 2)
   766  	binary.BigEndian.PutUint16(tmp, totalLen)
   767  	return checksumNoFold(tmp, sum)
   768  }