github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/offload_linux_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"net/netip"
    10  	"testing"
    11  
    12  	"github.com/amnezia-vpn/amneziawg-go/conn"
    13  	"golang.org/x/sys/unix"
    14  	"gvisor.dev/gvisor/pkg/tcpip"
    15  	"gvisor.dev/gvisor/pkg/tcpip/header"
    16  )
    17  
    18  const (
    19  	offset = virtioNetHdrLen
    20  )
    21  
    22  var (
    23  	ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
    24  	ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
    25  	ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
    26  	ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
    27  	ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
    28  	ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
    29  )
    30  
    31  func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
    32  	totalLen := 28 + payloadLen
    33  	b := make([]byte, offset+int(totalLen), 65535)
    34  	ipv4H := header.IPv4(b[offset:])
    35  	srcAs4 := srcIPPort.Addr().As4()
    36  	dstAs4 := dstIPPort.Addr().As4()
    37  	ipFields := &header.IPv4Fields{
    38  		SrcAddr:     tcpip.AddrFromSlice(srcAs4[:]),
    39  		DstAddr:     tcpip.AddrFromSlice(dstAs4[:]),
    40  		Protocol:    unix.IPPROTO_UDP,
    41  		TTL:         64,
    42  		TotalLength: uint16(totalLen),
    43  	}
    44  	if ipFn != nil {
    45  		ipFn(ipFields)
    46  	}
    47  	ipv4H.Encode(ipFields)
    48  	udpH := header.UDP(b[offset+20:])
    49  	udpH.Encode(&header.UDPFields{
    50  		SrcPort: srcIPPort.Port(),
    51  		DstPort: dstIPPort.Port(),
    52  		Length:  uint16(payloadLen + udphLen),
    53  	})
    54  	ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
    55  	pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
    56  	udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
    57  	return b
    58  }
    59  
    60  func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
    61  	return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
    62  }
    63  
    64  func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
    65  	totalLen := 48 + payloadLen
    66  	b := make([]byte, offset+int(totalLen), 65535)
    67  	ipv6H := header.IPv6(b[offset:])
    68  	srcAs16 := srcIPPort.Addr().As16()
    69  	dstAs16 := dstIPPort.Addr().As16()
    70  	ipFields := &header.IPv6Fields{
    71  		SrcAddr:           tcpip.AddrFromSlice(srcAs16[:]),
    72  		DstAddr:           tcpip.AddrFromSlice(dstAs16[:]),
    73  		TransportProtocol: unix.IPPROTO_UDP,
    74  		HopLimit:          64,
    75  		PayloadLength:     uint16(payloadLen + udphLen),
    76  	}
    77  	if ipFn != nil {
    78  		ipFn(ipFields)
    79  	}
    80  	ipv6H.Encode(ipFields)
    81  	udpH := header.UDP(b[offset+40:])
    82  	udpH.Encode(&header.UDPFields{
    83  		SrcPort: srcIPPort.Port(),
    84  		DstPort: dstIPPort.Port(),
    85  		Length:  uint16(payloadLen + udphLen),
    86  	})
    87  	pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
    88  	udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
    89  	return b
    90  }
    91  
    92  func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
    93  	return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
    94  }
    95  
    96  func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
    97  	totalLen := 40 + segmentSize
    98  	b := make([]byte, offset+int(totalLen), 65535)
    99  	ipv4H := header.IPv4(b[offset:])
   100  	srcAs4 := srcIPPort.Addr().As4()
   101  	dstAs4 := dstIPPort.Addr().As4()
   102  	ipFields := &header.IPv4Fields{
   103  		SrcAddr:     tcpip.AddrFromSlice(srcAs4[:]),
   104  		DstAddr:     tcpip.AddrFromSlice(dstAs4[:]),
   105  		Protocol:    unix.IPPROTO_TCP,
   106  		TTL:         64,
   107  		TotalLength: uint16(totalLen),
   108  	}
   109  	if ipFn != nil {
   110  		ipFn(ipFields)
   111  	}
   112  	ipv4H.Encode(ipFields)
   113  	tcpH := header.TCP(b[offset+20:])
   114  	tcpH.Encode(&header.TCPFields{
   115  		SrcPort:    srcIPPort.Port(),
   116  		DstPort:    dstIPPort.Port(),
   117  		SeqNum:     seq,
   118  		AckNum:     1,
   119  		DataOffset: 20,
   120  		Flags:      flags,
   121  		WindowSize: 3000,
   122  	})
   123  	ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
   124  	pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
   125  	tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
   126  	return b
   127  }
   128  
   129  func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
   130  	return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
   131  }
   132  
   133  func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
   134  	totalLen := 60 + segmentSize
   135  	b := make([]byte, offset+int(totalLen), 65535)
   136  	ipv6H := header.IPv6(b[offset:])
   137  	srcAs16 := srcIPPort.Addr().As16()
   138  	dstAs16 := dstIPPort.Addr().As16()
   139  	ipFields := &header.IPv6Fields{
   140  		SrcAddr:           tcpip.AddrFromSlice(srcAs16[:]),
   141  		DstAddr:           tcpip.AddrFromSlice(dstAs16[:]),
   142  		TransportProtocol: unix.IPPROTO_TCP,
   143  		HopLimit:          64,
   144  		PayloadLength:     uint16(segmentSize + 20),
   145  	}
   146  	if ipFn != nil {
   147  		ipFn(ipFields)
   148  	}
   149  	ipv6H.Encode(ipFields)
   150  	tcpH := header.TCP(b[offset+40:])
   151  	tcpH.Encode(&header.TCPFields{
   152  		SrcPort:    srcIPPort.Port(),
   153  		DstPort:    dstIPPort.Port(),
   154  		SeqNum:     seq,
   155  		AckNum:     1,
   156  		DataOffset: 20,
   157  		Flags:      flags,
   158  		WindowSize: 3000,
   159  	})
   160  	pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
   161  	tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
   162  	return b
   163  }
   164  
   165  func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
   166  	return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
   167  }
   168  
   169  func Test_handleVirtioRead(t *testing.T) {
   170  	tests := []struct {
   171  		name     string
   172  		hdr      virtioNetHdr
   173  		pktIn    []byte
   174  		wantLens []int
   175  		wantErr  bool
   176  	}{
   177  		{
   178  			"tcp4",
   179  			virtioNetHdr{
   180  				flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
   181  				gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV4,
   182  				gsoSize:    100,
   183  				hdrLen:     40,
   184  				csumStart:  20,
   185  				csumOffset: 16,
   186  			},
   187  			tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
   188  			[]int{140, 140},
   189  			false,
   190  		},
   191  		{
   192  			"tcp6",
   193  			virtioNetHdr{
   194  				flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
   195  				gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV6,
   196  				gsoSize:    100,
   197  				hdrLen:     60,
   198  				csumStart:  40,
   199  				csumOffset: 16,
   200  			},
   201  			tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
   202  			[]int{160, 160},
   203  			false,
   204  		},
   205  		{
   206  			"udp4",
   207  			virtioNetHdr{
   208  				flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
   209  				gsoType:    unix.VIRTIO_NET_HDR_GSO_UDP_L4,
   210  				gsoSize:    100,
   211  				hdrLen:     28,
   212  				csumStart:  20,
   213  				csumOffset: 6,
   214  			},
   215  			udp4Packet(ip4PortA, ip4PortB, 200),
   216  			[]int{128, 128},
   217  			false,
   218  		},
   219  		{
   220  			"udp6",
   221  			virtioNetHdr{
   222  				flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
   223  				gsoType:    unix.VIRTIO_NET_HDR_GSO_UDP_L4,
   224  				gsoSize:    100,
   225  				hdrLen:     48,
   226  				csumStart:  40,
   227  				csumOffset: 6,
   228  			},
   229  			udp6Packet(ip6PortA, ip6PortB, 200),
   230  			[]int{148, 148},
   231  			false,
   232  		},
   233  	}
   234  
   235  	for _, tt := range tests {
   236  		t.Run(tt.name, func(t *testing.T) {
   237  			out := make([][]byte, conn.IdealBatchSize)
   238  			sizes := make([]int, conn.IdealBatchSize)
   239  			for i := range out {
   240  				out[i] = make([]byte, 65535)
   241  			}
   242  			tt.hdr.encode(tt.pktIn)
   243  			n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
   244  			if err != nil {
   245  				if tt.wantErr {
   246  					return
   247  				}
   248  				t.Fatalf("got err: %v", err)
   249  			}
   250  			if n != len(tt.wantLens) {
   251  				t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
   252  			}
   253  			for i := range tt.wantLens {
   254  				if tt.wantLens[i] != sizes[i] {
   255  					t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
   256  				}
   257  			}
   258  		})
   259  	}
   260  }
   261  
   262  func flipTCP4Checksum(b []byte) []byte {
   263  	at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
   264  	b[at] ^= 0xFF
   265  	b[at+1] ^= 0xFF
   266  	return b
   267  }
   268  
   269  func flipUDP4Checksum(b []byte) []byte {
   270  	at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
   271  	b[at] ^= 0xFF
   272  	b[at+1] ^= 0xFF
   273  	return b
   274  }
   275  
   276  func Fuzz_handleGRO(f *testing.F) {
   277  	pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
   278  	pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
   279  	pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
   280  	pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
   281  	pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
   282  	pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
   283  	pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
   284  	pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
   285  	pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
   286  	pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
   287  	pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
   288  	pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
   289  	f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
   290  	f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
   291  		pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
   292  		toWrite := make([]int, 0, len(pkts))
   293  		handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
   294  		if len(toWrite) > len(pkts) {
   295  			t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
   296  		}
   297  		seenWriteI := make(map[int]bool)
   298  		for _, writeI := range toWrite {
   299  			if writeI < 0 || writeI > len(pkts)-1 {
   300  				t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
   301  			}
   302  			if seenWriteI[writeI] {
   303  				t.Errorf("duplicate toWrite value: %d", writeI)
   304  			}
   305  			seenWriteI[writeI] = true
   306  		}
   307  	})
   308  }
   309  
   310  func Test_handleGRO(t *testing.T) {
   311  	tests := []struct {
   312  		name        string
   313  		pktsIn      [][]byte
   314  		canUDPGRO   bool
   315  		wantToWrite []int
   316  		wantLens    []int
   317  		wantErr     bool
   318  	}{
   319  		{
   320  			"multiple protocols and flows",
   321  			[][]byte{
   322  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // tcp4 flow 1
   323  				udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
   324  				udp4Packet(ip4PortA, ip4PortC, 100),                         // udp4 flow 2
   325  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
   326  				tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
   327  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),   // tcp6 flow 1
   328  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
   329  				tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
   330  				udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
   331  				udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
   332  				udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
   333  			},
   334  			true,
   335  			[]int{0, 1, 2, 4, 5, 7, 9},
   336  			[]int{240, 228, 128, 140, 260, 160, 248},
   337  			false,
   338  		},
   339  		{
   340  			"multiple protocols and flows no UDP GRO",
   341  			[][]byte{
   342  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // tcp4 flow 1
   343  				udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
   344  				udp4Packet(ip4PortA, ip4PortC, 100),                         // udp4 flow 2
   345  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
   346  				tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
   347  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),   // tcp6 flow 1
   348  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
   349  				tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
   350  				udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
   351  				udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
   352  				udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
   353  			},
   354  			false,
   355  			[]int{0, 1, 2, 4, 5, 7, 8, 9, 10},
   356  			[]int{240, 128, 128, 140, 260, 160, 128, 148, 148},
   357  			false,
   358  		},
   359  		{
   360  			"PSH interleaved",
   361  			[][]byte{
   362  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),                     // v4 flow 1
   363  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
   364  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                   // v4 flow 1
   365  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301),                   // v4 flow 1
   366  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),                     // v6 flow 1
   367  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
   368  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201),                   // v6 flow 1
   369  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301),                   // v6 flow 1
   370  			},
   371  			true,
   372  			[]int{0, 2, 4, 6},
   373  			[]int{240, 240, 260, 260},
   374  			false,
   375  		},
   376  		{
   377  			"coalesceItemInvalidCSum",
   378  			[][]byte{
   379  				flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
   380  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101),                 // v4 flow 1 seq 101 len 100
   381  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                 // v4 flow 1 seq 201 len 100
   382  				flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
   383  				udp4Packet(ip4PortA, ip4PortB, 100),
   384  				udp4Packet(ip4PortA, ip4PortB, 100),
   385  			},
   386  			true,
   387  			[]int{0, 1, 3, 4},
   388  			[]int{140, 240, 128, 228},
   389  			false,
   390  		},
   391  		{
   392  			"out of order",
   393  			[][]byte{
   394  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
   395  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // v4 flow 1 seq 1 len 100
   396  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
   397  			},
   398  			true,
   399  			[]int{0},
   400  			[]int{340},
   401  			false,
   402  		},
   403  		{
   404  			"unequal TTL",
   405  			[][]byte{
   406  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   407  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   408  					fields.TTL++
   409  				}),
   410  				udp4Packet(ip4PortA, ip4PortB, 100),
   411  				udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
   412  					fields.TTL++
   413  				}),
   414  			},
   415  			true,
   416  			[]int{0, 1, 2, 3},
   417  			[]int{140, 140, 128, 128},
   418  			false,
   419  		},
   420  		{
   421  			"unequal ToS",
   422  			[][]byte{
   423  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   424  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   425  					fields.TOS++
   426  				}),
   427  				udp4Packet(ip4PortA, ip4PortB, 100),
   428  				udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
   429  					fields.TOS++
   430  				}),
   431  			},
   432  			true,
   433  			[]int{0, 1, 2, 3},
   434  			[]int{140, 140, 128, 128},
   435  			false,
   436  		},
   437  		{
   438  			"unequal flags more fragments set",
   439  			[][]byte{
   440  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   441  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   442  					fields.Flags = 1
   443  				}),
   444  				udp4Packet(ip4PortA, ip4PortB, 100),
   445  				udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
   446  					fields.Flags = 1
   447  				}),
   448  			},
   449  			true,
   450  			[]int{0, 1, 2, 3},
   451  			[]int{140, 140, 128, 128},
   452  			false,
   453  		},
   454  		{
   455  			"unequal flags DF set",
   456  			[][]byte{
   457  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   458  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   459  					fields.Flags = 2
   460  				}),
   461  				udp4Packet(ip4PortA, ip4PortB, 100),
   462  				udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
   463  					fields.Flags = 2
   464  				}),
   465  			},
   466  			true,
   467  			[]int{0, 1, 2, 3},
   468  			[]int{140, 140, 128, 128},
   469  			false,
   470  		},
   471  		{
   472  			"ipv6 unequal hop limit",
   473  			[][]byte{
   474  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
   475  				tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
   476  					fields.HopLimit++
   477  				}),
   478  				udp6Packet(ip6PortA, ip6PortB, 100),
   479  				udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
   480  					fields.HopLimit++
   481  				}),
   482  			},
   483  			true,
   484  			[]int{0, 1, 2, 3},
   485  			[]int{160, 160, 148, 148},
   486  			false,
   487  		},
   488  		{
   489  			"ipv6 unequal traffic class",
   490  			[][]byte{
   491  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
   492  				tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
   493  					fields.TrafficClass++
   494  				}),
   495  				udp6Packet(ip6PortA, ip6PortB, 100),
   496  				udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
   497  					fields.TrafficClass++
   498  				}),
   499  			},
   500  			true,
   501  			[]int{0, 1, 2, 3},
   502  			[]int{160, 160, 148, 148},
   503  			false,
   504  		},
   505  	}
   506  
   507  	for _, tt := range tests {
   508  		t.Run(tt.name, func(t *testing.T) {
   509  			toWrite := make([]int, 0, len(tt.pktsIn))
   510  			err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
   511  			if err != nil {
   512  				if tt.wantErr {
   513  					return
   514  				}
   515  				t.Fatalf("got err: %v", err)
   516  			}
   517  			if len(toWrite) != len(tt.wantToWrite) {
   518  				t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
   519  			}
   520  			for i, pktI := range tt.wantToWrite {
   521  				if tt.wantToWrite[i] != toWrite[i] {
   522  					t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
   523  				}
   524  				if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
   525  					t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
   526  				}
   527  			}
   528  		})
   529  	}
   530  }
   531  
   532  func Test_packetIsGROCandidate(t *testing.T) {
   533  	tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
   534  	tcp4TooShort := tcp4[:39]
   535  	ip4InvalidHeaderLen := make([]byte, len(tcp4))
   536  	copy(ip4InvalidHeaderLen, tcp4)
   537  	ip4InvalidHeaderLen[0] = 0x46
   538  	ip4InvalidProtocol := make([]byte, len(tcp4))
   539  	copy(ip4InvalidProtocol, tcp4)
   540  	ip4InvalidProtocol[9] = unix.IPPROTO_GRE
   541  
   542  	tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
   543  	tcp6TooShort := tcp6[:59]
   544  	ip6InvalidProtocol := make([]byte, len(tcp6))
   545  	copy(ip6InvalidProtocol, tcp6)
   546  	ip6InvalidProtocol[6] = unix.IPPROTO_GRE
   547  
   548  	udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
   549  	udp4TooShort := udp4[:27]
   550  
   551  	udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
   552  	udp6TooShort := udp6[:47]
   553  
   554  	tests := []struct {
   555  		name      string
   556  		b         []byte
   557  		canUDPGRO bool
   558  		want      groCandidateType
   559  	}{
   560  		{
   561  			"tcp4",
   562  			tcp4,
   563  			true,
   564  			tcp4GROCandidate,
   565  		},
   566  		{
   567  			"tcp6",
   568  			tcp6,
   569  			true,
   570  			tcp6GROCandidate,
   571  		},
   572  		{
   573  			"udp4",
   574  			udp4,
   575  			true,
   576  			udp4GROCandidate,
   577  		},
   578  		{
   579  			"udp4 no support",
   580  			udp4,
   581  			false,
   582  			notGROCandidate,
   583  		},
   584  		{
   585  			"udp6",
   586  			udp6,
   587  			true,
   588  			udp6GROCandidate,
   589  		},
   590  		{
   591  			"udp6 no support",
   592  			udp6,
   593  			false,
   594  			notGROCandidate,
   595  		},
   596  		{
   597  			"udp4 too short",
   598  			udp4TooShort,
   599  			true,
   600  			notGROCandidate,
   601  		},
   602  		{
   603  			"udp6 too short",
   604  			udp6TooShort,
   605  			true,
   606  			notGROCandidate,
   607  		},
   608  		{
   609  			"tcp4 too short",
   610  			tcp4TooShort,
   611  			true,
   612  			notGROCandidate,
   613  		},
   614  		{
   615  			"tcp6 too short",
   616  			tcp6TooShort,
   617  			true,
   618  			notGROCandidate,
   619  		},
   620  		{
   621  			"invalid IP version",
   622  			[]byte{0x00},
   623  			true,
   624  			notGROCandidate,
   625  		},
   626  		{
   627  			"invalid IP header len",
   628  			ip4InvalidHeaderLen,
   629  			true,
   630  			notGROCandidate,
   631  		},
   632  		{
   633  			"ip4 invalid protocol",
   634  			ip4InvalidProtocol,
   635  			true,
   636  			notGROCandidate,
   637  		},
   638  		{
   639  			"ip6 invalid protocol",
   640  			ip6InvalidProtocol,
   641  			true,
   642  			notGROCandidate,
   643  		},
   644  	}
   645  	for _, tt := range tests {
   646  		t.Run(tt.name, func(t *testing.T) {
   647  			if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
   648  				t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
   649  			}
   650  		})
   651  	}
   652  }
   653  
   654  func Test_udpPacketsCanCoalesce(t *testing.T) {
   655  	udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
   656  	udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
   657  	udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
   658  
   659  	type args struct {
   660  		pkt        []byte
   661  		iphLen     uint8
   662  		gsoSize    uint16
   663  		item       udpGROItem
   664  		bufs       [][]byte
   665  		bufsOffset int
   666  	}
   667  	tests := []struct {
   668  		name string
   669  		args args
   670  		want canCoalesce
   671  	}{
   672  		{
   673  			"coalesceAppend equal gso",
   674  			args{
   675  				pkt:     udp4a[offset:],
   676  				iphLen:  20,
   677  				gsoSize: 100,
   678  				item: udpGROItem{
   679  					gsoSize: 100,
   680  					iphLen:  20,
   681  				},
   682  				bufs: [][]byte{
   683  					udp4a,
   684  					udp4b,
   685  				},
   686  				bufsOffset: offset,
   687  			},
   688  			coalesceAppend,
   689  		},
   690  		{
   691  			"coalesceAppend smaller gso",
   692  			args{
   693  				pkt:     udp4a[offset : len(udp4a)-90],
   694  				iphLen:  20,
   695  				gsoSize: 10,
   696  				item: udpGROItem{
   697  					gsoSize: 100,
   698  					iphLen:  20,
   699  				},
   700  				bufs: [][]byte{
   701  					udp4a,
   702  					udp4b,
   703  				},
   704  				bufsOffset: offset,
   705  			},
   706  			coalesceAppend,
   707  		},
   708  		{
   709  			"coalesceUnavailable smaller gso previously appended",
   710  			args{
   711  				pkt:     udp4a[offset:],
   712  				iphLen:  20,
   713  				gsoSize: 100,
   714  				item: udpGROItem{
   715  					gsoSize: 100,
   716  					iphLen:  20,
   717  				},
   718  				bufs: [][]byte{
   719  					udp4c,
   720  					udp4b,
   721  				},
   722  				bufsOffset: offset,
   723  			},
   724  			coalesceUnavailable,
   725  		},
   726  		{
   727  			"coalesceUnavailable larger following smaller",
   728  			args{
   729  				pkt:     udp4c[offset:],
   730  				iphLen:  20,
   731  				gsoSize: 110,
   732  				item: udpGROItem{
   733  					gsoSize: 100,
   734  					iphLen:  20,
   735  				},
   736  				bufs: [][]byte{
   737  					udp4a,
   738  					udp4c,
   739  				},
   740  				bufsOffset: offset,
   741  			},
   742  			coalesceUnavailable,
   743  		},
   744  	}
   745  	for _, tt := range tests {
   746  		t.Run(tt.name, func(t *testing.T) {
   747  			if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
   748  				t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
   749  			}
   750  		})
   751  	}
   752  }