github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/tun/tcp_offload_linux_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "net/netip" 10 "testing" 11 12 "golang.org/x/sys/unix" 13 "github.com/koomox/wireguard-go/conn" 14 "gvisor.dev/gvisor/pkg/tcpip" 15 "gvisor.dev/gvisor/pkg/tcpip/header" 16 ) 17 18 const ( 19 offset = virtioNetHdrLen 20 ) 21 22 var ( 23 ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") 24 ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") 25 ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") 26 ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") 27 ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") 28 ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") 29 ) 30 31 func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { 32 totalLen := 40 + segmentSize 33 b := make([]byte, offset+int(totalLen), 65535) 34 ipv4H := header.IPv4(b[offset:]) 35 srcAs4 := srcIPPort.Addr().As4() 36 dstAs4 := dstIPPort.Addr().As4() 37 ipFields := &header.IPv4Fields{ 38 SrcAddr: tcpip.Address(srcAs4[:]), 39 DstAddr: tcpip.Address(dstAs4[:]), 40 Protocol: unix.IPPROTO_TCP, 41 TTL: 64, 42 TotalLength: uint16(totalLen), 43 } 44 if ipFn != nil { 45 ipFn(ipFields) 46 } 47 ipv4H.Encode(ipFields) 48 tcpH := header.TCP(b[offset+20:]) 49 tcpH.Encode(&header.TCPFields{ 50 SrcPort: srcIPPort.Port(), 51 DstPort: dstIPPort.Port(), 52 SeqNum: seq, 53 AckNum: 1, 54 DataOffset: 20, 55 Flags: flags, 56 WindowSize: 3000, 57 }) 58 ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) 59 pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) 60 tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) 61 return b 62 } 63 64 func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { 65 return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) 66 } 67 68 func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { 69 totalLen := 60 + segmentSize 70 b := make([]byte, offset+int(totalLen), 65535) 71 ipv6H := header.IPv6(b[offset:]) 72 srcAs16 := srcIPPort.Addr().As16() 73 dstAs16 := dstIPPort.Addr().As16() 74 ipFields := &header.IPv6Fields{ 75 SrcAddr: tcpip.Address(srcAs16[:]), 76 DstAddr: tcpip.Address(dstAs16[:]), 77 TransportProtocol: unix.IPPROTO_TCP, 78 HopLimit: 64, 79 PayloadLength: uint16(segmentSize + 20), 80 } 81 if ipFn != nil { 82 ipFn(ipFields) 83 } 84 ipv6H.Encode(ipFields) 85 tcpH := header.TCP(b[offset+40:]) 86 tcpH.Encode(&header.TCPFields{ 87 SrcPort: srcIPPort.Port(), 88 DstPort: dstIPPort.Port(), 89 SeqNum: seq, 90 AckNum: 1, 91 DataOffset: 20, 92 Flags: flags, 93 WindowSize: 3000, 94 }) 95 pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) 96 tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) 97 return b 98 } 99 100 func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { 101 return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) 102 } 103 104 func Test_handleVirtioRead(t *testing.T) { 105 tests := []struct { 106 name string 107 hdr virtioNetHdr 108 pktIn []byte 109 wantLens []int 110 wantErr bool 111 }{ 112 { 113 "tcp4", 114 virtioNetHdr{ 115 flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, 116 gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, 117 gsoSize: 100, 118 hdrLen: 40, 119 csumStart: 20, 120 csumOffset: 16, 121 }, 122 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), 123 []int{140, 140}, 124 false, 125 }, 126 { 127 "tcp6", 128 virtioNetHdr{ 129 flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, 130 gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, 131 gsoSize: 100, 132 hdrLen: 60, 133 csumStart: 40, 134 csumOffset: 16, 135 }, 136 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), 137 []int{160, 160}, 138 false, 139 }, 140 } 141 142 for _, tt := range tests { 143 t.Run(tt.name, func(t *testing.T) { 144 out := make([][]byte, conn.IdealBatchSize) 145 sizes := make([]int, conn.IdealBatchSize) 146 for i := range out { 147 out[i] = make([]byte, 65535) 148 } 149 tt.hdr.encode(tt.pktIn) 150 n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) 151 if err != nil { 152 if tt.wantErr { 153 return 154 } 155 t.Fatalf("got err: %v", err) 156 } 157 if n != len(tt.wantLens) { 158 t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) 159 } 160 for i := range tt.wantLens { 161 if tt.wantLens[i] != sizes[i] { 162 t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) 163 } 164 } 165 }) 166 } 167 } 168 169 func flipTCP4Checksum(b []byte) []byte { 170 at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 171 b[at] ^= 0xFF 172 b[at+1] ^= 0xFF 173 return b 174 } 175 176 func Fuzz_handleGRO(f *testing.F) { 177 pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) 178 pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) 179 pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) 180 pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) 181 pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) 182 pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) 183 f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) 184 f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { 185 pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} 186 toWrite := make([]int, 0, len(pkts)) 187 handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) 188 if len(toWrite) > len(pkts) { 189 t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) 190 } 191 seenWriteI := make(map[int]bool) 192 for _, writeI := range toWrite { 193 if writeI < 0 || writeI > len(pkts)-1 { 194 t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) 195 } 196 if seenWriteI[writeI] { 197 t.Errorf("duplicate toWrite value: %d", writeI) 198 } 199 seenWriteI[writeI] = true 200 } 201 }) 202 } 203 204 func Test_handleGRO(t *testing.T) { 205 tests := []struct { 206 name string 207 pktsIn [][]byte 208 wantToWrite []int 209 wantLens []int 210 wantErr bool 211 }{ 212 { 213 "multiple flows", 214 [][]byte{ 215 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 216 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 217 tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 218 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 219 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 220 tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 221 }, 222 []int{0, 2, 3, 5}, 223 []int{240, 140, 260, 160}, 224 false, 225 }, 226 { 227 "PSH interleaved", 228 [][]byte{ 229 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 230 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 231 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 232 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 233 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 234 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 235 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 236 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 237 }, 238 []int{0, 2, 4, 6}, 239 []int{240, 240, 260, 260}, 240 false, 241 }, 242 { 243 "coalesceItemInvalidCSum", 244 [][]byte{ 245 flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 246 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 247 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 248 }, 249 []int{0, 1}, 250 []int{140, 240}, 251 false, 252 }, 253 { 254 "out of order", 255 [][]byte{ 256 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 257 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 258 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 259 }, 260 []int{0}, 261 []int{340}, 262 false, 263 }, 264 { 265 "tcp4 unequal TTL", 266 [][]byte{ 267 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 268 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 269 fields.TTL++ 270 }), 271 }, 272 []int{0, 1}, 273 []int{140, 140}, 274 false, 275 }, 276 { 277 "tcp4 unequal ToS", 278 [][]byte{ 279 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 280 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 281 fields.TOS++ 282 }), 283 }, 284 []int{0, 1}, 285 []int{140, 140}, 286 false, 287 }, 288 { 289 "tcp4 unequal flags more fragments set", 290 [][]byte{ 291 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 292 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 293 fields.Flags = 1 294 }), 295 }, 296 []int{0, 1}, 297 []int{140, 140}, 298 false, 299 }, 300 { 301 "tcp4 unequal flags DF set", 302 [][]byte{ 303 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 304 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 305 fields.Flags = 2 306 }), 307 }, 308 []int{0, 1}, 309 []int{140, 140}, 310 false, 311 }, 312 { 313 "tcp6 unequal hop limit", 314 [][]byte{ 315 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), 316 tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { 317 fields.HopLimit++ 318 }), 319 }, 320 []int{0, 1}, 321 []int{160, 160}, 322 false, 323 }, 324 { 325 "tcp6 unequal traffic class", 326 [][]byte{ 327 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), 328 tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { 329 fields.TrafficClass++ 330 }), 331 }, 332 []int{0, 1}, 333 []int{160, 160}, 334 false, 335 }, 336 } 337 338 for _, tt := range tests { 339 t.Run(tt.name, func(t *testing.T) { 340 toWrite := make([]int, 0, len(tt.pktsIn)) 341 err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) 342 if err != nil { 343 if tt.wantErr { 344 return 345 } 346 t.Fatalf("got err: %v", err) 347 } 348 if len(toWrite) != len(tt.wantToWrite) { 349 t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) 350 } 351 for i, pktI := range tt.wantToWrite { 352 if tt.wantToWrite[i] != toWrite[i] { 353 t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) 354 } 355 if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { 356 t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) 357 } 358 } 359 }) 360 } 361 } 362 363 func Test_isTCP4NoIPOptions(t *testing.T) { 364 valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] 365 invalidLen := valid[:39] 366 invalidHeaderLen := make([]byte, len(valid)) 367 copy(invalidHeaderLen, valid) 368 invalidHeaderLen[0] = 0x46 369 invalidProtocol := make([]byte, len(valid)) 370 copy(invalidProtocol, valid) 371 invalidProtocol[9] = unix.IPPROTO_TCP + 1 372 373 tests := []struct { 374 name string 375 b []byte 376 want bool 377 }{ 378 { 379 "valid", 380 valid, 381 true, 382 }, 383 { 384 "invalid length", 385 invalidLen, 386 false, 387 }, 388 { 389 "invalid version", 390 []byte{0x00}, 391 false, 392 }, 393 { 394 "invalid header len", 395 invalidHeaderLen, 396 false, 397 }, 398 { 399 "invalid protocol", 400 invalidProtocol, 401 false, 402 }, 403 } 404 for _, tt := range tests { 405 t.Run(tt.name, func(t *testing.T) { 406 if got := isTCP4NoIPOptions(tt.b); got != tt.want { 407 t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) 408 } 409 }) 410 } 411 }