github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/udp/udp_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package udp_test 16 17 import ( 18 "bytes" 19 "fmt" 20 "io/ioutil" 21 "math/rand" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 "github.com/SagerNet/gvisor/pkg/tcpip" 26 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 27 "github.com/SagerNet/gvisor/pkg/tcpip/checker" 28 "github.com/SagerNet/gvisor/pkg/tcpip/faketime" 29 "github.com/SagerNet/gvisor/pkg/tcpip/header" 30 "github.com/SagerNet/gvisor/pkg/tcpip/link/channel" 31 "github.com/SagerNet/gvisor/pkg/tcpip/link/loopback" 32 "github.com/SagerNet/gvisor/pkg/tcpip/link/sniffer" 33 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4" 34 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6" 35 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 36 "github.com/SagerNet/gvisor/pkg/tcpip/testutil" 37 "github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp" 38 "github.com/SagerNet/gvisor/pkg/tcpip/transport/udp" 39 "github.com/SagerNet/gvisor/pkg/waiter" 40 ) 41 42 // Addresses and ports used for testing. It is recommended that tests stick to 43 // using these addresses as it allows using the testFlow helper. 44 // Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' 45 // represents the remote endpoint. 46 const ( 47 v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" 48 stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" 49 testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" 50 stackV4MappedAddr = v4MappedAddrPrefix + stackAddr 51 testV4MappedAddr = v4MappedAddrPrefix + testAddr 52 multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr 53 broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr 54 v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" 55 56 stackAddr = "\x0a\x00\x00\x01" 57 stackPort = 1234 58 testAddr = "\x0a\x00\x00\x02" 59 testPort = 4096 60 invalidPort = 8192 61 multicastAddr = "\xe8\x2b\xd3\xea" 62 multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" 63 broadcastAddr = header.IPv4Broadcast 64 testTOS = 0x80 65 66 // defaultMTU is the MTU, in bytes, used throughout the tests, except 67 // where another value is explicitly used. It is chosen to match the MTU 68 // of loopback interfaces on linux systems. 69 defaultMTU = 65536 70 ) 71 72 // header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in 73 // a packet header. These values are used to populate a header or verify one. 74 // Note that because they are used in packet headers, the addresses are never in 75 // a V4-mapped format. 76 type header4Tuple struct { 77 srcAddr tcpip.FullAddress 78 dstAddr tcpip.FullAddress 79 } 80 81 // testFlow implements a helper type used for sending and receiving test 82 // packets. A given test flow value defines 1) the socket endpoint used for the 83 // test and 2) the type of packet send or received on the endpoint. E.g., a 84 // multicastV6Only flow is a V6 multicast packet passing through a V6-only 85 // endpoint. The type provides helper methods to characterize the flow (e.g., 86 // isV4) as well as return a proper header4Tuple for it. 87 type testFlow int 88 89 const ( 90 unicastV4 testFlow = iota // V4 unicast on a V4 socket 91 unicastV4in6 // V4-mapped unicast on a V6-dual socket 92 unicastV6 // V6 unicast on a V6 socket 93 unicastV6Only // V6 unicast on a V6-only socket 94 multicastV4 // V4 multicast on a V4 socket 95 multicastV4in6 // V4-mapped multicast on a V6-dual socket 96 multicastV6 // V6 multicast on a V6 socket 97 multicastV6Only // V6 multicast on a V6-only socket 98 broadcast // V4 broadcast on a V4 socket 99 broadcastIn6 // V4-mapped broadcast on a V6-dual socket 100 reverseMulticast4 // V4 multicast src. Must fail. 101 reverseMulticast6 // V6 multicast src. Must fail. 102 ) 103 104 func (flow testFlow) String() string { 105 switch flow { 106 case unicastV4: 107 return "unicastV4" 108 case unicastV6: 109 return "unicastV6" 110 case unicastV6Only: 111 return "unicastV6Only" 112 case unicastV4in6: 113 return "unicastV4in6" 114 case multicastV4: 115 return "multicastV4" 116 case multicastV6: 117 return "multicastV6" 118 case multicastV6Only: 119 return "multicastV6Only" 120 case multicastV4in6: 121 return "multicastV4in6" 122 case broadcast: 123 return "broadcast" 124 case broadcastIn6: 125 return "broadcastIn6" 126 case reverseMulticast4: 127 return "reverseMulticast4" 128 case reverseMulticast6: 129 return "reverseMulticast6" 130 default: 131 return "unknown" 132 } 133 } 134 135 // packetDirection explains if a flow is incoming (read) or outgoing (write). 136 type packetDirection int 137 138 const ( 139 incoming packetDirection = iota 140 outgoing 141 ) 142 143 // header4Tuple returns the header4Tuple for the given flow and direction. Note 144 // that the tuple contains no mapped addresses as those only exist at the socket 145 // level but not at the packet header level. 146 func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { 147 var h header4Tuple 148 if flow.isV4() { 149 if d == outgoing { 150 h = header4Tuple{ 151 srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, 152 dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, 153 } 154 } else { 155 h = header4Tuple{ 156 srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, 157 dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, 158 } 159 } 160 if flow.isMulticast() { 161 h.dstAddr.Addr = multicastAddr 162 } else if flow.isBroadcast() { 163 h.dstAddr.Addr = broadcastAddr 164 } 165 } else { // IPv6 166 if d == outgoing { 167 h = header4Tuple{ 168 srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, 169 dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, 170 } 171 } else { 172 h = header4Tuple{ 173 srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, 174 dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, 175 } 176 } 177 if flow.isMulticast() { 178 h.dstAddr.Addr = multicastV6Addr 179 } 180 } 181 if flow.isReverseMulticast() { 182 h.srcAddr.Addr = flow.getMcastAddr() 183 } 184 return h 185 } 186 187 func (flow testFlow) getMcastAddr() tcpip.Address { 188 if flow.isV4() { 189 return multicastAddr 190 } 191 return multicastV6Addr 192 } 193 194 // mapAddrIfApplicable converts the given V4 address into its V4-mapped version 195 // if it is applicable to the flow. 196 func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { 197 if flow.isMapped() { 198 return v4MappedAddrPrefix + v4Addr 199 } 200 return v4Addr 201 } 202 203 // netProto returns the protocol number used for the network packet. 204 func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { 205 if flow.isV4() { 206 return ipv4.ProtocolNumber 207 } 208 return ipv6.ProtocolNumber 209 } 210 211 // sockProto returns the protocol number used when creating the socket 212 // endpoint for this flow. 213 func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { 214 switch flow { 215 case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: 216 return ipv6.ProtocolNumber 217 case unicastV4, multicastV4, broadcast, reverseMulticast4: 218 return ipv4.ProtocolNumber 219 default: 220 panic(fmt.Sprintf("invalid testFlow given: %d", flow)) 221 } 222 } 223 224 func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { 225 if flow.isV4() { 226 return checker.IPv4 227 } 228 return checker.IPv6 229 } 230 231 func (flow testFlow) isV6() bool { return !flow.isV4() } 232 func (flow testFlow) isV4() bool { 233 return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() 234 } 235 236 func (flow testFlow) isV6Only() bool { 237 switch flow { 238 case unicastV6Only, multicastV6Only: 239 return true 240 case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: 241 return false 242 default: 243 panic(fmt.Sprintf("invalid testFlow given: %d", flow)) 244 } 245 } 246 247 func (flow testFlow) isMulticast() bool { 248 switch flow { 249 case multicastV4, multicastV4in6, multicastV6, multicastV6Only: 250 return true 251 case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: 252 return false 253 default: 254 panic(fmt.Sprintf("invalid testFlow given: %d", flow)) 255 } 256 } 257 258 func (flow testFlow) isBroadcast() bool { 259 switch flow { 260 case broadcast, broadcastIn6: 261 return true 262 case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: 263 return false 264 default: 265 panic(fmt.Sprintf("invalid testFlow given: %d", flow)) 266 } 267 } 268 269 func (flow testFlow) isMapped() bool { 270 switch flow { 271 case unicastV4in6, multicastV4in6, broadcastIn6: 272 return true 273 case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: 274 return false 275 default: 276 panic(fmt.Sprintf("invalid testFlow given: %d", flow)) 277 } 278 } 279 280 func (flow testFlow) isReverseMulticast() bool { 281 switch flow { 282 case reverseMulticast4, reverseMulticast6: 283 return true 284 default: 285 return false 286 } 287 } 288 289 type testContext struct { 290 t *testing.T 291 linkEP *channel.Endpoint 292 s *stack.Stack 293 294 ep tcpip.Endpoint 295 wq waiter.Queue 296 } 297 298 func newDualTestContext(t *testing.T, mtu uint32) *testContext { 299 t.Helper() 300 return newDualTestContextWithHandleLocal(t, mtu, true) 301 } 302 303 func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { 304 t.Helper() 305 306 options := stack.Options{ 307 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 308 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, 309 HandleLocal: handleLocal, 310 Clock: &faketime.NullClock{}, 311 } 312 s := stack.New(options) 313 ep := channel.New(256, mtu, "") 314 wep := stack.LinkEndpoint(ep) 315 316 if testing.Verbose() { 317 wep = sniffer.New(ep) 318 } 319 if err := s.CreateNIC(1, wep); err != nil { 320 t.Fatalf("CreateNIC failed: %s", err) 321 } 322 323 if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { 324 t.Fatalf("AddAddress failed: %s", err) 325 } 326 327 if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { 328 t.Fatalf("AddAddress failed: %s", err) 329 } 330 331 s.SetRouteTable([]tcpip.Route{ 332 { 333 Destination: header.IPv4EmptySubnet, 334 NIC: 1, 335 }, 336 { 337 Destination: header.IPv6EmptySubnet, 338 NIC: 1, 339 }, 340 }) 341 342 return &testContext{ 343 t: t, 344 s: s, 345 linkEP: ep, 346 } 347 } 348 349 func (c *testContext) cleanup() { 350 if c.ep != nil { 351 c.ep.Close() 352 } 353 } 354 355 func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { 356 c.t.Helper() 357 358 var err tcpip.Error 359 c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) 360 if err != nil { 361 c.t.Fatal("NewEndpoint failed: ", err) 362 } 363 } 364 365 func (c *testContext) createEndpointForFlow(flow testFlow) { 366 c.t.Helper() 367 368 c.createEndpoint(flow.sockProto()) 369 if flow.isV6Only() { 370 c.ep.SocketOptions().SetV6Only(true) 371 } else if flow.isBroadcast() { 372 c.ep.SocketOptions().SetBroadcast(true) 373 } 374 } 375 376 // getPacketAndVerify reads a packet from the link endpoint and verifies the 377 // header against expected values from the given test flow. In addition, it 378 // calls any extra checker functions provided. 379 func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { 380 c.t.Helper() 381 382 p, ok := c.linkEP.Read() 383 if !ok { 384 c.t.Fatalf("Packet wasn't written out") 385 return nil 386 } 387 388 if p.Proto != flow.netProto() { 389 c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) 390 } 391 392 if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { 393 c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) 394 } 395 396 vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) 397 b := vv.ToView() 398 399 h := flow.header4Tuple(outgoing) 400 checkers = append( 401 checkers, 402 checker.SrcAddr(h.srcAddr.Addr), 403 checker.DstAddr(h.dstAddr.Addr), 404 checker.UDP(checker.DstPort(h.dstAddr.Port)), 405 ) 406 flow.checkerFn()(c.t, b, checkers...) 407 return b 408 } 409 410 // injectPacket creates a packet of the given flow and with the given payload, 411 // and injects it into the link endpoint. If badChecksum is true, the packet has 412 // a bad checksum in the UDP header. 413 func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) { 414 c.t.Helper() 415 416 h := flow.header4Tuple(incoming) 417 if flow.isV4() { 418 buf := c.buildV4Packet(payload, &h) 419 if badChecksum { 420 // Invalidate the UDP header checksum field, taking care to avoid 421 // overflow to zero, which would disable checksum validation. 422 for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { 423 u.SetChecksum(u.Checksum() + 1) 424 if u.Checksum() != 0 { 425 break 426 } 427 } 428 } 429 c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 430 Data: buf.ToVectorisedView(), 431 })) 432 } else { 433 buf := c.buildV6Packet(payload, &h) 434 if badChecksum { 435 // Invalidate the UDP header checksum field (Unlike IPv4, zero is 436 // a valid checksum value for IPv6 so no need to avoid it). 437 u := header.UDP(buf[header.IPv6MinimumSize:]) 438 u.SetChecksum(u.Checksum() + 1) 439 } 440 c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 441 Data: buf.ToVectorisedView(), 442 })) 443 } 444 } 445 446 // buildV6Packet creates a V6 test packet with the given payload and header 447 // values in a buffer. 448 func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { 449 // Allocate a buffer for data and headers. 450 buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) 451 payloadStart := len(buf) - len(payload) 452 copy(buf[payloadStart:], payload) 453 454 // Initialize the IP header. 455 ip := header.IPv6(buf) 456 ip.Encode(&header.IPv6Fields{ 457 TrafficClass: testTOS, 458 PayloadLength: uint16(header.UDPMinimumSize + len(payload)), 459 TransportProtocol: udp.ProtocolNumber, 460 HopLimit: 65, 461 SrcAddr: h.srcAddr.Addr, 462 DstAddr: h.dstAddr.Addr, 463 }) 464 465 // Initialize the UDP header. 466 u := header.UDP(buf[header.IPv6MinimumSize:]) 467 u.Encode(&header.UDPFields{ 468 SrcPort: h.srcAddr.Port, 469 DstPort: h.dstAddr.Port, 470 Length: uint16(header.UDPMinimumSize + len(payload)), 471 }) 472 473 // Calculate the UDP pseudo-header checksum. 474 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) 475 476 // Calculate the UDP checksum and set it. 477 xsum = header.Checksum(payload, xsum) 478 u.SetChecksum(^u.CalculateChecksum(xsum)) 479 480 return buf 481 } 482 483 // buildV4Packet creates a V4 test packet with the given payload and header 484 // values in a buffer. 485 func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { 486 // Allocate a buffer for data and headers. 487 buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) 488 payloadStart := len(buf) - len(payload) 489 copy(buf[payloadStart:], payload) 490 491 // Initialize the IP header. 492 ip := header.IPv4(buf) 493 ip.Encode(&header.IPv4Fields{ 494 TOS: testTOS, 495 TotalLength: uint16(len(buf)), 496 TTL: 65, 497 Protocol: uint8(udp.ProtocolNumber), 498 SrcAddr: h.srcAddr.Addr, 499 DstAddr: h.dstAddr.Addr, 500 }) 501 ip.SetChecksum(^ip.CalculateChecksum()) 502 503 // Initialize the UDP header. 504 u := header.UDP(buf[header.IPv4MinimumSize:]) 505 u.Encode(&header.UDPFields{ 506 SrcPort: h.srcAddr.Port, 507 DstPort: h.dstAddr.Port, 508 Length: uint16(header.UDPMinimumSize + len(payload)), 509 }) 510 511 // Calculate the UDP pseudo-header checksum. 512 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) 513 514 // Calculate the UDP checksum and set it. 515 xsum = header.Checksum(payload, xsum) 516 u.SetChecksum(^u.CalculateChecksum(xsum)) 517 518 return buf 519 } 520 521 func newPayload() []byte { 522 return newMinPayload(30) 523 } 524 525 func newMinPayload(minSize int) []byte { 526 b := make([]byte, minSize+rand.Intn(100)) 527 for i := range b { 528 b[i] = byte(rand.Intn(256)) 529 } 530 return b 531 } 532 533 func TestBindToDeviceOption(t *testing.T) { 534 s := stack.New(stack.Options{ 535 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 536 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 537 Clock: &faketime.NullClock{}, 538 }) 539 540 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 541 if err != nil { 542 t.Fatalf("NewEndpoint failed; %s", err) 543 } 544 defer ep.Close() 545 546 opts := stack.NICOptions{Name: "my_device"} 547 if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { 548 t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) 549 } 550 551 // nicIDPtr is used instead of taking the address of NICID literals, which is 552 // a compiler error. 553 nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { 554 return &s 555 } 556 557 testActions := []struct { 558 name string 559 setBindToDevice *tcpip.NICID 560 setBindToDeviceError tcpip.Error 561 getBindToDevice int32 562 }{ 563 {"GetDefaultValue", nil, nil, 0}, 564 {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, 565 {"BindToExistent", nicIDPtr(321), nil, 321}, 566 {"UnbindToDevice", nicIDPtr(0), nil, 0}, 567 } 568 for _, testAction := range testActions { 569 t.Run(testAction.name, func(t *testing.T) { 570 if testAction.setBindToDevice != nil { 571 bindToDevice := int32(*testAction.setBindToDevice) 572 if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { 573 t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) 574 } 575 } 576 bindToDevice := ep.SocketOptions().GetBindToDevice() 577 if bindToDevice != testAction.getBindToDevice { 578 t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) 579 } 580 }) 581 } 582 } 583 584 // testReadInternal sends a packet of the given test flow into the stack by 585 // injecting it into the link endpoint. It then attempts to read it from the 586 // UDP endpoint and depending on if this was expected to succeed verifies its 587 // correctness including any additional checker functions provided. 588 func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { 589 c.t.Helper() 590 591 payload := newPayload() 592 c.injectPacket(flow, payload, false) 593 594 // Try to receive the data. 595 we, ch := waiter.NewChannelEntry(nil) 596 c.wq.EventRegister(&we, waiter.ReadableEvents) 597 defer c.wq.EventUnregister(&we) 598 599 // Take a snapshot of the stats to validate them at the end of the test. 600 epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() 601 602 var buf bytes.Buffer 603 res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) 604 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 605 // Wait for data to become available. 606 select { 607 case <-ch: 608 res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) 609 610 default: 611 if packetShouldBeDropped { 612 return // expected to time out 613 } 614 c.t.Fatal("timed out waiting for data") 615 } 616 } 617 618 if expectReadError && err != nil { 619 c.checkEndpointReadStats(1, epstats, err) 620 return 621 } 622 623 if err != nil { 624 c.t.Fatal("Read failed:", err) 625 } 626 627 if packetShouldBeDropped { 628 c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) 629 } 630 631 // Check the read result. 632 h := flow.header4Tuple(incoming) 633 if diff := cmp.Diff(tcpip.ReadResult{ 634 Count: buf.Len(), 635 Total: buf.Len(), 636 RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr}, 637 }, res, checker.IgnoreCmpPath( 638 "ControlMessages", // ControlMessages will be checked later. 639 "RemoteAddr.NIC", 640 "RemoteAddr.Port", 641 )); diff != "" { 642 c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff) 643 } 644 645 // Check the payload. 646 v := buf.Bytes() 647 if !bytes.Equal(payload, v) { 648 c.t.Fatalf("got payload = %x, want = %x", v, payload) 649 } 650 651 // Run any checkers against the ControlMessages. 652 for _, f := range checkers { 653 f(c.t, res.ControlMessages) 654 } 655 656 c.checkEndpointReadStats(1, epstats, err) 657 } 658 659 // testRead sends a packet of the given test flow into the stack by injecting it 660 // into the link endpoint. It then reads it from the UDP endpoint and verifies 661 // its correctness including any additional checker functions provided. 662 func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { 663 c.t.Helper() 664 testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) 665 } 666 667 // testFailingRead sends a packet of the given test flow into the stack by 668 // injecting it into the link endpoint. It then tries to read it from the UDP 669 // endpoint and expects this to fail. 670 func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { 671 c.t.Helper() 672 testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) 673 } 674 675 func TestBindEphemeralPort(t *testing.T) { 676 c := newDualTestContext(t, defaultMTU) 677 defer c.cleanup() 678 679 c.createEndpoint(ipv6.ProtocolNumber) 680 681 if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { 682 t.Fatalf("ep.Bind(...) failed: %s", err) 683 } 684 } 685 686 func TestBindReservedPort(t *testing.T) { 687 c := newDualTestContext(t, defaultMTU) 688 defer c.cleanup() 689 690 c.createEndpoint(ipv6.ProtocolNumber) 691 692 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { 693 c.t.Fatalf("Connect failed: %s", err) 694 } 695 696 addr, err := c.ep.GetLocalAddress() 697 if err != nil { 698 t.Fatalf("GetLocalAddress failed: %s", err) 699 } 700 701 // We can't bind the address reserved by the connected endpoint above. 702 { 703 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) 704 if err != nil { 705 t.Fatalf("NewEndpoint failed: %s", err) 706 } 707 defer ep.Close() 708 { 709 err := ep.Bind(addr) 710 if _, ok := err.(*tcpip.ErrPortInUse); !ok { 711 t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) 712 } 713 } 714 } 715 716 func() { 717 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) 718 if err != nil { 719 t.Fatalf("NewEndpoint failed: %s", err) 720 } 721 defer ep.Close() 722 // We can't bind ipv4-any on the port reserved by the connected endpoint 723 // above, since the endpoint is dual-stack. 724 { 725 err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) 726 if _, ok := err.(*tcpip.ErrPortInUse); !ok { 727 t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) 728 } 729 } 730 // We can bind an ipv4 address on this port, though. 731 if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { 732 t.Fatalf("ep.Bind(...) failed: %s", err) 733 } 734 }() 735 736 // Once the connected endpoint releases its port reservation, we are able to 737 // bind ipv4-any once again. 738 c.ep.Close() 739 func() { 740 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) 741 if err != nil { 742 t.Fatalf("NewEndpoint failed: %s", err) 743 } 744 defer ep.Close() 745 if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { 746 t.Fatalf("ep.Bind(...) failed: %s", err) 747 } 748 }() 749 } 750 751 func TestV4ReadOnV6(t *testing.T) { 752 c := newDualTestContext(t, defaultMTU) 753 defer c.cleanup() 754 755 c.createEndpointForFlow(unicastV4in6) 756 757 // Bind to wildcard. 758 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 759 c.t.Fatalf("Bind failed: %s", err) 760 } 761 762 // Test acceptance. 763 testRead(c, unicastV4in6) 764 } 765 766 func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { 767 c := newDualTestContext(t, defaultMTU) 768 defer c.cleanup() 769 770 c.createEndpointForFlow(unicastV4in6) 771 772 // Bind to v4 mapped wildcard. 773 if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { 774 c.t.Fatalf("Bind failed: %s", err) 775 } 776 777 // Test acceptance. 778 testRead(c, unicastV4in6) 779 } 780 781 func TestV4ReadOnBoundToV4Mapped(t *testing.T) { 782 c := newDualTestContext(t, defaultMTU) 783 defer c.cleanup() 784 785 c.createEndpointForFlow(unicastV4in6) 786 787 // Bind to local address. 788 if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { 789 c.t.Fatalf("Bind failed: %s", err) 790 } 791 792 // Test acceptance. 793 testRead(c, unicastV4in6) 794 } 795 796 func TestV6ReadOnV6(t *testing.T) { 797 c := newDualTestContext(t, defaultMTU) 798 defer c.cleanup() 799 800 c.createEndpointForFlow(unicastV6) 801 802 // Bind to wildcard. 803 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 804 c.t.Fatalf("Bind failed: %s", err) 805 } 806 807 // Test acceptance. 808 testRead(c, unicastV6) 809 } 810 811 // TestV4ReadSelfSource checks that packets coming from a local IP address are 812 // correctly dropped when handleLocal is true and not otherwise. 813 func TestV4ReadSelfSource(t *testing.T) { 814 for _, tt := range []struct { 815 name string 816 handleLocal bool 817 wantErr tcpip.Error 818 wantInvalidSource uint64 819 }{ 820 {"HandleLocal", false, nil, 0}, 821 {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, 822 } { 823 t.Run(tt.name, func(t *testing.T) { 824 c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal) 825 defer c.cleanup() 826 827 c.createEndpointForFlow(unicastV4) 828 829 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 830 t.Fatalf("Bind failed: %s", err) 831 } 832 833 payload := newPayload() 834 h := unicastV4.header4Tuple(incoming) 835 h.srcAddr = h.dstAddr 836 837 buf := c.buildV4Packet(payload, &h) 838 c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 839 Data: buf.ToVectorisedView(), 840 })) 841 842 if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { 843 t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) 844 } 845 846 if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr { 847 t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) 848 } 849 }) 850 } 851 } 852 853 func TestV4ReadOnV4(t *testing.T) { 854 c := newDualTestContext(t, defaultMTU) 855 defer c.cleanup() 856 857 c.createEndpointForFlow(unicastV4) 858 859 // Bind to wildcard. 860 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 861 c.t.Fatalf("Bind failed: %s", err) 862 } 863 864 // Test acceptance. 865 testRead(c, unicastV4) 866 } 867 868 // TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast 869 // address and receive data sent to that address. 870 func TestReadOnBoundToMulticast(t *testing.T) { 871 // FIXME(b/128189410): multicastV4in6 currently doesn't work as 872 // AddMembershipOption doesn't handle V4in6 addresses. 873 for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { 874 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 875 c := newDualTestContext(t, defaultMTU) 876 defer c.cleanup() 877 878 c.createEndpointForFlow(flow) 879 880 // Bind to multicast address. 881 mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) 882 if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { 883 c.t.Fatal("Bind failed:", err) 884 } 885 886 // Join multicast group. 887 ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} 888 if err := c.ep.SetSockOpt(&ifoptSet); err != nil { 889 c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) 890 } 891 892 // Check that we receive multicast packets but not unicast or broadcast 893 // ones. 894 testRead(c, flow) 895 testFailingRead(c, broadcast, false /* expectReadError */) 896 testFailingRead(c, unicastV4, false /* expectReadError */) 897 }) 898 } 899 } 900 901 // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast 902 // address and can receive only broadcast data. 903 func TestV4ReadOnBoundToBroadcast(t *testing.T) { 904 for _, flow := range []testFlow{broadcast, broadcastIn6} { 905 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 906 c := newDualTestContext(t, defaultMTU) 907 defer c.cleanup() 908 909 c.createEndpointForFlow(flow) 910 911 // Bind to broadcast address. 912 bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) 913 if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { 914 c.t.Fatalf("Bind failed: %s", err) 915 } 916 917 // Check that we receive broadcast packets but not unicast ones. 918 testRead(c, flow) 919 testFailingRead(c, unicastV4, false /* expectReadError */) 920 }) 921 } 922 } 923 924 // TestReadFromMulticast checks that an endpoint will NOT receive a packet 925 // that was sent with multicast SOURCE address. 926 func TestReadFromMulticast(t *testing.T) { 927 for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { 928 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 929 c := newDualTestContext(t, defaultMTU) 930 defer c.cleanup() 931 932 c.createEndpointForFlow(flow) 933 934 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 935 t.Fatalf("Bind failed: %s", err) 936 } 937 testFailingRead(c, flow, false /* expectReadError */) 938 }) 939 } 940 } 941 942 // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY 943 // and receive broadcast and unicast data. 944 func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { 945 for _, flow := range []testFlow{broadcast, broadcastIn6} { 946 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 947 c := newDualTestContext(t, defaultMTU) 948 defer c.cleanup() 949 950 c.createEndpointForFlow(flow) 951 952 // Bind to wildcard. 953 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 954 c.t.Fatalf("Bind failed: %s (", err) 955 } 956 957 // Check that we receive both broadcast and unicast packets. 958 testRead(c, flow) 959 testRead(c, unicastV4) 960 }) 961 } 962 } 963 964 // testFailingWrite sends a packet of the given test flow into the UDP endpoint 965 // and verifies it fails with the provided error code. 966 func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { 967 c.t.Helper() 968 // Take a snapshot of the stats to validate them at the end of the test. 969 epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() 970 h := flow.header4Tuple(outgoing) 971 writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) 972 973 var r bytes.Reader 974 r.Reset(newPayload()) 975 _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ 976 To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, 977 }) 978 c.checkEndpointWriteStats(1, epstats, gotErr) 979 if gotErr != wantErr { 980 c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) 981 } 982 } 983 984 // testWrite sends a packet of the given test flow from the UDP endpoint to the 985 // flow's destination address:port. It then receives it from the link endpoint 986 // and verifies its correctness including any additional checker functions 987 // provided. 988 func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { 989 c.t.Helper() 990 return testWriteAndVerifyInternal(c, flow, true, checkers...) 991 } 992 993 // testWriteWithoutDestination sends a packet of the given test flow from the 994 // UDP endpoint without giving a destination address:port. It then receives it 995 // from the link endpoint and verifies its correctness including any additional 996 // checker functions provided. 997 func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { 998 c.t.Helper() 999 return testWriteAndVerifyInternal(c, flow, false, checkers...) 1000 } 1001 1002 func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { 1003 c.t.Helper() 1004 // Take a snapshot of the stats to validate them at the end of the test. 1005 epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() 1006 1007 writeOpts := tcpip.WriteOptions{} 1008 if setDest { 1009 h := flow.header4Tuple(outgoing) 1010 writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) 1011 writeOpts = tcpip.WriteOptions{ 1012 To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, 1013 } 1014 } 1015 var r bytes.Reader 1016 payload := newPayload() 1017 r.Reset(payload) 1018 n, err := c.ep.Write(&r, writeOpts) 1019 if err != nil { 1020 c.t.Fatalf("Write failed: %s", err) 1021 } 1022 if n != int64(len(payload)) { 1023 c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) 1024 } 1025 c.checkEndpointWriteStats(1, epstats, err) 1026 return payload 1027 } 1028 1029 func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { 1030 c.t.Helper() 1031 payload := testWriteNoVerify(c, flow, setDest) 1032 // Received the packet and check the payload. 1033 b := c.getPacketAndVerify(flow, checkers...) 1034 var udpH header.UDP 1035 if flow.isV4() { 1036 udpH = header.IPv4(b).Payload() 1037 } else { 1038 udpH = header.IPv6(b).Payload() 1039 } 1040 if !bytes.Equal(payload, udpH.Payload()) { 1041 c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) 1042 } 1043 1044 return udpH.SourcePort() 1045 } 1046 1047 func testDualWrite(c *testContext) uint16 { 1048 c.t.Helper() 1049 1050 v4Port := testWrite(c, unicastV4in6) 1051 v6Port := testWrite(c, unicastV6) 1052 if v4Port != v6Port { 1053 c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) 1054 } 1055 1056 return v4Port 1057 } 1058 1059 func TestDualWriteUnbound(t *testing.T) { 1060 c := newDualTestContext(t, defaultMTU) 1061 defer c.cleanup() 1062 1063 c.createEndpoint(ipv6.ProtocolNumber) 1064 1065 testDualWrite(c) 1066 } 1067 1068 func TestDualWriteBoundToWildcard(t *testing.T) { 1069 c := newDualTestContext(t, defaultMTU) 1070 defer c.cleanup() 1071 1072 c.createEndpoint(ipv6.ProtocolNumber) 1073 1074 // Bind to wildcard. 1075 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 1076 c.t.Fatalf("Bind failed: %s", err) 1077 } 1078 1079 p := testDualWrite(c) 1080 if p != stackPort { 1081 c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) 1082 } 1083 } 1084 1085 func TestDualWriteConnectedToV6(t *testing.T) { 1086 c := newDualTestContext(t, defaultMTU) 1087 defer c.cleanup() 1088 1089 c.createEndpoint(ipv6.ProtocolNumber) 1090 1091 // Connect to v6 address. 1092 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { 1093 c.t.Fatalf("Bind failed: %s", err) 1094 } 1095 1096 testWrite(c, unicastV6) 1097 1098 // Write to V4 mapped address. 1099 testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) 1100 const want = 1 1101 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { 1102 c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) 1103 } 1104 } 1105 1106 func TestDualWriteConnectedToV4Mapped(t *testing.T) { 1107 c := newDualTestContext(t, defaultMTU) 1108 defer c.cleanup() 1109 1110 c.createEndpoint(ipv6.ProtocolNumber) 1111 1112 // Connect to v4 mapped address. 1113 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { 1114 c.t.Fatalf("Bind failed: %s", err) 1115 } 1116 1117 testWrite(c, unicastV4in6) 1118 1119 // Write to v6 address. 1120 testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) 1121 } 1122 1123 func TestV4WriteOnV6Only(t *testing.T) { 1124 c := newDualTestContext(t, defaultMTU) 1125 defer c.cleanup() 1126 1127 c.createEndpointForFlow(unicastV6Only) 1128 1129 // Write to V4 mapped address. 1130 testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) 1131 } 1132 1133 func TestV6WriteOnBoundToV4Mapped(t *testing.T) { 1134 c := newDualTestContext(t, defaultMTU) 1135 defer c.cleanup() 1136 1137 c.createEndpoint(ipv6.ProtocolNumber) 1138 1139 // Bind to v4 mapped address. 1140 if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { 1141 c.t.Fatalf("Bind failed: %s", err) 1142 } 1143 1144 // Write to v6 address. 1145 testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) 1146 } 1147 1148 func TestV6WriteOnConnected(t *testing.T) { 1149 c := newDualTestContext(t, defaultMTU) 1150 defer c.cleanup() 1151 1152 c.createEndpoint(ipv6.ProtocolNumber) 1153 1154 // Connect to v6 address. 1155 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { 1156 c.t.Fatalf("Connect failed: %s", err) 1157 } 1158 1159 testWriteWithoutDestination(c, unicastV6) 1160 } 1161 1162 func TestV4WriteOnConnected(t *testing.T) { 1163 c := newDualTestContext(t, defaultMTU) 1164 defer c.cleanup() 1165 1166 c.createEndpoint(ipv6.ProtocolNumber) 1167 1168 // Connect to v4 mapped address. 1169 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { 1170 c.t.Fatalf("Connect failed: %s", err) 1171 } 1172 1173 testWriteWithoutDestination(c, unicastV4) 1174 } 1175 1176 func TestWriteOnConnectedInvalidPort(t *testing.T) { 1177 protocols := map[string]tcpip.NetworkProtocolNumber{ 1178 "ipv4": ipv4.ProtocolNumber, 1179 "ipv6": ipv6.ProtocolNumber, 1180 } 1181 for name, pn := range protocols { 1182 t.Run(name, func(t *testing.T) { 1183 c := newDualTestContext(t, defaultMTU) 1184 defer c.cleanup() 1185 1186 c.createEndpoint(pn) 1187 if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { 1188 c.t.Fatalf("Connect failed: %s", err) 1189 } 1190 writeOpts := tcpip.WriteOptions{ 1191 To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, 1192 } 1193 var r bytes.Reader 1194 payload := newPayload() 1195 r.Reset(payload) 1196 n, err := c.ep.Write(&r, writeOpts) 1197 if err != nil { 1198 c.t.Fatalf("c.ep.Write(...) = %s, want nil", err) 1199 } 1200 if got, want := n, int64(len(payload)); got != want { 1201 c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) 1202 } 1203 1204 { 1205 err := c.ep.LastError() 1206 if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { 1207 c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) 1208 } 1209 } 1210 }) 1211 } 1212 } 1213 1214 // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket 1215 // that is bound to a V4 multicast address. 1216 func TestWriteOnBoundToV4Multicast(t *testing.T) { 1217 for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { 1218 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 1219 c := newDualTestContext(t, defaultMTU) 1220 defer c.cleanup() 1221 1222 c.createEndpointForFlow(flow) 1223 1224 // Bind to V4 mcast address. 1225 if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { 1226 c.t.Fatal("Bind failed:", err) 1227 } 1228 1229 testWrite(c, flow) 1230 }) 1231 } 1232 } 1233 1234 // TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a 1235 // socket that is bound to a V4-mapped multicast address. 1236 func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { 1237 for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { 1238 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 1239 c := newDualTestContext(t, defaultMTU) 1240 defer c.cleanup() 1241 1242 c.createEndpointForFlow(flow) 1243 1244 // Bind to V4Mapped mcast address. 1245 if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { 1246 c.t.Fatalf("Bind failed: %s", err) 1247 } 1248 1249 testWrite(c, flow) 1250 }) 1251 } 1252 } 1253 1254 // TestWriteOnBoundToV6Multicast checks that we can send packets out of a 1255 // socket that is bound to a V6 multicast address. 1256 func TestWriteOnBoundToV6Multicast(t *testing.T) { 1257 for _, flow := range []testFlow{unicastV6, multicastV6} { 1258 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 1259 c := newDualTestContext(t, defaultMTU) 1260 defer c.cleanup() 1261 1262 c.createEndpointForFlow(flow) 1263 1264 // Bind to V6 mcast address. 1265 if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { 1266 c.t.Fatalf("Bind failed: %s", err) 1267 } 1268 1269 testWrite(c, flow) 1270 }) 1271 } 1272 } 1273 1274 // TestWriteOnBoundToV6Multicast checks that we can send packets out of a 1275 // V6-only socket that is bound to a V6 multicast address. 1276 func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { 1277 for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { 1278 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 1279 c := newDualTestContext(t, defaultMTU) 1280 defer c.cleanup() 1281 1282 c.createEndpointForFlow(flow) 1283 1284 // Bind to V6 mcast address. 1285 if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { 1286 c.t.Fatalf("Bind failed: %s", err) 1287 } 1288 1289 testWrite(c, flow) 1290 }) 1291 } 1292 } 1293 1294 // TestWriteOnBoundToBroadcast checks that we can send packets out of a 1295 // socket that is bound to the broadcast address. 1296 func TestWriteOnBoundToBroadcast(t *testing.T) { 1297 for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { 1298 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 1299 c := newDualTestContext(t, defaultMTU) 1300 defer c.cleanup() 1301 1302 c.createEndpointForFlow(flow) 1303 1304 // Bind to V4 broadcast address. 1305 if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { 1306 c.t.Fatal("Bind failed:", err) 1307 } 1308 1309 testWrite(c, flow) 1310 }) 1311 } 1312 } 1313 1314 // TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a 1315 // socket that is bound to the V4-mapped broadcast address. 1316 func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { 1317 for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { 1318 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 1319 c := newDualTestContext(t, defaultMTU) 1320 defer c.cleanup() 1321 1322 c.createEndpointForFlow(flow) 1323 1324 // Bind to V4Mapped mcast address. 1325 if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { 1326 c.t.Fatalf("Bind failed: %s", err) 1327 } 1328 1329 testWrite(c, flow) 1330 }) 1331 } 1332 } 1333 1334 func TestReadIncrementsPacketsReceived(t *testing.T) { 1335 c := newDualTestContext(t, defaultMTU) 1336 defer c.cleanup() 1337 1338 // Create IPv4 UDP endpoint 1339 c.createEndpoint(ipv6.ProtocolNumber) 1340 1341 // Bind to wildcard. 1342 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 1343 c.t.Fatalf("Bind failed: %s", err) 1344 } 1345 1346 testRead(c, unicastV4) 1347 1348 var want uint64 = 1 1349 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { 1350 c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) 1351 } 1352 } 1353 1354 func TestReadIPPacketInfo(t *testing.T) { 1355 tests := []struct { 1356 name string 1357 proto tcpip.NetworkProtocolNumber 1358 flow testFlow 1359 expectedLocalAddr tcpip.Address 1360 expectedDestAddr tcpip.Address 1361 }{ 1362 { 1363 name: "IPv4 unicast", 1364 proto: header.IPv4ProtocolNumber, 1365 flow: unicastV4, 1366 expectedLocalAddr: stackAddr, 1367 expectedDestAddr: stackAddr, 1368 }, 1369 { 1370 name: "IPv4 multicast", 1371 proto: header.IPv4ProtocolNumber, 1372 flow: multicastV4, 1373 // This should actually be a unicast address assigned to the interface. 1374 // 1375 // TODO(github.com/SagerNet/issue/3556): This check is validating incorrect 1376 // behaviour. We still include the test so that once the bug is 1377 // resolved, this test will start to fail and the individual tasked 1378 // with fixing this bug knows to also fix this test :). 1379 expectedLocalAddr: multicastAddr, 1380 expectedDestAddr: multicastAddr, 1381 }, 1382 { 1383 name: "IPv4 broadcast", 1384 proto: header.IPv4ProtocolNumber, 1385 flow: broadcast, 1386 // This should actually be a unicast address assigned to the interface. 1387 // 1388 // TODO(github.com/SagerNet/issue/3556): This check is validating incorrect 1389 // behaviour. We still include the test so that once the bug is 1390 // resolved, this test will start to fail and the individual tasked 1391 // with fixing this bug knows to also fix this test :). 1392 expectedLocalAddr: broadcastAddr, 1393 expectedDestAddr: broadcastAddr, 1394 }, 1395 { 1396 name: "IPv6 unicast", 1397 proto: header.IPv6ProtocolNumber, 1398 flow: unicastV6, 1399 expectedLocalAddr: stackV6Addr, 1400 expectedDestAddr: stackV6Addr, 1401 }, 1402 { 1403 name: "IPv6 multicast", 1404 proto: header.IPv6ProtocolNumber, 1405 flow: multicastV6, 1406 // This should actually be a unicast address assigned to the interface. 1407 // 1408 // TODO(github.com/SagerNet/issue/3556): This check is validating incorrect 1409 // behaviour. We still include the test so that once the bug is 1410 // resolved, this test will start to fail and the individual tasked 1411 // with fixing this bug knows to also fix this test :). 1412 expectedLocalAddr: multicastV6Addr, 1413 expectedDestAddr: multicastV6Addr, 1414 }, 1415 } 1416 1417 for _, test := range tests { 1418 t.Run(test.name, func(t *testing.T) { 1419 c := newDualTestContext(t, defaultMTU) 1420 defer c.cleanup() 1421 1422 c.createEndpoint(test.proto) 1423 1424 bindAddr := tcpip.FullAddress{Port: stackPort} 1425 if err := c.ep.Bind(bindAddr); err != nil { 1426 t.Fatalf("Bind(%+v): %s", bindAddr, err) 1427 } 1428 1429 if test.flow.isMulticast() { 1430 ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} 1431 if err := c.ep.SetSockOpt(&ifoptSet); err != nil { 1432 c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) 1433 } 1434 } 1435 1436 c.ep.SocketOptions().SetReceivePacketInfo(true) 1437 1438 testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ 1439 NIC: 1, 1440 LocalAddr: test.expectedLocalAddr, 1441 DestinationAddr: test.expectedDestAddr, 1442 })) 1443 1444 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { 1445 t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) 1446 } 1447 }) 1448 } 1449 } 1450 1451 func TestReadRecvOriginalDstAddr(t *testing.T) { 1452 tests := []struct { 1453 name string 1454 proto tcpip.NetworkProtocolNumber 1455 flow testFlow 1456 expectedOriginalDstAddr tcpip.FullAddress 1457 }{ 1458 { 1459 name: "IPv4 unicast", 1460 proto: header.IPv4ProtocolNumber, 1461 flow: unicastV4, 1462 expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, 1463 }, 1464 { 1465 name: "IPv4 multicast", 1466 proto: header.IPv4ProtocolNumber, 1467 flow: multicastV4, 1468 // This should actually be a unicast address assigned to the interface. 1469 // 1470 // TODO(github.com/SagerNet/issue/3556): This check is validating incorrect 1471 // behaviour. We still include the test so that once the bug is 1472 // resolved, this test will start to fail and the individual tasked 1473 // with fixing this bug knows to also fix this test :). 1474 expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, 1475 }, 1476 { 1477 name: "IPv4 broadcast", 1478 proto: header.IPv4ProtocolNumber, 1479 flow: broadcast, 1480 // This should actually be a unicast address assigned to the interface. 1481 // 1482 // TODO(github.com/SagerNet/issue/3556): This check is validating incorrect 1483 // behaviour. We still include the test so that once the bug is 1484 // resolved, this test will start to fail and the individual tasked 1485 // with fixing this bug knows to also fix this test :). 1486 expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, 1487 }, 1488 { 1489 name: "IPv6 unicast", 1490 proto: header.IPv6ProtocolNumber, 1491 flow: unicastV6, 1492 expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, 1493 }, 1494 { 1495 name: "IPv6 multicast", 1496 proto: header.IPv6ProtocolNumber, 1497 flow: multicastV6, 1498 // This should actually be a unicast address assigned to the interface. 1499 // 1500 // TODO(github.com/SagerNet/issue/3556): This check is validating incorrect 1501 // behaviour. We still include the test so that once the bug is 1502 // resolved, this test will start to fail and the individual tasked 1503 // with fixing this bug knows to also fix this test :). 1504 expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, 1505 }, 1506 } 1507 1508 for _, test := range tests { 1509 t.Run(test.name, func(t *testing.T) { 1510 c := newDualTestContext(t, defaultMTU) 1511 defer c.cleanup() 1512 1513 c.createEndpoint(test.proto) 1514 1515 bindAddr := tcpip.FullAddress{Port: stackPort} 1516 if err := c.ep.Bind(bindAddr); err != nil { 1517 t.Fatalf("Bind(%#v): %s", bindAddr, err) 1518 } 1519 1520 if test.flow.isMulticast() { 1521 ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} 1522 if err := c.ep.SetSockOpt(&ifoptSet); err != nil { 1523 c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) 1524 } 1525 } 1526 1527 c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) 1528 1529 testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) 1530 1531 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { 1532 t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) 1533 } 1534 }) 1535 } 1536 } 1537 1538 func TestWriteIncrementsPacketsSent(t *testing.T) { 1539 c := newDualTestContext(t, defaultMTU) 1540 defer c.cleanup() 1541 1542 c.createEndpoint(ipv6.ProtocolNumber) 1543 1544 testDualWrite(c) 1545 1546 var want uint64 = 2 1547 if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { 1548 c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) 1549 } 1550 } 1551 1552 func TestNoChecksum(t *testing.T) { 1553 for _, flow := range []testFlow{unicastV4, unicastV6} { 1554 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1555 c := newDualTestContext(t, defaultMTU) 1556 defer c.cleanup() 1557 1558 c.createEndpointForFlow(flow) 1559 1560 // Disable the checksum generation. 1561 c.ep.SocketOptions().SetNoChecksum(true) 1562 // This option is effective on IPv4 only. 1563 testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) 1564 1565 // Enable the checksum generation. 1566 c.ep.SocketOptions().SetNoChecksum(false) 1567 testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) 1568 }) 1569 } 1570 } 1571 1572 var _ stack.NetworkInterface = (*testInterface)(nil) 1573 1574 type testInterface struct { 1575 stack.NetworkInterface 1576 } 1577 1578 func (*testInterface) ID() tcpip.NICID { 1579 return 0 1580 } 1581 1582 func (*testInterface) Enabled() bool { 1583 return true 1584 } 1585 1586 func TestTTL(t *testing.T) { 1587 for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { 1588 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1589 c := newDualTestContext(t, defaultMTU) 1590 defer c.cleanup() 1591 1592 c.createEndpointForFlow(flow) 1593 1594 const multicastTTL = 42 1595 if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { 1596 c.t.Fatalf("SetSockOptInt failed: %s", err) 1597 } 1598 1599 var wantTTL uint8 1600 if flow.isMulticast() { 1601 wantTTL = multicastTTL 1602 } else { 1603 var p stack.NetworkProtocolFactory 1604 var n tcpip.NetworkProtocolNumber 1605 if flow.isV4() { 1606 p = ipv4.NewProtocol 1607 n = ipv4.ProtocolNumber 1608 } else { 1609 p = ipv6.NewProtocol 1610 n = ipv6.ProtocolNumber 1611 } 1612 s := stack.New(stack.Options{ 1613 NetworkProtocols: []stack.NetworkProtocolFactory{p}, 1614 Clock: &faketime.NullClock{}, 1615 }) 1616 ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) 1617 wantTTL = ep.DefaultTTL() 1618 ep.Close() 1619 } 1620 1621 testWrite(c, flow, checker.TTL(wantTTL)) 1622 }) 1623 } 1624 } 1625 1626 func TestSetTTL(t *testing.T) { 1627 for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { 1628 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1629 for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { 1630 t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { 1631 c := newDualTestContext(t, defaultMTU) 1632 defer c.cleanup() 1633 1634 c.createEndpointForFlow(flow) 1635 1636 if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { 1637 c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) 1638 } 1639 1640 testWrite(c, flow, checker.TTL(wantTTL)) 1641 }) 1642 } 1643 }) 1644 } 1645 } 1646 1647 func TestSetTOS(t *testing.T) { 1648 for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { 1649 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1650 c := newDualTestContext(t, defaultMTU) 1651 defer c.cleanup() 1652 1653 c.createEndpointForFlow(flow) 1654 1655 const tos = testTOS 1656 v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) 1657 if err != nil { 1658 c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) 1659 } 1660 // Test for expected default value. 1661 if v != 0 { 1662 c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) 1663 } 1664 1665 if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { 1666 c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) 1667 } 1668 1669 v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) 1670 if err != nil { 1671 c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) 1672 } 1673 1674 if v != tos { 1675 c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) 1676 } 1677 1678 testWrite(c, flow, checker.TOS(tos, 0)) 1679 }) 1680 } 1681 } 1682 1683 func TestSetTClass(t *testing.T) { 1684 for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { 1685 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1686 c := newDualTestContext(t, defaultMTU) 1687 defer c.cleanup() 1688 1689 c.createEndpointForFlow(flow) 1690 1691 const tClass = testTOS 1692 v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) 1693 if err != nil { 1694 c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) 1695 } 1696 // Test for expected default value. 1697 if v != 0 { 1698 c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) 1699 } 1700 1701 if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { 1702 c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) 1703 } 1704 1705 v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) 1706 if err != nil { 1707 c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) 1708 } 1709 1710 if v != tClass { 1711 c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) 1712 } 1713 1714 // The header getter for TClass is called TOS, so use that checker. 1715 testWrite(c, flow, checker.TOS(tClass, 0)) 1716 }) 1717 } 1718 } 1719 1720 func TestReceiveTosTClass(t *testing.T) { 1721 const RcvTOSOpt = "ReceiveTosOption" 1722 const RcvTClassOpt = "ReceiveTClassOption" 1723 1724 testCases := []struct { 1725 name string 1726 tests []testFlow 1727 }{ 1728 {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, 1729 {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, 1730 } 1731 for _, testCase := range testCases { 1732 for _, flow := range testCase.tests { 1733 t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { 1734 c := newDualTestContext(t, defaultMTU) 1735 defer c.cleanup() 1736 1737 c.createEndpointForFlow(flow) 1738 name := testCase.name 1739 1740 var optionGetter func() bool 1741 var optionSetter func(bool) 1742 switch name { 1743 case RcvTOSOpt: 1744 optionGetter = c.ep.SocketOptions().GetReceiveTOS 1745 optionSetter = c.ep.SocketOptions().SetReceiveTOS 1746 case RcvTClassOpt: 1747 optionGetter = c.ep.SocketOptions().GetReceiveTClass 1748 optionSetter = c.ep.SocketOptions().SetReceiveTClass 1749 default: 1750 t.Fatalf("unkown test variant: %s", name) 1751 } 1752 1753 // Verify that setting and reading the option works. 1754 v := optionGetter() 1755 // Test for expected default value. 1756 if v != false { 1757 c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) 1758 } 1759 1760 const want = true 1761 optionSetter(want) 1762 1763 got := optionGetter() 1764 if got != want { 1765 c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) 1766 } 1767 1768 // Verify that the correct received TOS or TClass is handed through as 1769 // ancillary data to the ControlMessages struct. 1770 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 1771 c.t.Fatalf("Bind failed: %s", err) 1772 } 1773 switch name { 1774 case RcvTClassOpt: 1775 testRead(c, flow, checker.ReceiveTClass(testTOS)) 1776 case RcvTOSOpt: 1777 testRead(c, flow, checker.ReceiveTOS(testTOS)) 1778 default: 1779 t.Fatalf("unknown test variant: %s", name) 1780 } 1781 }) 1782 } 1783 } 1784 } 1785 1786 func TestMulticastInterfaceOption(t *testing.T) { 1787 for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { 1788 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1789 for _, bindTyp := range []string{"bound", "unbound"} { 1790 t.Run(bindTyp, func(t *testing.T) { 1791 for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { 1792 t.Run(optTyp, func(t *testing.T) { 1793 h := flow.header4Tuple(outgoing) 1794 mcastAddr := h.dstAddr.Addr 1795 localIfAddr := h.srcAddr.Addr 1796 1797 var ifoptSet tcpip.MulticastInterfaceOption 1798 switch optTyp { 1799 case "use local-addr": 1800 ifoptSet.InterfaceAddr = localIfAddr 1801 case "use NICID": 1802 ifoptSet.NIC = 1 1803 case "use local-addr and NIC": 1804 ifoptSet.InterfaceAddr = localIfAddr 1805 ifoptSet.NIC = 1 1806 default: 1807 t.Fatal("unknown test variant") 1808 } 1809 1810 c := newDualTestContext(t, defaultMTU) 1811 defer c.cleanup() 1812 1813 c.createEndpoint(flow.sockProto()) 1814 1815 if bindTyp == "bound" { 1816 // Bind the socket by connecting to the multicast address. 1817 // This may have an influence on how the multicast interface 1818 // is set. 1819 addr := tcpip.FullAddress{ 1820 Addr: flow.mapAddrIfApplicable(mcastAddr), 1821 Port: stackPort, 1822 } 1823 if err := c.ep.Connect(addr); err != nil { 1824 c.t.Fatalf("Connect failed: %s", err) 1825 } 1826 } 1827 1828 if err := c.ep.SetSockOpt(&ifoptSet); err != nil { 1829 c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) 1830 } 1831 1832 // Verify multicast interface addr and NIC were set correctly. 1833 // Note that NIC must be 1 since this is our outgoing interface. 1834 var ifoptGot tcpip.MulticastInterfaceOption 1835 if err := c.ep.GetSockOpt(&ifoptGot); err != nil { 1836 c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) 1837 } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { 1838 c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) 1839 } 1840 }) 1841 } 1842 }) 1843 } 1844 }) 1845 } 1846 } 1847 1848 // TestV4UnknownDestination verifies that we generate an ICMPv4 Destination 1849 // Unreachable message when a udp datagram is received on ports for which there 1850 // is no bound udp socket. 1851 func TestV4UnknownDestination(t *testing.T) { 1852 c := newDualTestContext(t, defaultMTU) 1853 defer c.cleanup() 1854 1855 testCases := []struct { 1856 flow testFlow 1857 icmpRequired bool 1858 // largePayload if true, will result in a payload large enough 1859 // so that the final generated IPv4 packet is larger than 1860 // header.IPv4MinimumProcessableDatagramSize. 1861 largePayload bool 1862 // badChecksum if true, will set an invalid checksum in the 1863 // header. 1864 badChecksum bool 1865 }{ 1866 {unicastV4, true, false, false}, 1867 {unicastV4, true, true, false}, 1868 {unicastV4, false, false, true}, 1869 {unicastV4, false, true, true}, 1870 {multicastV4, false, false, false}, 1871 {multicastV4, false, true, false}, 1872 {broadcast, false, false, false}, 1873 {broadcast, false, true, false}, 1874 } 1875 checksumErrors := uint64(0) 1876 for _, tc := range testCases { 1877 t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { 1878 payload := newPayload() 1879 if tc.largePayload { 1880 payload = newMinPayload(576) 1881 } 1882 c.injectPacket(tc.flow, payload, tc.badChecksum) 1883 if tc.badChecksum { 1884 checksumErrors++ 1885 if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { 1886 t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1887 } 1888 } 1889 if !tc.icmpRequired { 1890 if p, ok := c.linkEP.Read(); ok { 1891 t.Fatalf("unexpected packet received: %+v", p) 1892 } 1893 return 1894 } 1895 1896 // ICMP required. 1897 p, ok := c.linkEP.Read() 1898 if !ok { 1899 t.Fatalf("packet wasn't written out") 1900 return 1901 } 1902 1903 vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) 1904 pkt := vv.ToView() 1905 if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { 1906 t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) 1907 } 1908 1909 hdr := header.IPv4(pkt) 1910 checker.IPv4(t, hdr, checker.ICMPv4( 1911 checker.ICMPv4Type(header.ICMPv4DstUnreachable), 1912 checker.ICMPv4Code(header.ICMPv4PortUnreachable))) 1913 1914 // We need to compare the included data part of the UDP packet that is in 1915 // the ICMP packet with the matching original data. 1916 icmpPkt := header.ICMPv4(hdr.Payload()) 1917 payloadIPHeader := header.IPv4(icmpPkt.Payload()) 1918 incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize 1919 wantLen := len(payload) 1920 if tc.largePayload { 1921 // To work out the data size we need to simulate what the sender would 1922 // have done. The wanted size is the total available minus the sum of 1923 // the headers in the UDP AND ICMP packets, given that we know the test 1924 // had only a minimal IP header but the ICMP sender will have allowed 1925 // for a maximally sized packet header. 1926 wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength 1927 } 1928 1929 // In the case of large payloads the IP packet may be truncated. Update 1930 // the length field before retrieving the udp datagram payload. 1931 // Add back the two headers within the payload. 1932 payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) 1933 1934 origDgram := header.UDP(payloadIPHeader.Payload()) 1935 if got, want := len(origDgram.Payload()), wantLen; got != want { 1936 t.Fatalf("unexpected payload length got: %d, want: %d", got, want) 1937 } 1938 if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { 1939 t.Fatalf("unexpected payload got: %d, want: %d", got, want) 1940 } 1941 }) 1942 } 1943 } 1944 1945 // TestV6UnknownDestination verifies that we generate an ICMPv6 Destination 1946 // Unreachable message when a udp datagram is received on ports for which there 1947 // is no bound udp socket. 1948 func TestV6UnknownDestination(t *testing.T) { 1949 c := newDualTestContext(t, defaultMTU) 1950 defer c.cleanup() 1951 1952 testCases := []struct { 1953 flow testFlow 1954 icmpRequired bool 1955 // largePayload if true will result in a payload large enough to 1956 // create an IPv6 packet > header.IPv6MinimumMTU bytes. 1957 largePayload bool 1958 // badChecksum if true, will set an invalid checksum in the 1959 // header. 1960 badChecksum bool 1961 }{ 1962 {unicastV6, true, false, false}, 1963 {unicastV6, true, true, false}, 1964 {unicastV6, false, false, true}, 1965 {unicastV6, false, true, true}, 1966 {multicastV6, false, false, false}, 1967 {multicastV6, false, true, false}, 1968 } 1969 checksumErrors := uint64(0) 1970 for _, tc := range testCases { 1971 t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { 1972 payload := newPayload() 1973 if tc.largePayload { 1974 payload = newMinPayload(1280) 1975 } 1976 c.injectPacket(tc.flow, payload, tc.badChecksum) 1977 if tc.badChecksum { 1978 checksumErrors++ 1979 if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { 1980 t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1981 } 1982 } 1983 if !tc.icmpRequired { 1984 if p, ok := c.linkEP.Read(); ok { 1985 t.Fatalf("unexpected packet received: %+v", p) 1986 } 1987 return 1988 } 1989 1990 // ICMP required. 1991 p, ok := c.linkEP.Read() 1992 if !ok { 1993 t.Fatalf("packet wasn't written out") 1994 return 1995 } 1996 1997 vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) 1998 pkt := vv.ToView() 1999 if got, want := len(pkt), header.IPv6MinimumMTU; got > want { 2000 t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) 2001 } 2002 2003 hdr := header.IPv6(pkt) 2004 checker.IPv6(t, hdr, checker.ICMPv6( 2005 checker.ICMPv6Type(header.ICMPv6DstUnreachable), 2006 checker.ICMPv6Code(header.ICMPv6PortUnreachable))) 2007 2008 icmpPkt := header.ICMPv6(hdr.Payload()) 2009 payloadIPHeader := header.IPv6(icmpPkt.Payload()) 2010 wantLen := len(payload) 2011 if tc.largePayload { 2012 wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize 2013 } 2014 // In case of large payloads the IP packet may be truncated. Update 2015 // the length field before retrieving the udp datagram payload. 2016 payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) 2017 2018 origDgram := header.UDP(payloadIPHeader.Payload()) 2019 if got, want := len(origDgram.Payload()), wantLen; got != want { 2020 t.Fatalf("unexpected payload length got: %d, want: %d", got, want) 2021 } 2022 if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { 2023 t.Fatalf("unexpected payload got: %v, want: %v", got, want) 2024 } 2025 }) 2026 } 2027 } 2028 2029 // TestIncrementMalformedPacketsReceived verifies if the malformed received 2030 // global and endpoint stats are incremented. 2031 func TestIncrementMalformedPacketsReceived(t *testing.T) { 2032 c := newDualTestContext(t, defaultMTU) 2033 defer c.cleanup() 2034 2035 c.createEndpoint(ipv6.ProtocolNumber) 2036 // Bind to wildcard. 2037 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2038 c.t.Fatalf("Bind failed: %s", err) 2039 } 2040 2041 payload := newPayload() 2042 h := unicastV6.header4Tuple(incoming) 2043 buf := c.buildV6Packet(payload, &h) 2044 2045 // Invalidate the UDP header length field. 2046 u := header.UDP(buf[header.IPv6MinimumSize:]) 2047 u.SetLength(u.Length() + 1) 2048 2049 c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2050 Data: buf.ToVectorisedView(), 2051 })) 2052 2053 const want = 1 2054 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { 2055 t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) 2056 } 2057 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { 2058 t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) 2059 } 2060 } 2061 2062 // TestShortHeader verifies that when a packet with a too-short UDP header is 2063 // received, the malformed received global stat gets incremented. 2064 func TestShortHeader(t *testing.T) { 2065 c := newDualTestContext(t, defaultMTU) 2066 defer c.cleanup() 2067 2068 c.createEndpoint(ipv6.ProtocolNumber) 2069 // Bind to wildcard. 2070 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2071 c.t.Fatalf("Bind failed: %s", err) 2072 } 2073 2074 h := unicastV6.header4Tuple(incoming) 2075 2076 // Allocate a buffer for an IPv6 and too-short UDP header. 2077 const udpSize = header.UDPMinimumSize - 1 2078 buf := buffer.NewView(header.IPv6MinimumSize + udpSize) 2079 // Initialize the IP header. 2080 ip := header.IPv6(buf) 2081 ip.Encode(&header.IPv6Fields{ 2082 TrafficClass: testTOS, 2083 PayloadLength: uint16(udpSize), 2084 TransportProtocol: udp.ProtocolNumber, 2085 HopLimit: 65, 2086 SrcAddr: h.srcAddr.Addr, 2087 DstAddr: h.dstAddr.Addr, 2088 }) 2089 2090 // Initialize the UDP header. 2091 udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) 2092 udpHdr.Encode(&header.UDPFields{ 2093 SrcPort: h.srcAddr.Port, 2094 DstPort: h.dstAddr.Port, 2095 Length: header.UDPMinimumSize, 2096 }) 2097 // Calculate the UDP pseudo-header checksum. 2098 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) 2099 udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) 2100 // Copy all but the last byte of the UDP header into the packet. 2101 copy(buf[header.IPv6MinimumSize:], udpHdr) 2102 2103 // Inject packet. 2104 c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2105 Data: buf.ToVectorisedView(), 2106 })) 2107 2108 if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { 2109 t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) 2110 } 2111 } 2112 2113 // TestBadChecksumErrors verifies if a checksum error is detected, 2114 // global and endpoint stats are incremented. 2115 func TestBadChecksumErrors(t *testing.T) { 2116 for _, flow := range []testFlow{unicastV4, unicastV6} { 2117 t.Run(flow.String(), func(t *testing.T) { 2118 c := newDualTestContext(t, defaultMTU) 2119 defer c.cleanup() 2120 2121 c.createEndpoint(flow.sockProto()) 2122 // Bind to wildcard. 2123 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2124 c.t.Fatalf("Bind failed: %s", err) 2125 } 2126 2127 payload := newPayload() 2128 c.injectPacket(flow, payload, true /* badChecksum */) 2129 2130 const want = 1 2131 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { 2132 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 2133 } 2134 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 2135 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 2136 } 2137 }) 2138 } 2139 } 2140 2141 // TestPayloadModifiedV4 verifies if a checksum error is detected, 2142 // global and endpoint stats are incremented. 2143 func TestPayloadModifiedV4(t *testing.T) { 2144 c := newDualTestContext(t, defaultMTU) 2145 defer c.cleanup() 2146 2147 c.createEndpoint(ipv4.ProtocolNumber) 2148 // Bind to wildcard. 2149 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2150 c.t.Fatalf("Bind failed: %s", err) 2151 } 2152 2153 payload := newPayload() 2154 h := unicastV4.header4Tuple(incoming) 2155 buf := c.buildV4Packet(payload, &h) 2156 // Modify the payload so that the checksum value in the UDP header will be 2157 // incorrect. 2158 buf[len(buf)-1]++ 2159 c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2160 Data: buf.ToVectorisedView(), 2161 })) 2162 2163 const want = 1 2164 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { 2165 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 2166 } 2167 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 2168 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 2169 } 2170 } 2171 2172 // TestPayloadModifiedV6 verifies if a checksum error is detected, 2173 // global and endpoint stats are incremented. 2174 func TestPayloadModifiedV6(t *testing.T) { 2175 c := newDualTestContext(t, defaultMTU) 2176 defer c.cleanup() 2177 2178 c.createEndpoint(ipv6.ProtocolNumber) 2179 // Bind to wildcard. 2180 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2181 c.t.Fatalf("Bind failed: %s", err) 2182 } 2183 2184 payload := newPayload() 2185 h := unicastV6.header4Tuple(incoming) 2186 buf := c.buildV6Packet(payload, &h) 2187 // Modify the payload so that the checksum value in the UDP header will be 2188 // incorrect. 2189 buf[len(buf)-1]++ 2190 c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2191 Data: buf.ToVectorisedView(), 2192 })) 2193 2194 const want = 1 2195 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { 2196 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 2197 } 2198 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 2199 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 2200 } 2201 } 2202 2203 // TestChecksumZeroV4 verifies if the checksum value is zero, global and 2204 // endpoint states are *not* incremented (UDP checksum is optional on IPv4). 2205 func TestChecksumZeroV4(t *testing.T) { 2206 c := newDualTestContext(t, defaultMTU) 2207 defer c.cleanup() 2208 2209 c.createEndpoint(ipv4.ProtocolNumber) 2210 // Bind to wildcard. 2211 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2212 c.t.Fatalf("Bind failed: %s", err) 2213 } 2214 2215 payload := newPayload() 2216 h := unicastV4.header4Tuple(incoming) 2217 buf := c.buildV4Packet(payload, &h) 2218 // Set the checksum field in the UDP header to zero. 2219 u := header.UDP(buf[header.IPv4MinimumSize:]) 2220 u.SetChecksum(0) 2221 c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2222 Data: buf.ToVectorisedView(), 2223 })) 2224 2225 const want = 0 2226 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { 2227 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 2228 } 2229 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 2230 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 2231 } 2232 } 2233 2234 // TestChecksumZeroV6 verifies if the checksum value is zero, global and 2235 // endpoint states are incremented (UDP checksum is *not* optional on IPv6). 2236 func TestChecksumZeroV6(t *testing.T) { 2237 c := newDualTestContext(t, defaultMTU) 2238 defer c.cleanup() 2239 2240 c.createEndpoint(ipv6.ProtocolNumber) 2241 // Bind to wildcard. 2242 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2243 c.t.Fatalf("Bind failed: %s", err) 2244 } 2245 2246 payload := newPayload() 2247 h := unicastV6.header4Tuple(incoming) 2248 buf := c.buildV6Packet(payload, &h) 2249 // Set the checksum field in the UDP header to zero. 2250 u := header.UDP(buf[header.IPv6MinimumSize:]) 2251 u.SetChecksum(0) 2252 c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2253 Data: buf.ToVectorisedView(), 2254 })) 2255 2256 const want = 1 2257 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { 2258 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 2259 } 2260 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 2261 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 2262 } 2263 } 2264 2265 // TestShutdownRead verifies endpoint read shutdown and error 2266 // stats increment on packet receive. 2267 func TestShutdownRead(t *testing.T) { 2268 c := newDualTestContext(t, defaultMTU) 2269 defer c.cleanup() 2270 2271 c.createEndpoint(ipv6.ProtocolNumber) 2272 2273 // Bind to wildcard. 2274 if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { 2275 c.t.Fatalf("Bind failed: %s", err) 2276 } 2277 2278 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { 2279 c.t.Fatalf("Connect failed: %s", err) 2280 } 2281 2282 if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { 2283 t.Fatalf("Shutdown failed: %s", err) 2284 } 2285 2286 testFailingRead(c, unicastV6, true /* expectReadError */) 2287 2288 var want uint64 = 1 2289 if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { 2290 t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) 2291 } 2292 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { 2293 t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) 2294 } 2295 } 2296 2297 // TestShutdownWrite verifies endpoint write shutdown and error 2298 // stats increment on packet write. 2299 func TestShutdownWrite(t *testing.T) { 2300 c := newDualTestContext(t, defaultMTU) 2301 defer c.cleanup() 2302 2303 c.createEndpoint(ipv6.ProtocolNumber) 2304 2305 if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { 2306 c.t.Fatalf("Connect failed: %s", err) 2307 } 2308 2309 if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { 2310 t.Fatalf("Shutdown failed: %s", err) 2311 } 2312 2313 testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) 2314 } 2315 2316 func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { 2317 got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() 2318 switch err.(type) { 2319 case nil: 2320 want.PacketsSent.IncrementBy(incr) 2321 case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: 2322 want.WriteErrors.InvalidArgs.IncrementBy(incr) 2323 case *tcpip.ErrClosedForSend: 2324 want.WriteErrors.WriteClosed.IncrementBy(incr) 2325 case *tcpip.ErrInvalidEndpointState: 2326 want.WriteErrors.InvalidEndpointState.IncrementBy(incr) 2327 case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: 2328 want.SendErrors.NoRoute.IncrementBy(incr) 2329 default: 2330 want.SendErrors.SendToNetworkFailed.IncrementBy(incr) 2331 } 2332 if got != want { 2333 c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) 2334 } 2335 } 2336 2337 func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { 2338 got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() 2339 switch err.(type) { 2340 case nil, *tcpip.ErrWouldBlock: 2341 case *tcpip.ErrClosedForReceive: 2342 want.ReadErrors.ReadClosed.IncrementBy(incr) 2343 default: 2344 c.t.Errorf("Endpoint error missing stats update err %v", err) 2345 } 2346 if got != want { 2347 c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) 2348 } 2349 } 2350 2351 func TestOutgoingSubnetBroadcast(t *testing.T) { 2352 const nicID1 = 1 2353 2354 ipv4Addr := tcpip.AddressWithPrefix{ 2355 Address: "\xc0\xa8\x01\x3a", 2356 PrefixLen: 24, 2357 } 2358 ipv4Subnet := ipv4Addr.Subnet() 2359 ipv4SubnetBcast := ipv4Subnet.Broadcast() 2360 ipv4Gateway := testutil.MustParse4("192.168.1.1") 2361 ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ 2362 Address: "\xc0\xa8\x01\x3a", 2363 PrefixLen: 31, 2364 } 2365 ipv4Subnet31 := ipv4AddrPrefix31.Subnet() 2366 ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() 2367 ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ 2368 Address: "\xc0\xa8\x01\x3a", 2369 PrefixLen: 32, 2370 } 2371 ipv4Subnet32 := ipv4AddrPrefix32.Subnet() 2372 ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() 2373 ipv6Addr := tcpip.AddressWithPrefix{ 2374 Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 2375 PrefixLen: 64, 2376 } 2377 ipv6Subnet := ipv6Addr.Subnet() 2378 ipv6SubnetBcast := ipv6Subnet.Broadcast() 2379 remNetAddr := tcpip.AddressWithPrefix{ 2380 Address: "\x64\x0a\x7b\x18", 2381 PrefixLen: 24, 2382 } 2383 remNetSubnet := remNetAddr.Subnet() 2384 remNetSubnetBcast := remNetSubnet.Broadcast() 2385 2386 tests := []struct { 2387 name string 2388 nicAddr tcpip.ProtocolAddress 2389 routes []tcpip.Route 2390 remoteAddr tcpip.Address 2391 requiresBroadcastOpt bool 2392 }{ 2393 { 2394 name: "IPv4 Broadcast to local subnet", 2395 nicAddr: tcpip.ProtocolAddress{ 2396 Protocol: header.IPv4ProtocolNumber, 2397 AddressWithPrefix: ipv4Addr, 2398 }, 2399 routes: []tcpip.Route{ 2400 { 2401 Destination: ipv4Subnet, 2402 NIC: nicID1, 2403 }, 2404 }, 2405 remoteAddr: ipv4SubnetBcast, 2406 requiresBroadcastOpt: true, 2407 }, 2408 { 2409 name: "IPv4 Broadcast to local /31 subnet", 2410 nicAddr: tcpip.ProtocolAddress{ 2411 Protocol: header.IPv4ProtocolNumber, 2412 AddressWithPrefix: ipv4AddrPrefix31, 2413 }, 2414 routes: []tcpip.Route{ 2415 { 2416 Destination: ipv4Subnet31, 2417 NIC: nicID1, 2418 }, 2419 }, 2420 remoteAddr: ipv4Subnet31Bcast, 2421 requiresBroadcastOpt: false, 2422 }, 2423 { 2424 name: "IPv4 Broadcast to local /32 subnet", 2425 nicAddr: tcpip.ProtocolAddress{ 2426 Protocol: header.IPv4ProtocolNumber, 2427 AddressWithPrefix: ipv4AddrPrefix32, 2428 }, 2429 routes: []tcpip.Route{ 2430 { 2431 Destination: ipv4Subnet32, 2432 NIC: nicID1, 2433 }, 2434 }, 2435 remoteAddr: ipv4Subnet32Bcast, 2436 requiresBroadcastOpt: false, 2437 }, 2438 // IPv6 has no notion of a broadcast. 2439 { 2440 name: "IPv6 'Broadcast' to local subnet", 2441 nicAddr: tcpip.ProtocolAddress{ 2442 Protocol: header.IPv6ProtocolNumber, 2443 AddressWithPrefix: ipv6Addr, 2444 }, 2445 routes: []tcpip.Route{ 2446 { 2447 Destination: ipv6Subnet, 2448 NIC: nicID1, 2449 }, 2450 }, 2451 remoteAddr: ipv6SubnetBcast, 2452 requiresBroadcastOpt: false, 2453 }, 2454 { 2455 name: "IPv4 Broadcast to remote subnet", 2456 nicAddr: tcpip.ProtocolAddress{ 2457 Protocol: header.IPv4ProtocolNumber, 2458 AddressWithPrefix: ipv4Addr, 2459 }, 2460 routes: []tcpip.Route{ 2461 { 2462 Destination: remNetSubnet, 2463 Gateway: ipv4Gateway, 2464 NIC: nicID1, 2465 }, 2466 }, 2467 remoteAddr: remNetSubnetBcast, 2468 // TODO(github.com/SagerNet/issue/3938): Once we support marking a route as 2469 // broadcast, this test should require the broadcast option to be set. 2470 requiresBroadcastOpt: false, 2471 }, 2472 } 2473 2474 for _, test := range tests { 2475 t.Run(test.name, func(t *testing.T) { 2476 s := stack.New(stack.Options{ 2477 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 2478 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 2479 Clock: &faketime.NullClock{}, 2480 }) 2481 e := channel.New(0, defaultMTU, "") 2482 if err := s.CreateNIC(nicID1, e); err != nil { 2483 t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) 2484 } 2485 if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { 2486 t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) 2487 } 2488 2489 s.SetRouteTable(test.routes) 2490 2491 var netProto tcpip.NetworkProtocolNumber 2492 switch l := len(test.remoteAddr); l { 2493 case header.IPv4AddressSize: 2494 netProto = header.IPv4ProtocolNumber 2495 case header.IPv6AddressSize: 2496 netProto = header.IPv6ProtocolNumber 2497 default: 2498 t.Fatalf("got unexpected address length = %d bytes", l) 2499 } 2500 2501 wq := waiter.Queue{} 2502 ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) 2503 if err != nil { 2504 t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) 2505 } 2506 defer ep.Close() 2507 2508 var r bytes.Reader 2509 data := []byte{1, 2, 3, 4} 2510 to := tcpip.FullAddress{ 2511 Addr: test.remoteAddr, 2512 Port: 80, 2513 } 2514 opts := tcpip.WriteOptions{To: &to} 2515 expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { 2516 if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { 2517 return nil 2518 } 2519 return &tcpip.ErrBroadcastDisabled{} 2520 } 2521 if !test.requiresBroadcastOpt { 2522 expectedErrWithoutBcastOpt = nil 2523 } 2524 2525 r.Reset(data) 2526 { 2527 n, err := ep.Write(&r, opts) 2528 if expectedErrWithoutBcastOpt != nil { 2529 if want := expectedErrWithoutBcastOpt(err); want != nil { 2530 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) 2531 } 2532 } else if err != nil { 2533 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) 2534 } 2535 } 2536 2537 ep.SocketOptions().SetBroadcast(true) 2538 2539 r.Reset(data) 2540 if n, err := ep.Write(&r, opts); err != nil { 2541 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) 2542 } 2543 2544 ep.SocketOptions().SetBroadcast(false) 2545 2546 r.Reset(data) 2547 { 2548 n, err := ep.Write(&r, opts) 2549 if expectedErrWithoutBcastOpt != nil { 2550 if want := expectedErrWithoutBcastOpt(err); want != nil { 2551 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) 2552 } 2553 } else if err != nil { 2554 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) 2555 } 2556 } 2557 }) 2558 } 2559 }