github.com/amnezia-vpn/amnezia-wg@v0.1.8/tun/offload_linux_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "net/netip" 10 "testing" 11 12 "github.com/amnezia-vpn/amnezia-wg/conn" 13 "golang.org/x/sys/unix" 14 "gvisor.dev/gvisor/pkg/tcpip" 15 "gvisor.dev/gvisor/pkg/tcpip/header" 16 ) 17 18 const ( 19 offset = virtioNetHdrLen 20 ) 21 22 var ( 23 ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") 24 ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") 25 ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") 26 ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") 27 ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") 28 ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") 29 ) 30 31 func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { 32 totalLen := 28 + payloadLen 33 b := make([]byte, offset+int(totalLen), 65535) 34 ipv4H := header.IPv4(b[offset:]) 35 srcAs4 := srcIPPort.Addr().As4() 36 dstAs4 := dstIPPort.Addr().As4() 37 ipFields := &header.IPv4Fields{ 38 SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), 39 DstAddr: tcpip.AddrFromSlice(dstAs4[:]), 40 Protocol: unix.IPPROTO_UDP, 41 TTL: 64, 42 TotalLength: uint16(totalLen), 43 } 44 if ipFn != nil { 45 ipFn(ipFields) 46 } 47 ipv4H.Encode(ipFields) 48 udpH := header.UDP(b[offset+20:]) 49 udpH.Encode(&header.UDPFields{ 50 SrcPort: srcIPPort.Port(), 51 DstPort: dstIPPort.Port(), 52 Length: uint16(payloadLen + udphLen), 53 }) 54 ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) 55 pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) 56 udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) 57 return b 58 } 59 60 func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { 61 return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) 62 } 63 64 func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { 65 totalLen := 48 + payloadLen 66 b := make([]byte, offset+int(totalLen), 65535) 67 ipv6H := header.IPv6(b[offset:]) 68 srcAs16 := srcIPPort.Addr().As16() 69 dstAs16 := dstIPPort.Addr().As16() 70 ipFields := &header.IPv6Fields{ 71 SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), 72 DstAddr: tcpip.AddrFromSlice(dstAs16[:]), 73 TransportProtocol: unix.IPPROTO_UDP, 74 HopLimit: 64, 75 PayloadLength: uint16(payloadLen + udphLen), 76 } 77 if ipFn != nil { 78 ipFn(ipFields) 79 } 80 ipv6H.Encode(ipFields) 81 udpH := header.UDP(b[offset+40:]) 82 udpH.Encode(&header.UDPFields{ 83 SrcPort: srcIPPort.Port(), 84 DstPort: dstIPPort.Port(), 85 Length: uint16(payloadLen + udphLen), 86 }) 87 pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) 88 udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) 89 return b 90 } 91 92 func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { 93 return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) 94 } 95 96 func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { 97 totalLen := 40 + segmentSize 98 b := make([]byte, offset+int(totalLen), 65535) 99 ipv4H := header.IPv4(b[offset:]) 100 srcAs4 := srcIPPort.Addr().As4() 101 dstAs4 := dstIPPort.Addr().As4() 102 ipFields := &header.IPv4Fields{ 103 SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), 104 DstAddr: tcpip.AddrFromSlice(dstAs4[:]), 105 Protocol: unix.IPPROTO_TCP, 106 TTL: 64, 107 TotalLength: uint16(totalLen), 108 } 109 if ipFn != nil { 110 ipFn(ipFields) 111 } 112 ipv4H.Encode(ipFields) 113 tcpH := header.TCP(b[offset+20:]) 114 tcpH.Encode(&header.TCPFields{ 115 SrcPort: srcIPPort.Port(), 116 DstPort: dstIPPort.Port(), 117 SeqNum: seq, 118 AckNum: 1, 119 DataOffset: 20, 120 Flags: flags, 121 WindowSize: 3000, 122 }) 123 ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) 124 pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) 125 tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) 126 return b 127 } 128 129 func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { 130 return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) 131 } 132 133 func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { 134 totalLen := 60 + segmentSize 135 b := make([]byte, offset+int(totalLen), 65535) 136 ipv6H := header.IPv6(b[offset:]) 137 srcAs16 := srcIPPort.Addr().As16() 138 dstAs16 := dstIPPort.Addr().As16() 139 ipFields := &header.IPv6Fields{ 140 SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), 141 DstAddr: tcpip.AddrFromSlice(dstAs16[:]), 142 TransportProtocol: unix.IPPROTO_TCP, 143 HopLimit: 64, 144 PayloadLength: uint16(segmentSize + 20), 145 } 146 if ipFn != nil { 147 ipFn(ipFields) 148 } 149 ipv6H.Encode(ipFields) 150 tcpH := header.TCP(b[offset+40:]) 151 tcpH.Encode(&header.TCPFields{ 152 SrcPort: srcIPPort.Port(), 153 DstPort: dstIPPort.Port(), 154 SeqNum: seq, 155 AckNum: 1, 156 DataOffset: 20, 157 Flags: flags, 158 WindowSize: 3000, 159 }) 160 pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) 161 tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) 162 return b 163 } 164 165 func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { 166 return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) 167 } 168 169 func Test_handleVirtioRead(t *testing.T) { 170 tests := []struct { 171 name string 172 hdr virtioNetHdr 173 pktIn []byte 174 wantLens []int 175 wantErr bool 176 }{ 177 { 178 "tcp4", 179 virtioNetHdr{ 180 flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, 181 gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, 182 gsoSize: 100, 183 hdrLen: 40, 184 csumStart: 20, 185 csumOffset: 16, 186 }, 187 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), 188 []int{140, 140}, 189 false, 190 }, 191 { 192 "tcp6", 193 virtioNetHdr{ 194 flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, 195 gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, 196 gsoSize: 100, 197 hdrLen: 60, 198 csumStart: 40, 199 csumOffset: 16, 200 }, 201 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), 202 []int{160, 160}, 203 false, 204 }, 205 { 206 "udp4", 207 virtioNetHdr{ 208 flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, 209 gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, 210 gsoSize: 100, 211 hdrLen: 28, 212 csumStart: 20, 213 csumOffset: 6, 214 }, 215 udp4Packet(ip4PortA, ip4PortB, 200), 216 []int{128, 128}, 217 false, 218 }, 219 { 220 "udp6", 221 virtioNetHdr{ 222 flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, 223 gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, 224 gsoSize: 100, 225 hdrLen: 48, 226 csumStart: 40, 227 csumOffset: 6, 228 }, 229 udp6Packet(ip6PortA, ip6PortB, 200), 230 []int{148, 148}, 231 false, 232 }, 233 } 234 235 for _, tt := range tests { 236 t.Run(tt.name, func(t *testing.T) { 237 out := make([][]byte, conn.IdealBatchSize) 238 sizes := make([]int, conn.IdealBatchSize) 239 for i := range out { 240 out[i] = make([]byte, 65535) 241 } 242 tt.hdr.encode(tt.pktIn) 243 n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) 244 if err != nil { 245 if tt.wantErr { 246 return 247 } 248 t.Fatalf("got err: %v", err) 249 } 250 if n != len(tt.wantLens) { 251 t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) 252 } 253 for i := range tt.wantLens { 254 if tt.wantLens[i] != sizes[i] { 255 t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) 256 } 257 } 258 }) 259 } 260 } 261 262 func flipTCP4Checksum(b []byte) []byte { 263 at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 264 b[at] ^= 0xFF 265 b[at+1] ^= 0xFF 266 return b 267 } 268 269 func flipUDP4Checksum(b []byte) []byte { 270 at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 271 b[at] ^= 0xFF 272 b[at+1] ^= 0xFF 273 return b 274 } 275 276 func Fuzz_handleGRO(f *testing.F) { 277 pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) 278 pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) 279 pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) 280 pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) 281 pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) 282 pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) 283 pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) 284 pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) 285 pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) 286 pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) 287 pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) 288 pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) 289 f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) 290 f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { 291 pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} 292 toWrite := make([]int, 0, len(pkts)) 293 handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) 294 if len(toWrite) > len(pkts) { 295 t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) 296 } 297 seenWriteI := make(map[int]bool) 298 for _, writeI := range toWrite { 299 if writeI < 0 || writeI > len(pkts)-1 { 300 t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) 301 } 302 if seenWriteI[writeI] { 303 t.Errorf("duplicate toWrite value: %d", writeI) 304 } 305 seenWriteI[writeI] = true 306 } 307 }) 308 } 309 310 func Test_handleGRO(t *testing.T) { 311 tests := []struct { 312 name string 313 pktsIn [][]byte 314 canUDPGRO bool 315 wantToWrite []int 316 wantLens []int 317 wantErr bool 318 }{ 319 { 320 "multiple protocols and flows", 321 [][]byte{ 322 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 323 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 324 udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 325 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 326 tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 327 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 328 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 329 tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 330 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 331 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 332 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 333 }, 334 true, 335 []int{0, 1, 2, 4, 5, 7, 9}, 336 []int{240, 228, 128, 140, 260, 160, 248}, 337 false, 338 }, 339 { 340 "multiple protocols and flows no UDP GRO", 341 [][]byte{ 342 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 343 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 344 udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 345 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 346 tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 347 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 348 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 349 tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 350 udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 351 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 352 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 353 }, 354 false, 355 []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, 356 []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, 357 false, 358 }, 359 { 360 "PSH interleaved", 361 [][]byte{ 362 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 363 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 364 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 365 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 366 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 367 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 368 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 369 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 370 }, 371 true, 372 []int{0, 2, 4, 6}, 373 []int{240, 240, 260, 260}, 374 false, 375 }, 376 { 377 "coalesceItemInvalidCSum", 378 [][]byte{ 379 flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 380 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 381 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 382 flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), 383 udp4Packet(ip4PortA, ip4PortB, 100), 384 udp4Packet(ip4PortA, ip4PortB, 100), 385 }, 386 true, 387 []int{0, 1, 3, 4}, 388 []int{140, 240, 128, 228}, 389 false, 390 }, 391 { 392 "out of order", 393 [][]byte{ 394 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 395 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 396 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 397 }, 398 true, 399 []int{0}, 400 []int{340}, 401 false, 402 }, 403 { 404 "unequal TTL", 405 [][]byte{ 406 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 407 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 408 fields.TTL++ 409 }), 410 udp4Packet(ip4PortA, ip4PortB, 100), 411 udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { 412 fields.TTL++ 413 }), 414 }, 415 true, 416 []int{0, 1, 2, 3}, 417 []int{140, 140, 128, 128}, 418 false, 419 }, 420 { 421 "unequal ToS", 422 [][]byte{ 423 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 424 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 425 fields.TOS++ 426 }), 427 udp4Packet(ip4PortA, ip4PortB, 100), 428 udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { 429 fields.TOS++ 430 }), 431 }, 432 true, 433 []int{0, 1, 2, 3}, 434 []int{140, 140, 128, 128}, 435 false, 436 }, 437 { 438 "unequal flags more fragments set", 439 [][]byte{ 440 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 441 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 442 fields.Flags = 1 443 }), 444 udp4Packet(ip4PortA, ip4PortB, 100), 445 udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { 446 fields.Flags = 1 447 }), 448 }, 449 true, 450 []int{0, 1, 2, 3}, 451 []int{140, 140, 128, 128}, 452 false, 453 }, 454 { 455 "unequal flags DF set", 456 [][]byte{ 457 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), 458 tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { 459 fields.Flags = 2 460 }), 461 udp4Packet(ip4PortA, ip4PortB, 100), 462 udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { 463 fields.Flags = 2 464 }), 465 }, 466 true, 467 []int{0, 1, 2, 3}, 468 []int{140, 140, 128, 128}, 469 false, 470 }, 471 { 472 "ipv6 unequal hop limit", 473 [][]byte{ 474 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), 475 tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { 476 fields.HopLimit++ 477 }), 478 udp6Packet(ip6PortA, ip6PortB, 100), 479 udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { 480 fields.HopLimit++ 481 }), 482 }, 483 true, 484 []int{0, 1, 2, 3}, 485 []int{160, 160, 148, 148}, 486 false, 487 }, 488 { 489 "ipv6 unequal traffic class", 490 [][]byte{ 491 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), 492 tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { 493 fields.TrafficClass++ 494 }), 495 udp6Packet(ip6PortA, ip6PortB, 100), 496 udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { 497 fields.TrafficClass++ 498 }), 499 }, 500 true, 501 []int{0, 1, 2, 3}, 502 []int{160, 160, 148, 148}, 503 false, 504 }, 505 } 506 507 for _, tt := range tests { 508 t.Run(tt.name, func(t *testing.T) { 509 toWrite := make([]int, 0, len(tt.pktsIn)) 510 err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) 511 if err != nil { 512 if tt.wantErr { 513 return 514 } 515 t.Fatalf("got err: %v", err) 516 } 517 if len(toWrite) != len(tt.wantToWrite) { 518 t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) 519 } 520 for i, pktI := range tt.wantToWrite { 521 if tt.wantToWrite[i] != toWrite[i] { 522 t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) 523 } 524 if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { 525 t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) 526 } 527 } 528 }) 529 } 530 } 531 532 func Test_packetIsGROCandidate(t *testing.T) { 533 tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] 534 tcp4TooShort := tcp4[:39] 535 ip4InvalidHeaderLen := make([]byte, len(tcp4)) 536 copy(ip4InvalidHeaderLen, tcp4) 537 ip4InvalidHeaderLen[0] = 0x46 538 ip4InvalidProtocol := make([]byte, len(tcp4)) 539 copy(ip4InvalidProtocol, tcp4) 540 ip4InvalidProtocol[9] = unix.IPPROTO_GRE 541 542 tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] 543 tcp6TooShort := tcp6[:59] 544 ip6InvalidProtocol := make([]byte, len(tcp6)) 545 copy(ip6InvalidProtocol, tcp6) 546 ip6InvalidProtocol[6] = unix.IPPROTO_GRE 547 548 udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] 549 udp4TooShort := udp4[:27] 550 551 udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] 552 udp6TooShort := udp6[:47] 553 554 tests := []struct { 555 name string 556 b []byte 557 canUDPGRO bool 558 want groCandidateType 559 }{ 560 { 561 "tcp4", 562 tcp4, 563 true, 564 tcp4GROCandidate, 565 }, 566 { 567 "tcp6", 568 tcp6, 569 true, 570 tcp6GROCandidate, 571 }, 572 { 573 "udp4", 574 udp4, 575 true, 576 udp4GROCandidate, 577 }, 578 { 579 "udp4 no support", 580 udp4, 581 false, 582 notGROCandidate, 583 }, 584 { 585 "udp6", 586 udp6, 587 true, 588 udp6GROCandidate, 589 }, 590 { 591 "udp6 no support", 592 udp6, 593 false, 594 notGROCandidate, 595 }, 596 { 597 "udp4 too short", 598 udp4TooShort, 599 true, 600 notGROCandidate, 601 }, 602 { 603 "udp6 too short", 604 udp6TooShort, 605 true, 606 notGROCandidate, 607 }, 608 { 609 "tcp4 too short", 610 tcp4TooShort, 611 true, 612 notGROCandidate, 613 }, 614 { 615 "tcp6 too short", 616 tcp6TooShort, 617 true, 618 notGROCandidate, 619 }, 620 { 621 "invalid IP version", 622 []byte{0x00}, 623 true, 624 notGROCandidate, 625 }, 626 { 627 "invalid IP header len", 628 ip4InvalidHeaderLen, 629 true, 630 notGROCandidate, 631 }, 632 { 633 "ip4 invalid protocol", 634 ip4InvalidProtocol, 635 true, 636 notGROCandidate, 637 }, 638 { 639 "ip6 invalid protocol", 640 ip6InvalidProtocol, 641 true, 642 notGROCandidate, 643 }, 644 } 645 for _, tt := range tests { 646 t.Run(tt.name, func(t *testing.T) { 647 if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { 648 t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) 649 } 650 }) 651 } 652 } 653 654 func Test_udpPacketsCanCoalesce(t *testing.T) { 655 udp4a := udp4Packet(ip4PortA, ip4PortB, 100) 656 udp4b := udp4Packet(ip4PortA, ip4PortB, 100) 657 udp4c := udp4Packet(ip4PortA, ip4PortB, 110) 658 659 type args struct { 660 pkt []byte 661 iphLen uint8 662 gsoSize uint16 663 item udpGROItem 664 bufs [][]byte 665 bufsOffset int 666 } 667 tests := []struct { 668 name string 669 args args 670 want canCoalesce 671 }{ 672 { 673 "coalesceAppend equal gso", 674 args{ 675 pkt: udp4a[offset:], 676 iphLen: 20, 677 gsoSize: 100, 678 item: udpGROItem{ 679 gsoSize: 100, 680 iphLen: 20, 681 }, 682 bufs: [][]byte{ 683 udp4a, 684 udp4b, 685 }, 686 bufsOffset: offset, 687 }, 688 coalesceAppend, 689 }, 690 { 691 "coalesceAppend smaller gso", 692 args{ 693 pkt: udp4a[offset : len(udp4a)-90], 694 iphLen: 20, 695 gsoSize: 10, 696 item: udpGROItem{ 697 gsoSize: 100, 698 iphLen: 20, 699 }, 700 bufs: [][]byte{ 701 udp4a, 702 udp4b, 703 }, 704 bufsOffset: offset, 705 }, 706 coalesceAppend, 707 }, 708 { 709 "coalesceUnavailable smaller gso previously appended", 710 args{ 711 pkt: udp4a[offset:], 712 iphLen: 20, 713 gsoSize: 100, 714 item: udpGROItem{ 715 gsoSize: 100, 716 iphLen: 20, 717 }, 718 bufs: [][]byte{ 719 udp4c, 720 udp4b, 721 }, 722 bufsOffset: offset, 723 }, 724 coalesceUnavailable, 725 }, 726 { 727 "coalesceUnavailable larger following smaller", 728 args{ 729 pkt: udp4c[offset:], 730 iphLen: 20, 731 gsoSize: 110, 732 item: udpGROItem{ 733 gsoSize: 100, 734 iphLen: 20, 735 }, 736 bufs: [][]byte{ 737 udp4a, 738 udp4c, 739 }, 740 bufsOffset: offset, 741 }, 742 coalesceUnavailable, 743 }, 744 } 745 for _, tt := range tests { 746 t.Run(tt.name, func(t *testing.T) { 747 if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { 748 t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) 749 } 750 }) 751 } 752 }