github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/tun/tcp_offload_linux_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"net/netip"
    10  	"testing"
    11  
    12  	"golang.org/x/sys/unix"
    13  	"github.com/koomox/wireguard-go/conn"
    14  	"gvisor.dev/gvisor/pkg/tcpip"
    15  	"gvisor.dev/gvisor/pkg/tcpip/header"
    16  )
    17  
    18  const (
    19  	offset = virtioNetHdrLen
    20  )
    21  
    22  var (
    23  	ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
    24  	ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
    25  	ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
    26  	ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
    27  	ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
    28  	ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
    29  )
    30  
    31  func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
    32  	totalLen := 40 + segmentSize
    33  	b := make([]byte, offset+int(totalLen), 65535)
    34  	ipv4H := header.IPv4(b[offset:])
    35  	srcAs4 := srcIPPort.Addr().As4()
    36  	dstAs4 := dstIPPort.Addr().As4()
    37  	ipFields := &header.IPv4Fields{
    38  		SrcAddr:     tcpip.Address(srcAs4[:]),
    39  		DstAddr:     tcpip.Address(dstAs4[:]),
    40  		Protocol:    unix.IPPROTO_TCP,
    41  		TTL:         64,
    42  		TotalLength: uint16(totalLen),
    43  	}
    44  	if ipFn != nil {
    45  		ipFn(ipFields)
    46  	}
    47  	ipv4H.Encode(ipFields)
    48  	tcpH := header.TCP(b[offset+20:])
    49  	tcpH.Encode(&header.TCPFields{
    50  		SrcPort:    srcIPPort.Port(),
    51  		DstPort:    dstIPPort.Port(),
    52  		SeqNum:     seq,
    53  		AckNum:     1,
    54  		DataOffset: 20,
    55  		Flags:      flags,
    56  		WindowSize: 3000,
    57  	})
    58  	ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
    59  	pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
    60  	tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
    61  	return b
    62  }
    63  
    64  func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
    65  	return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
    66  }
    67  
    68  func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
    69  	totalLen := 60 + segmentSize
    70  	b := make([]byte, offset+int(totalLen), 65535)
    71  	ipv6H := header.IPv6(b[offset:])
    72  	srcAs16 := srcIPPort.Addr().As16()
    73  	dstAs16 := dstIPPort.Addr().As16()
    74  	ipFields := &header.IPv6Fields{
    75  		SrcAddr:           tcpip.Address(srcAs16[:]),
    76  		DstAddr:           tcpip.Address(dstAs16[:]),
    77  		TransportProtocol: unix.IPPROTO_TCP,
    78  		HopLimit:          64,
    79  		PayloadLength:     uint16(segmentSize + 20),
    80  	}
    81  	if ipFn != nil {
    82  		ipFn(ipFields)
    83  	}
    84  	ipv6H.Encode(ipFields)
    85  	tcpH := header.TCP(b[offset+40:])
    86  	tcpH.Encode(&header.TCPFields{
    87  		SrcPort:    srcIPPort.Port(),
    88  		DstPort:    dstIPPort.Port(),
    89  		SeqNum:     seq,
    90  		AckNum:     1,
    91  		DataOffset: 20,
    92  		Flags:      flags,
    93  		WindowSize: 3000,
    94  	})
    95  	pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
    96  	tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
    97  	return b
    98  }
    99  
   100  func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
   101  	return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
   102  }
   103  
   104  func Test_handleVirtioRead(t *testing.T) {
   105  	tests := []struct {
   106  		name     string
   107  		hdr      virtioNetHdr
   108  		pktIn    []byte
   109  		wantLens []int
   110  		wantErr  bool
   111  	}{
   112  		{
   113  			"tcp4",
   114  			virtioNetHdr{
   115  				flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
   116  				gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV4,
   117  				gsoSize:    100,
   118  				hdrLen:     40,
   119  				csumStart:  20,
   120  				csumOffset: 16,
   121  			},
   122  			tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
   123  			[]int{140, 140},
   124  			false,
   125  		},
   126  		{
   127  			"tcp6",
   128  			virtioNetHdr{
   129  				flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
   130  				gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV6,
   131  				gsoSize:    100,
   132  				hdrLen:     60,
   133  				csumStart:  40,
   134  				csumOffset: 16,
   135  			},
   136  			tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
   137  			[]int{160, 160},
   138  			false,
   139  		},
   140  	}
   141  
   142  	for _, tt := range tests {
   143  		t.Run(tt.name, func(t *testing.T) {
   144  			out := make([][]byte, conn.IdealBatchSize)
   145  			sizes := make([]int, conn.IdealBatchSize)
   146  			for i := range out {
   147  				out[i] = make([]byte, 65535)
   148  			}
   149  			tt.hdr.encode(tt.pktIn)
   150  			n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
   151  			if err != nil {
   152  				if tt.wantErr {
   153  					return
   154  				}
   155  				t.Fatalf("got err: %v", err)
   156  			}
   157  			if n != len(tt.wantLens) {
   158  				t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
   159  			}
   160  			for i := range tt.wantLens {
   161  				if tt.wantLens[i] != sizes[i] {
   162  					t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
   163  				}
   164  			}
   165  		})
   166  	}
   167  }
   168  
   169  func flipTCP4Checksum(b []byte) []byte {
   170  	at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
   171  	b[at] ^= 0xFF
   172  	b[at+1] ^= 0xFF
   173  	return b
   174  }
   175  
   176  func Fuzz_handleGRO(f *testing.F) {
   177  	pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
   178  	pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
   179  	pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
   180  	pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
   181  	pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
   182  	pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
   183  	f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
   184  	f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
   185  		pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
   186  		toWrite := make([]int, 0, len(pkts))
   187  		handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
   188  		if len(toWrite) > len(pkts) {
   189  			t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
   190  		}
   191  		seenWriteI := make(map[int]bool)
   192  		for _, writeI := range toWrite {
   193  			if writeI < 0 || writeI > len(pkts)-1 {
   194  				t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
   195  			}
   196  			if seenWriteI[writeI] {
   197  				t.Errorf("duplicate toWrite value: %d", writeI)
   198  			}
   199  			seenWriteI[writeI] = true
   200  		}
   201  	})
   202  }
   203  
   204  func Test_handleGRO(t *testing.T) {
   205  	tests := []struct {
   206  		name        string
   207  		pktsIn      [][]byte
   208  		wantToWrite []int
   209  		wantLens    []int
   210  		wantErr     bool
   211  	}{
   212  		{
   213  			"multiple flows",
   214  			[][]byte{
   215  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // v4 flow 1
   216  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
   217  				tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
   218  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),   // v6 flow 1
   219  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
   220  				tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
   221  			},
   222  			[]int{0, 2, 3, 5},
   223  			[]int{240, 140, 260, 160},
   224  			false,
   225  		},
   226  		{
   227  			"PSH interleaved",
   228  			[][]byte{
   229  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),                     // v4 flow 1
   230  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
   231  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                   // v4 flow 1
   232  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301),                   // v4 flow 1
   233  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),                     // v6 flow 1
   234  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
   235  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201),                   // v6 flow 1
   236  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301),                   // v6 flow 1
   237  			},
   238  			[]int{0, 2, 4, 6},
   239  			[]int{240, 240, 260, 260},
   240  			false,
   241  		},
   242  		{
   243  			"coalesceItemInvalidCSum",
   244  			[][]byte{
   245  				flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
   246  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101),                 // v4 flow 1 seq 101 len 100
   247  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                 // v4 flow 1 seq 201 len 100
   248  			},
   249  			[]int{0, 1},
   250  			[]int{140, 240},
   251  			false,
   252  		},
   253  		{
   254  			"out of order",
   255  			[][]byte{
   256  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
   257  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // v4 flow 1 seq 1 len 100
   258  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
   259  			},
   260  			[]int{0},
   261  			[]int{340},
   262  			false,
   263  		},
   264  		{
   265  			"tcp4 unequal TTL",
   266  			[][]byte{
   267  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   268  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   269  					fields.TTL++
   270  				}),
   271  			},
   272  			[]int{0, 1},
   273  			[]int{140, 140},
   274  			false,
   275  		},
   276  		{
   277  			"tcp4 unequal ToS",
   278  			[][]byte{
   279  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   280  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   281  					fields.TOS++
   282  				}),
   283  			},
   284  			[]int{0, 1},
   285  			[]int{140, 140},
   286  			false,
   287  		},
   288  		{
   289  			"tcp4 unequal flags more fragments set",
   290  			[][]byte{
   291  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   292  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   293  					fields.Flags = 1
   294  				}),
   295  			},
   296  			[]int{0, 1},
   297  			[]int{140, 140},
   298  			false,
   299  		},
   300  		{
   301  			"tcp4 unequal flags DF set",
   302  			[][]byte{
   303  				tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
   304  				tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
   305  					fields.Flags = 2
   306  				}),
   307  			},
   308  			[]int{0, 1},
   309  			[]int{140, 140},
   310  			false,
   311  		},
   312  		{
   313  			"tcp6 unequal hop limit",
   314  			[][]byte{
   315  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
   316  				tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
   317  					fields.HopLimit++
   318  				}),
   319  			},
   320  			[]int{0, 1},
   321  			[]int{160, 160},
   322  			false,
   323  		},
   324  		{
   325  			"tcp6 unequal traffic class",
   326  			[][]byte{
   327  				tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
   328  				tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
   329  					fields.TrafficClass++
   330  				}),
   331  			},
   332  			[]int{0, 1},
   333  			[]int{160, 160},
   334  			false,
   335  		},
   336  	}
   337  
   338  	for _, tt := range tests {
   339  		t.Run(tt.name, func(t *testing.T) {
   340  			toWrite := make([]int, 0, len(tt.pktsIn))
   341  			err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
   342  			if err != nil {
   343  				if tt.wantErr {
   344  					return
   345  				}
   346  				t.Fatalf("got err: %v", err)
   347  			}
   348  			if len(toWrite) != len(tt.wantToWrite) {
   349  				t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
   350  			}
   351  			for i, pktI := range tt.wantToWrite {
   352  				if tt.wantToWrite[i] != toWrite[i] {
   353  					t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
   354  				}
   355  				if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
   356  					t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
   357  				}
   358  			}
   359  		})
   360  	}
   361  }
   362  
   363  func Test_isTCP4NoIPOptions(t *testing.T) {
   364  	valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
   365  	invalidLen := valid[:39]
   366  	invalidHeaderLen := make([]byte, len(valid))
   367  	copy(invalidHeaderLen, valid)
   368  	invalidHeaderLen[0] = 0x46
   369  	invalidProtocol := make([]byte, len(valid))
   370  	copy(invalidProtocol, valid)
   371  	invalidProtocol[9] = unix.IPPROTO_TCP + 1
   372  
   373  	tests := []struct {
   374  		name string
   375  		b    []byte
   376  		want bool
   377  	}{
   378  		{
   379  			"valid",
   380  			valid,
   381  			true,
   382  		},
   383  		{
   384  			"invalid length",
   385  			invalidLen,
   386  			false,
   387  		},
   388  		{
   389  			"invalid version",
   390  			[]byte{0x00},
   391  			false,
   392  		},
   393  		{
   394  			"invalid header len",
   395  			invalidHeaderLen,
   396  			false,
   397  		},
   398  		{
   399  			"invalid protocol",
   400  			invalidProtocol,
   401  			false,
   402  		},
   403  	}
   404  	for _, tt := range tests {
   405  		t.Run(tt.name, func(t *testing.T) {
   406  			if got := isTCP4NoIPOptions(tt.b); got != tt.want {
   407  				t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want)
   408  			}
   409  		})
   410  	}
   411  }