github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/ip_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ip_test 16 17 import ( 18 "fmt" 19 "strings" 20 "testing" 21 22 "github.com/google/go-cmp/cmp" 23 "github.com/SagerNet/gvisor/pkg/sync" 24 "github.com/SagerNet/gvisor/pkg/tcpip" 25 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 26 "github.com/SagerNet/gvisor/pkg/tcpip/checker" 27 "github.com/SagerNet/gvisor/pkg/tcpip/header" 28 "github.com/SagerNet/gvisor/pkg/tcpip/link/channel" 29 "github.com/SagerNet/gvisor/pkg/tcpip/link/loopback" 30 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4" 31 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6" 32 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 33 "github.com/SagerNet/gvisor/pkg/tcpip/testutil" 34 "github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp" 35 "github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp" 36 "github.com/SagerNet/gvisor/pkg/tcpip/transport/udp" 37 ) 38 39 const nicID = 1 40 41 var ( 42 localIPv4Addr = testutil.MustParse4("10.0.0.1") 43 remoteIPv4Addr = testutil.MustParse4("10.0.0.2") 44 ipv4SubnetAddr = testutil.MustParse4("10.0.0.0") 45 ipv4SubnetMask = testutil.MustParse4("255.255.255.0") 46 ipv4Gateway = testutil.MustParse4("10.0.0.3") 47 localIPv6Addr = testutil.MustParse6("a00::1") 48 remoteIPv6Addr = testutil.MustParse6("a00::2") 49 ipv6SubnetAddr = testutil.MustParse6("a00::") 50 ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00") 51 ipv6Gateway = testutil.MustParse6("a00::3") 52 ) 53 54 var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ 55 Address: localIPv4Addr, 56 PrefixLen: 24, 57 } 58 59 var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ 60 Address: localIPv6Addr, 61 PrefixLen: 120, 62 } 63 64 type transportError struct { 65 origin tcpip.SockErrOrigin 66 typ uint8 67 code uint8 68 info uint32 69 kind stack.TransportErrorKind 70 } 71 72 // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. 73 // The former is used to pretend that it's a link endpoint so that we can 74 // inspect packets written by the network endpoints. The latter is used to 75 // pretend that it's the network stack so that it can inspect incoming packets 76 // that have been handled by the network endpoints. 77 // 78 // Packets are checked by comparing their fields/values against the expected 79 // values stored in the test object itself. 80 type testObject struct { 81 t *testing.T 82 protocol tcpip.TransportProtocolNumber 83 contents []byte 84 srcAddr tcpip.Address 85 dstAddr tcpip.Address 86 v4 bool 87 transErr transportError 88 89 dataCalls int 90 controlCalls int 91 rawCalls int 92 } 93 94 // checkValues verifies that the transport protocol, data contents, src & dst 95 // addresses of a packet match what's expected. If any field doesn't match, the 96 // test fails. 97 func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) { 98 if protocol != t.protocol { 99 t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) 100 } 101 102 if srcAddr != t.srcAddr { 103 t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) 104 } 105 106 if dstAddr != t.dstAddr { 107 t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) 108 } 109 110 if len(v) != len(t.contents) { 111 t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) 112 } 113 114 for i := range t.contents { 115 if t.contents[i] != v[i] { 116 t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) 117 } 118 } 119 } 120 121 // DeliverTransportPacket is called by network endpoints after parsing incoming 122 // packets. This is used by the test object to verify that the results of the 123 // parsing are expected. 124 func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { 125 netHdr := pkt.Network() 126 t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress()) 127 t.dataCalls++ 128 return stack.TransportPacketHandled 129 } 130 131 // DeliverTransportError is called by network endpoints after parsing 132 // incoming control (ICMP) packets. This is used by the test object to verify 133 // that the results of the parsing are expected. 134 func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { 135 t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local) 136 if diff := cmp.Diff( 137 t.transErr, 138 transportError{ 139 origin: transErr.Origin(), 140 typ: transErr.Type(), 141 code: transErr.Code(), 142 info: transErr.Info(), 143 kind: transErr.Kind(), 144 }, 145 cmp.AllowUnexported(transportError{}), 146 ); diff != "" { 147 t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) 148 } 149 t.controlCalls++ 150 } 151 152 func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { 153 t.rawCalls++ 154 } 155 156 // Attach is only implemented to satisfy the LinkEndpoint interface. 157 func (*testObject) Attach(stack.NetworkDispatcher) {} 158 159 // IsAttached implements stack.LinkEndpoint.IsAttached. 160 func (*testObject) IsAttached() bool { 161 return true 162 } 163 164 // MTU implements stack.LinkEndpoint.MTU. It just returns a constant that 165 // matches the linux loopback MTU. 166 func (*testObject) MTU() uint32 { 167 return 65536 168 } 169 170 // Capabilities implements stack.LinkEndpoint.Capabilities. 171 func (*testObject) Capabilities() stack.LinkEndpointCapabilities { 172 return 0 173 } 174 175 // MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. 176 func (*testObject) MaxHeaderLength() uint16 { 177 return 0 178 } 179 180 // LinkAddress returns the link address of this endpoint. 181 func (*testObject) LinkAddress() tcpip.LinkAddress { 182 return "" 183 } 184 185 // Wait implements stack.LinkEndpoint.Wait. 186 func (*testObject) Wait() {} 187 188 // WritePacket is called by network endpoints after producing a packet and 189 // writing it to the link endpoint. This is used by the test object to verify 190 // that the produced packet is as expected. 191 func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { 192 var prot tcpip.TransportProtocolNumber 193 var srcAddr tcpip.Address 194 var dstAddr tcpip.Address 195 196 if t.v4 { 197 h := header.IPv4(pkt.NetworkHeader().View()) 198 prot = tcpip.TransportProtocolNumber(h.Protocol()) 199 srcAddr = h.SourceAddress() 200 dstAddr = h.DestinationAddress() 201 202 } else { 203 h := header.IPv6(pkt.NetworkHeader().View()) 204 prot = tcpip.TransportProtocolNumber(h.NextHeader()) 205 srcAddr = h.SourceAddress() 206 dstAddr = h.DestinationAddress() 207 } 208 t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr) 209 return nil 210 } 211 212 // WritePackets implements stack.LinkEndpoint.WritePackets. 213 func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { 214 panic("not implemented") 215 } 216 217 // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. 218 func (*testObject) ARPHardwareType() header.ARPHardwareType { 219 panic("not implemented") 220 } 221 222 // AddHeader implements stack.LinkEndpoint.AddHeader. 223 func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { 224 panic("not implemented") 225 } 226 227 func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { 228 s := stack.New(stack.Options{ 229 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 230 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 231 }) 232 s.CreateNIC(nicID, loopback.New()) 233 s.AddAddress(nicID, ipv4.ProtocolNumber, local) 234 s.SetRouteTable([]tcpip.Route{{ 235 Destination: header.IPv4EmptySubnet, 236 Gateway: ipv4Gateway, 237 NIC: 1, 238 }}) 239 240 return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) 241 } 242 243 func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { 244 s := stack.New(stack.Options{ 245 NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, 246 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 247 }) 248 s.CreateNIC(nicID, loopback.New()) 249 s.AddAddress(nicID, ipv6.ProtocolNumber, local) 250 s.SetRouteTable([]tcpip.Route{{ 251 Destination: header.IPv6EmptySubnet, 252 Gateway: ipv6Gateway, 253 NIC: 1, 254 }}) 255 256 return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) 257 } 258 259 func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { 260 t.Helper() 261 262 s := stack.New(stack.Options{ 263 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 264 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 265 }) 266 e := channel.New(1, mtu, "") 267 if err := s.CreateNIC(nicID, e); err != nil { 268 t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) 269 } 270 271 v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} 272 if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { 273 t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) 274 } 275 276 v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} 277 if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { 278 t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) 279 } 280 281 return s, e 282 } 283 284 func buildDummyStack(t *testing.T) *stack.Stack { 285 t.Helper() 286 287 s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) 288 return s 289 } 290 291 var _ stack.NetworkInterface = (*testInterface)(nil) 292 293 type testInterface struct { 294 testObject 295 296 mu struct { 297 sync.RWMutex 298 disabled bool 299 } 300 } 301 302 func (*testInterface) ID() tcpip.NICID { 303 return nicID 304 } 305 306 func (*testInterface) IsLoopback() bool { 307 return false 308 } 309 310 func (*testInterface) Name() string { 311 return "" 312 } 313 314 func (t *testInterface) Enabled() bool { 315 t.mu.RLock() 316 defer t.mu.RUnlock() 317 return !t.mu.disabled 318 } 319 320 func (*testInterface) Promiscuous() bool { 321 return false 322 } 323 324 func (*testInterface) Spoofing() bool { 325 return false 326 } 327 328 func (t *testInterface) setEnabled(v bool) { 329 t.mu.Lock() 330 defer t.mu.Unlock() 331 t.mu.disabled = !v 332 } 333 334 func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { 335 return &tcpip.ErrNotSupported{} 336 } 337 338 func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { 339 return nil 340 } 341 342 func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { 343 return nil 344 } 345 346 func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { 347 return tcpip.AddressWithPrefix{}, nil 348 } 349 350 func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { 351 return false 352 } 353 354 func TestSourceAddressValidation(t *testing.T) { 355 rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { 356 totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize 357 hdr := buffer.NewPrependable(totalLen) 358 pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) 359 pkt.SetType(header.ICMPv4Echo) 360 pkt.SetCode(0) 361 pkt.SetChecksum(0) 362 pkt.SetChecksum(^header.Checksum(pkt, 0)) 363 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 364 ip.Encode(&header.IPv4Fields{ 365 TotalLength: uint16(totalLen), 366 Protocol: uint8(icmp.ProtocolNumber4), 367 TTL: ipv4.DefaultTTL, 368 SrcAddr: src, 369 DstAddr: localIPv4Addr, 370 }) 371 ip.SetChecksum(^ip.CalculateChecksum()) 372 373 e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 374 Data: hdr.View().ToVectorisedView(), 375 })) 376 } 377 378 rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { 379 totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize 380 hdr := buffer.NewPrependable(totalLen) 381 pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) 382 pkt.SetType(header.ICMPv6EchoRequest) 383 pkt.SetCode(0) 384 pkt.SetChecksum(0) 385 pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 386 Header: pkt, 387 Src: src, 388 Dst: localIPv6Addr, 389 })) 390 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 391 ip.Encode(&header.IPv6Fields{ 392 PayloadLength: header.ICMPv6MinimumSize, 393 TransportProtocol: icmp.ProtocolNumber6, 394 HopLimit: ipv6.DefaultTTL, 395 SrcAddr: src, 396 DstAddr: localIPv6Addr, 397 }) 398 e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 399 Data: hdr.View().ToVectorisedView(), 400 })) 401 } 402 403 tests := []struct { 404 name string 405 srcAddress tcpip.Address 406 rxICMP func(*channel.Endpoint, tcpip.Address) 407 valid bool 408 }{ 409 { 410 name: "IPv4 valid", 411 srcAddress: "\x01\x02\x03\x04", 412 rxICMP: rxIPv4ICMP, 413 valid: true, 414 }, 415 { 416 name: "IPv6 valid", 417 srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", 418 rxICMP: rxIPv6ICMP, 419 valid: true, 420 }, 421 { 422 name: "IPv4 unspecified", 423 srcAddress: header.IPv4Any, 424 rxICMP: rxIPv4ICMP, 425 valid: true, 426 }, 427 { 428 name: "IPv6 unspecified", 429 srcAddress: header.IPv4Any, 430 rxICMP: rxIPv6ICMP, 431 valid: true, 432 }, 433 { 434 name: "IPv4 multicast", 435 srcAddress: "\xe0\x00\x00\x01", 436 rxICMP: rxIPv4ICMP, 437 valid: false, 438 }, 439 { 440 name: "IPv6 multicast", 441 srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 442 rxICMP: rxIPv6ICMP, 443 valid: false, 444 }, 445 { 446 name: "IPv4 broadcast", 447 srcAddress: header.IPv4Broadcast, 448 rxICMP: rxIPv4ICMP, 449 valid: false, 450 }, 451 { 452 name: "IPv4 subnet broadcast", 453 srcAddress: func() tcpip.Address { 454 subnet := localIPv4AddrWithPrefix.Subnet() 455 return subnet.Broadcast() 456 }(), 457 rxICMP: rxIPv4ICMP, 458 valid: false, 459 }, 460 } 461 462 for _, test := range tests { 463 t.Run(test.name, func(t *testing.T) { 464 s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) 465 test.rxICMP(e, test.srcAddress) 466 467 var wantValid uint64 468 if test.valid { 469 wantValid = 1 470 } 471 472 if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { 473 t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) 474 } 475 if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { 476 t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) 477 } 478 }) 479 } 480 } 481 482 func TestEnableWhenNICDisabled(t *testing.T) { 483 tests := []struct { 484 name string 485 protocolFactory stack.NetworkProtocolFactory 486 protoNum tcpip.NetworkProtocolNumber 487 }{ 488 { 489 name: "IPv4", 490 protocolFactory: ipv4.NewProtocol, 491 protoNum: ipv4.ProtocolNumber, 492 }, 493 { 494 name: "IPv6", 495 protocolFactory: ipv6.NewProtocol, 496 protoNum: ipv6.ProtocolNumber, 497 }, 498 } 499 500 for _, test := range tests { 501 t.Run(test.name, func(t *testing.T) { 502 var nic testInterface 503 nic.setEnabled(false) 504 505 s := stack.New(stack.Options{ 506 NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, 507 }) 508 p := s.NetworkProtocolInstance(test.protoNum) 509 510 // We pass nil for all parameters except the NetworkInterface and Stack 511 // since Enable only depends on these. 512 ep := p.NewEndpoint(&nic, nil) 513 514 // The endpoint should initially be disabled, regardless the NIC's enabled 515 // status. 516 if ep.Enabled() { 517 t.Fatal("got ep.Enabled() = true, want = false") 518 } 519 nic.setEnabled(true) 520 if ep.Enabled() { 521 t.Fatal("got ep.Enabled() = true, want = false") 522 } 523 524 // Attempting to enable the endpoint while the NIC is disabled should 525 // fail. 526 nic.setEnabled(false) 527 err := ep.Enable() 528 if _, ok := err.(*tcpip.ErrNotPermitted); !ok { 529 t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) 530 } 531 // ep should consider the NIC's enabled status when determining its own 532 // enabled status so we "enable" the NIC to read just the endpoint's 533 // enabled status. 534 nic.setEnabled(true) 535 if ep.Enabled() { 536 t.Fatal("got ep.Enabled() = true, want = false") 537 } 538 539 // Enabling the interface after the NIC has been enabled should succeed. 540 if err := ep.Enable(); err != nil { 541 t.Fatalf("ep.Enable(): %s", err) 542 } 543 if !ep.Enabled() { 544 t.Fatal("got ep.Enabled() = false, want = true") 545 } 546 547 // ep should consider the NIC's enabled status when determining its own 548 // enabled status. 549 nic.setEnabled(false) 550 if ep.Enabled() { 551 t.Fatal("got ep.Enabled() = true, want = false") 552 } 553 554 // Disabling the endpoint when the NIC is enabled should make the endpoint 555 // disabled. 556 nic.setEnabled(true) 557 ep.Disable() 558 if ep.Enabled() { 559 t.Fatal("got ep.Enabled() = true, want = false") 560 } 561 }) 562 } 563 } 564 565 func TestIPv4Send(t *testing.T) { 566 s := buildDummyStack(t) 567 proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) 568 nic := testInterface{ 569 testObject: testObject{ 570 t: t, 571 v4: true, 572 }, 573 } 574 ep := proto.NewEndpoint(&nic, nil) 575 defer ep.Close() 576 577 // Allocate and initialize the payload view. 578 payload := buffer.NewView(100) 579 for i := 0; i < len(payload); i++ { 580 payload[i] = uint8(i) 581 } 582 583 // Setup the packet buffer. 584 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 585 ReserveHeaderBytes: int(ep.MaxHeaderLength()), 586 Data: payload.ToVectorisedView(), 587 }) 588 589 // Issue the write. 590 nic.testObject.protocol = 123 591 nic.testObject.srcAddr = localIPv4Addr 592 nic.testObject.dstAddr = remoteIPv4Addr 593 nic.testObject.contents = payload 594 595 r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) 596 if err != nil { 597 t.Fatalf("could not find route: %v", err) 598 } 599 if err := ep.WritePacket(r, stack.NetworkHeaderParams{ 600 Protocol: 123, 601 TTL: 123, 602 TOS: stack.DefaultTOS, 603 }, pkt); err != nil { 604 t.Fatalf("WritePacket failed: %v", err) 605 } 606 } 607 608 func TestReceive(t *testing.T) { 609 tests := []struct { 610 name string 611 protoFactory stack.NetworkProtocolFactory 612 protoNum tcpip.NetworkProtocolNumber 613 v4 bool 614 epAddr tcpip.AddressWithPrefix 615 handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) 616 }{ 617 { 618 name: "IPv4", 619 protoFactory: ipv4.NewProtocol, 620 protoNum: ipv4.ProtocolNumber, 621 v4: true, 622 epAddr: localIPv4Addr.WithPrefix(), 623 handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { 624 const totalLen = header.IPv4MinimumSize + 30 /* payload length */ 625 626 view := buffer.NewView(totalLen) 627 ip := header.IPv4(view) 628 ip.Encode(&header.IPv4Fields{ 629 TotalLength: totalLen, 630 TTL: ipv4.DefaultTTL, 631 Protocol: 10, 632 SrcAddr: remoteIPv4Addr, 633 DstAddr: localIPv4Addr, 634 }) 635 ip.SetChecksum(^ip.CalculateChecksum()) 636 637 // Make payload be non-zero. 638 for i := header.IPv4MinimumSize; i < len(view); i++ { 639 view[i] = uint8(i) 640 } 641 642 // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. 643 nic.testObject.protocol = 10 644 nic.testObject.srcAddr = remoteIPv4Addr 645 nic.testObject.dstAddr = localIPv4Addr 646 nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] 647 648 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 649 Data: view.ToVectorisedView(), 650 }) 651 ep.HandlePacket(pkt) 652 }, 653 }, 654 { 655 name: "IPv6", 656 protoFactory: ipv6.NewProtocol, 657 protoNum: ipv6.ProtocolNumber, 658 v4: false, 659 epAddr: localIPv6Addr.WithPrefix(), 660 handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { 661 const payloadLen = 30 662 view := buffer.NewView(header.IPv6MinimumSize + payloadLen) 663 ip := header.IPv6(view) 664 ip.Encode(&header.IPv6Fields{ 665 PayloadLength: payloadLen, 666 TransportProtocol: 10, 667 HopLimit: ipv6.DefaultTTL, 668 SrcAddr: remoteIPv6Addr, 669 DstAddr: localIPv6Addr, 670 }) 671 672 // Make payload be non-zero. 673 for i := header.IPv6MinimumSize; i < len(view); i++ { 674 view[i] = uint8(i) 675 } 676 677 // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. 678 nic.testObject.protocol = 10 679 nic.testObject.srcAddr = remoteIPv6Addr 680 nic.testObject.dstAddr = localIPv6Addr 681 nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] 682 683 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 684 Data: view.ToVectorisedView(), 685 }) 686 ep.HandlePacket(pkt) 687 }, 688 }, 689 } 690 691 for _, test := range tests { 692 t.Run(test.name, func(t *testing.T) { 693 s := stack.New(stack.Options{ 694 NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, 695 }) 696 nic := testInterface{ 697 testObject: testObject{ 698 t: t, 699 v4: test.v4, 700 }, 701 } 702 ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) 703 defer ep.Close() 704 705 if err := ep.Enable(); err != nil { 706 t.Fatalf("ep.Enable(): %s", err) 707 } 708 709 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 710 if !ok { 711 t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) 712 } 713 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { 714 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) 715 } else { 716 ep.DecRef() 717 } 718 719 stat := s.Stats().IP.PacketsReceived 720 if got := stat.Value(); got != 0 { 721 t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) 722 } 723 test.handlePacket(t, ep, &nic) 724 if nic.testObject.dataCalls != 1 { 725 t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) 726 } 727 if nic.testObject.rawCalls != 1 { 728 t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) 729 } 730 if got := stat.Value(); got != 1 { 731 t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) 732 } 733 }) 734 } 735 } 736 737 func TestIPv4ReceiveControl(t *testing.T) { 738 const ( 739 mtu = 0xbeef - header.IPv4MinimumSize 740 dataLen = 8 741 ) 742 743 cases := []struct { 744 name string 745 expectedCount int 746 fragmentOffset uint16 747 code header.ICMPv4Code 748 transErr transportError 749 trunc int 750 }{ 751 { 752 name: "FragmentationNeeded", 753 expectedCount: 1, 754 fragmentOffset: 0, 755 code: header.ICMPv4FragmentationNeeded, 756 transErr: transportError{ 757 origin: tcpip.SockExtErrorOriginICMP, 758 typ: uint8(header.ICMPv4DstUnreachable), 759 code: uint8(header.ICMPv4FragmentationNeeded), 760 info: mtu, 761 kind: stack.PacketTooBigTransportError, 762 }, 763 trunc: 0, 764 }, 765 { 766 name: "Truncated (missing IPv4 header)", 767 expectedCount: 0, 768 fragmentOffset: 0, 769 code: header.ICMPv4FragmentationNeeded, 770 trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, 771 }, 772 { 773 name: "Truncated (partial offending packet's IP header)", 774 expectedCount: 0, 775 fragmentOffset: 0, 776 code: header.ICMPv4FragmentationNeeded, 777 trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, 778 }, 779 { 780 name: "Truncated (partial offending packet's data)", 781 expectedCount: 0, 782 fragmentOffset: 0, 783 code: header.ICMPv4FragmentationNeeded, 784 trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, 785 }, 786 { 787 name: "Port unreachable", 788 expectedCount: 1, 789 fragmentOffset: 0, 790 code: header.ICMPv4PortUnreachable, 791 transErr: transportError{ 792 origin: tcpip.SockExtErrorOriginICMP, 793 typ: uint8(header.ICMPv4DstUnreachable), 794 code: uint8(header.ICMPv4PortUnreachable), 795 kind: stack.DestinationPortUnreachableTransportError, 796 }, 797 trunc: 0, 798 }, 799 { 800 name: "Non-zero fragment offset", 801 expectedCount: 0, 802 fragmentOffset: 100, 803 code: header.ICMPv4PortUnreachable, 804 trunc: 0, 805 }, 806 { 807 name: "Zero-length packet", 808 expectedCount: 0, 809 fragmentOffset: 100, 810 code: header.ICMPv4PortUnreachable, 811 trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, 812 }, 813 } 814 for _, c := range cases { 815 t.Run(c.name, func(t *testing.T) { 816 s := buildDummyStack(t) 817 proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) 818 nic := testInterface{ 819 testObject: testObject{ 820 t: t, 821 }, 822 } 823 ep := proto.NewEndpoint(&nic, &nic.testObject) 824 defer ep.Close() 825 826 if err := ep.Enable(); err != nil { 827 t.Fatalf("ep.Enable(): %s", err) 828 } 829 830 const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize 831 view := buffer.NewView(dataOffset + dataLen) 832 833 // Create the outer IPv4 header. 834 ip := header.IPv4(view) 835 ip.Encode(&header.IPv4Fields{ 836 TotalLength: uint16(len(view) - c.trunc), 837 TTL: 20, 838 Protocol: uint8(header.ICMPv4ProtocolNumber), 839 SrcAddr: "\x0a\x00\x00\xbb", 840 DstAddr: localIPv4Addr, 841 }) 842 ip.SetChecksum(^ip.CalculateChecksum()) 843 844 // Create the ICMP header. 845 icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) 846 icmp.SetType(header.ICMPv4DstUnreachable) 847 icmp.SetCode(c.code) 848 icmp.SetIdent(0xdead) 849 icmp.SetSequence(0xbeef) 850 851 // Create the inner IPv4 header. 852 ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) 853 ip.Encode(&header.IPv4Fields{ 854 TotalLength: 100, 855 TTL: 20, 856 Protocol: 10, 857 FragmentOffset: c.fragmentOffset, 858 SrcAddr: localIPv4Addr, 859 DstAddr: remoteIPv4Addr, 860 }) 861 ip.SetChecksum(^ip.CalculateChecksum()) 862 863 // Make payload be non-zero. 864 for i := dataOffset; i < len(view); i++ { 865 view[i] = uint8(i) 866 } 867 868 icmp.SetChecksum(0) 869 checksum := ^header.Checksum(icmp, 0 /* initial */) 870 icmp.SetChecksum(checksum) 871 872 // Give packet to IPv4 endpoint, dispatcher will validate that 873 // it's ok. 874 nic.testObject.protocol = 10 875 nic.testObject.srcAddr = remoteIPv4Addr 876 nic.testObject.dstAddr = localIPv4Addr 877 nic.testObject.contents = view[dataOffset:] 878 nic.testObject.transErr = c.transErr 879 880 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 881 if !ok { 882 t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") 883 } 884 addr := localIPv4Addr.WithPrefix() 885 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { 886 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) 887 } else { 888 ep.DecRef() 889 } 890 891 pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) 892 ep.HandlePacket(pkt) 893 if want := c.expectedCount; nic.testObject.controlCalls != want { 894 t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) 895 } 896 }) 897 } 898 } 899 900 func TestIPv4FragmentationReceive(t *testing.T) { 901 s := stack.New(stack.Options{ 902 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 903 }) 904 proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) 905 nic := testInterface{ 906 testObject: testObject{ 907 t: t, 908 v4: true, 909 }, 910 } 911 ep := proto.NewEndpoint(&nic, &nic.testObject) 912 defer ep.Close() 913 914 if err := ep.Enable(); err != nil { 915 t.Fatalf("ep.Enable(): %s", err) 916 } 917 918 totalLen := header.IPv4MinimumSize + 24 919 920 frag1 := buffer.NewView(totalLen) 921 ip1 := header.IPv4(frag1) 922 ip1.Encode(&header.IPv4Fields{ 923 TotalLength: uint16(totalLen), 924 TTL: 20, 925 Protocol: 10, 926 FragmentOffset: 0, 927 Flags: header.IPv4FlagMoreFragments, 928 SrcAddr: remoteIPv4Addr, 929 DstAddr: localIPv4Addr, 930 }) 931 ip1.SetChecksum(^ip1.CalculateChecksum()) 932 933 // Make payload be non-zero. 934 for i := header.IPv4MinimumSize; i < totalLen; i++ { 935 frag1[i] = uint8(i) 936 } 937 938 frag2 := buffer.NewView(totalLen) 939 ip2 := header.IPv4(frag2) 940 ip2.Encode(&header.IPv4Fields{ 941 TotalLength: uint16(totalLen), 942 TTL: 20, 943 Protocol: 10, 944 FragmentOffset: 24, 945 SrcAddr: remoteIPv4Addr, 946 DstAddr: localIPv4Addr, 947 }) 948 ip2.SetChecksum(^ip2.CalculateChecksum()) 949 950 // Make payload be non-zero. 951 for i := header.IPv4MinimumSize; i < totalLen; i++ { 952 frag2[i] = uint8(i) 953 } 954 955 // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. 956 nic.testObject.protocol = 10 957 nic.testObject.srcAddr = remoteIPv4Addr 958 nic.testObject.dstAddr = localIPv4Addr 959 nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) 960 961 // Send first segment. 962 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 963 Data: frag1.ToVectorisedView(), 964 }) 965 966 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 967 if !ok { 968 t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") 969 } 970 addr := localIPv4Addr.WithPrefix() 971 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { 972 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) 973 } else { 974 ep.DecRef() 975 } 976 977 ep.HandlePacket(pkt) 978 if nic.testObject.dataCalls != 0 { 979 t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls) 980 } 981 if nic.testObject.rawCalls != 0 { 982 t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls) 983 } 984 985 // Send second segment. 986 pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ 987 Data: frag2.ToVectorisedView(), 988 }) 989 ep.HandlePacket(pkt) 990 if nic.testObject.dataCalls != 1 { 991 t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) 992 } 993 if nic.testObject.rawCalls != 1 { 994 t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) 995 } 996 } 997 998 func TestIPv6Send(t *testing.T) { 999 s := buildDummyStack(t) 1000 proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) 1001 nic := testInterface{ 1002 testObject: testObject{ 1003 t: t, 1004 }, 1005 } 1006 ep := proto.NewEndpoint(&nic, nil) 1007 defer ep.Close() 1008 1009 if err := ep.Enable(); err != nil { 1010 t.Fatalf("ep.Enable(): %s", err) 1011 } 1012 1013 // Allocate and initialize the payload view. 1014 payload := buffer.NewView(100) 1015 for i := 0; i < len(payload); i++ { 1016 payload[i] = uint8(i) 1017 } 1018 1019 // Setup the packet buffer. 1020 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1021 ReserveHeaderBytes: int(ep.MaxHeaderLength()), 1022 Data: payload.ToVectorisedView(), 1023 }) 1024 1025 // Issue the write. 1026 nic.testObject.protocol = 123 1027 nic.testObject.srcAddr = localIPv6Addr 1028 nic.testObject.dstAddr = remoteIPv6Addr 1029 nic.testObject.contents = payload 1030 1031 r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) 1032 if err != nil { 1033 t.Fatalf("could not find route: %v", err) 1034 } 1035 if err := ep.WritePacket(r, stack.NetworkHeaderParams{ 1036 Protocol: 123, 1037 TTL: 123, 1038 TOS: stack.DefaultTOS, 1039 }, pkt); err != nil { 1040 t.Fatalf("WritePacket failed: %v", err) 1041 } 1042 } 1043 1044 func TestIPv6ReceiveControl(t *testing.T) { 1045 const ( 1046 mtu = 0xffff 1047 outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" 1048 dataLen = 8 1049 ) 1050 1051 newUint16 := func(v uint16) *uint16 { return &v } 1052 1053 portUnreachableTransErr := transportError{ 1054 origin: tcpip.SockExtErrorOriginICMP6, 1055 typ: uint8(header.ICMPv6DstUnreachable), 1056 code: uint8(header.ICMPv6PortUnreachable), 1057 kind: stack.DestinationPortUnreachableTransportError, 1058 } 1059 1060 cases := []struct { 1061 name string 1062 expectedCount int 1063 fragmentOffset *uint16 1064 typ header.ICMPv6Type 1065 code header.ICMPv6Code 1066 transErr transportError 1067 trunc int 1068 }{ 1069 { 1070 name: "PacketTooBig", 1071 expectedCount: 1, 1072 fragmentOffset: nil, 1073 typ: header.ICMPv6PacketTooBig, 1074 code: header.ICMPv6UnusedCode, 1075 transErr: transportError{ 1076 origin: tcpip.SockExtErrorOriginICMP6, 1077 typ: uint8(header.ICMPv6PacketTooBig), 1078 code: uint8(header.ICMPv6UnusedCode), 1079 info: mtu, 1080 kind: stack.PacketTooBigTransportError, 1081 }, 1082 trunc: 0, 1083 }, 1084 { 1085 name: "Truncated (missing offending packet's IPv6 header)", 1086 expectedCount: 0, 1087 fragmentOffset: nil, 1088 typ: header.ICMPv6PacketTooBig, 1089 code: header.ICMPv6UnusedCode, 1090 trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, 1091 }, 1092 { 1093 name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", 1094 expectedCount: 0, 1095 fragmentOffset: nil, 1096 typ: header.ICMPv6PacketTooBig, 1097 code: header.ICMPv6UnusedCode, 1098 trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, 1099 }, 1100 { 1101 name: "Truncated (partial offending packet's data)", 1102 expectedCount: 0, 1103 fragmentOffset: nil, 1104 typ: header.ICMPv6PacketTooBig, 1105 code: header.ICMPv6UnusedCode, 1106 trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, 1107 }, 1108 { 1109 name: "Port unreachable", 1110 expectedCount: 1, 1111 fragmentOffset: nil, 1112 typ: header.ICMPv6DstUnreachable, 1113 code: header.ICMPv6PortUnreachable, 1114 transErr: portUnreachableTransErr, 1115 trunc: 0, 1116 }, 1117 { 1118 name: "Truncated DstPortUnreachable (partial offending packet's IP header)", 1119 expectedCount: 0, 1120 fragmentOffset: nil, 1121 typ: header.ICMPv6DstUnreachable, 1122 code: header.ICMPv6PortUnreachable, 1123 trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, 1124 }, 1125 { 1126 name: "DstPortUnreachable for Fragmented, zero offset", 1127 expectedCount: 1, 1128 fragmentOffset: newUint16(0), 1129 typ: header.ICMPv6DstUnreachable, 1130 code: header.ICMPv6PortUnreachable, 1131 transErr: portUnreachableTransErr, 1132 trunc: 0, 1133 }, 1134 { 1135 name: "DstPortUnreachable for Non-zero fragment offset", 1136 expectedCount: 0, 1137 fragmentOffset: newUint16(100), 1138 typ: header.ICMPv6DstUnreachable, 1139 code: header.ICMPv6PortUnreachable, 1140 transErr: portUnreachableTransErr, 1141 trunc: 0, 1142 }, 1143 { 1144 name: "Zero-length packet", 1145 expectedCount: 0, 1146 fragmentOffset: nil, 1147 typ: header.ICMPv6DstUnreachable, 1148 code: header.ICMPv6PortUnreachable, 1149 trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, 1150 }, 1151 } 1152 for _, c := range cases { 1153 t.Run(c.name, func(t *testing.T) { 1154 s := buildDummyStack(t) 1155 proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) 1156 nic := testInterface{ 1157 testObject: testObject{ 1158 t: t, 1159 }, 1160 } 1161 ep := proto.NewEndpoint(&nic, &nic.testObject) 1162 defer ep.Close() 1163 1164 if err := ep.Enable(); err != nil { 1165 t.Fatalf("ep.Enable(): %s", err) 1166 } 1167 1168 dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize 1169 if c.fragmentOffset != nil { 1170 dataOffset += header.IPv6FragmentHeaderSize 1171 } 1172 view := buffer.NewView(dataOffset + dataLen) 1173 1174 // Create the outer IPv6 header. 1175 ip := header.IPv6(view) 1176 ip.Encode(&header.IPv6Fields{ 1177 PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), 1178 TransportProtocol: header.ICMPv6ProtocolNumber, 1179 HopLimit: 20, 1180 SrcAddr: outerSrcAddr, 1181 DstAddr: localIPv6Addr, 1182 }) 1183 1184 // Create the ICMP header. 1185 icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) 1186 icmp.SetType(c.typ) 1187 icmp.SetCode(c.code) 1188 icmp.SetIdent(0xdead) 1189 icmp.SetSequence(0xbeef) 1190 1191 var extHdrs header.IPv6ExtHdrSerializer 1192 // Build the fragmentation header if needed. 1193 if c.fragmentOffset != nil { 1194 extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ 1195 FragmentOffset: *c.fragmentOffset, 1196 M: true, 1197 Identification: 0x12345678, 1198 }) 1199 } 1200 1201 // Create the inner IPv6 header. 1202 ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) 1203 ip.Encode(&header.IPv6Fields{ 1204 PayloadLength: 100, 1205 TransportProtocol: 10, 1206 HopLimit: 20, 1207 SrcAddr: localIPv6Addr, 1208 DstAddr: remoteIPv6Addr, 1209 ExtensionHeaders: extHdrs, 1210 }) 1211 1212 // Make payload be non-zero. 1213 for i := dataOffset; i < len(view); i++ { 1214 view[i] = uint8(i) 1215 } 1216 1217 // Give packet to IPv6 endpoint, dispatcher will validate that 1218 // it's ok. 1219 nic.testObject.protocol = 10 1220 nic.testObject.srcAddr = remoteIPv6Addr 1221 nic.testObject.dstAddr = localIPv6Addr 1222 nic.testObject.contents = view[dataOffset:] 1223 nic.testObject.transErr = c.transErr 1224 1225 // Set ICMPv6 checksum. 1226 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 1227 Header: icmp, 1228 Src: outerSrcAddr, 1229 Dst: localIPv6Addr, 1230 })) 1231 1232 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 1233 if !ok { 1234 t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") 1235 } 1236 addr := localIPv6Addr.WithPrefix() 1237 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { 1238 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) 1239 } else { 1240 ep.DecRef() 1241 } 1242 pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) 1243 ep.HandlePacket(pkt) 1244 if want := c.expectedCount; nic.testObject.controlCalls != want { 1245 t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) 1246 } 1247 }) 1248 } 1249 } 1250 1251 // truncatedPacket returns a PacketBuffer based on a truncated view. If view, 1252 // after truncation, is large enough to hold a network header, it makes part of 1253 // view the packet's NetworkHeader and the rest its Data. Otherwise all of view 1254 // becomes Data. 1255 func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { 1256 v := view[:len(view)-trunc] 1257 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1258 Data: v.ToVectorisedView(), 1259 }) 1260 return pkt 1261 } 1262 1263 func TestWriteHeaderIncludedPacket(t *testing.T) { 1264 const ( 1265 nicID = 1 1266 transportProto = 5 1267 1268 dataLen = 4 1269 ) 1270 1271 dataBuf := [dataLen]byte{1, 2, 3, 4} 1272 data := dataBuf[:] 1273 1274 ipv4Options := header.IPv4OptionsSerializer{ 1275 &header.IPv4SerializableListEndOption{}, 1276 &header.IPv4SerializableNOPOption{}, 1277 &header.IPv4SerializableListEndOption{}, 1278 &header.IPv4SerializableNOPOption{}, 1279 } 1280 1281 expectOptions := header.IPv4Options{ 1282 byte(header.IPv4OptionListEndType), 1283 byte(header.IPv4OptionNOPType), 1284 byte(header.IPv4OptionListEndType), 1285 byte(header.IPv4OptionNOPType), 1286 } 1287 1288 ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} 1289 ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] 1290 1291 var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte 1292 ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] 1293 if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { 1294 t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) 1295 } 1296 if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { 1297 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1298 } 1299 1300 tests := []struct { 1301 name string 1302 protoFactory stack.NetworkProtocolFactory 1303 protoNum tcpip.NetworkProtocolNumber 1304 nicAddr tcpip.Address 1305 remoteAddr tcpip.Address 1306 pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView 1307 checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) 1308 expectedErr tcpip.Error 1309 }{ 1310 { 1311 name: "IPv4", 1312 protoFactory: ipv4.NewProtocol, 1313 protoNum: ipv4.ProtocolNumber, 1314 nicAddr: localIPv4Addr, 1315 remoteAddr: remoteIPv4Addr, 1316 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1317 totalLen := header.IPv4MinimumSize + len(data) 1318 hdr := buffer.NewPrependable(totalLen) 1319 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1320 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1321 } 1322 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 1323 ip.Encode(&header.IPv4Fields{ 1324 Protocol: transportProto, 1325 TTL: ipv4.DefaultTTL, 1326 SrcAddr: src, 1327 DstAddr: remoteIPv4Addr, 1328 }) 1329 return hdr.View().ToVectorisedView() 1330 }, 1331 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1332 if src == header.IPv4Any { 1333 src = localIPv4Addr 1334 } 1335 1336 netHdr := pkt.NetworkHeader() 1337 1338 if len(netHdr.View()) != header.IPv4MinimumSize { 1339 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) 1340 } 1341 1342 checker.IPv4(t, stack.PayloadSince(netHdr), 1343 checker.SrcAddr(src), 1344 checker.DstAddr(remoteIPv4Addr), 1345 checker.IPv4HeaderLength(header.IPv4MinimumSize), 1346 checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), 1347 checker.IPPayload(data), 1348 ) 1349 }, 1350 }, 1351 { 1352 name: "IPv4 with IHL too small", 1353 protoFactory: ipv4.NewProtocol, 1354 protoNum: ipv4.ProtocolNumber, 1355 nicAddr: localIPv4Addr, 1356 remoteAddr: remoteIPv4Addr, 1357 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1358 totalLen := header.IPv4MinimumSize + len(data) 1359 hdr := buffer.NewPrependable(totalLen) 1360 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1361 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1362 } 1363 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 1364 ip.Encode(&header.IPv4Fields{ 1365 Protocol: transportProto, 1366 TTL: ipv4.DefaultTTL, 1367 SrcAddr: src, 1368 DstAddr: remoteIPv4Addr, 1369 }) 1370 ip.SetHeaderLength(header.IPv4MinimumSize - 1) 1371 return hdr.View().ToVectorisedView() 1372 }, 1373 expectedErr: &tcpip.ErrMalformedHeader{}, 1374 }, 1375 { 1376 name: "IPv4 too small", 1377 protoFactory: ipv4.NewProtocol, 1378 protoNum: ipv4.ProtocolNumber, 1379 nicAddr: localIPv4Addr, 1380 remoteAddr: remoteIPv4Addr, 1381 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1382 ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) 1383 ip.Encode(&header.IPv4Fields{ 1384 Protocol: transportProto, 1385 TTL: ipv4.DefaultTTL, 1386 SrcAddr: src, 1387 DstAddr: remoteIPv4Addr, 1388 }) 1389 return buffer.View(ip[:len(ip)-1]).ToVectorisedView() 1390 }, 1391 expectedErr: &tcpip.ErrMalformedHeader{}, 1392 }, 1393 { 1394 name: "IPv4 minimum size", 1395 protoFactory: ipv4.NewProtocol, 1396 protoNum: ipv4.ProtocolNumber, 1397 nicAddr: localIPv4Addr, 1398 remoteAddr: remoteIPv4Addr, 1399 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1400 ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) 1401 ip.Encode(&header.IPv4Fields{ 1402 Protocol: transportProto, 1403 TTL: ipv4.DefaultTTL, 1404 SrcAddr: src, 1405 DstAddr: remoteIPv4Addr, 1406 }) 1407 return buffer.View(ip).ToVectorisedView() 1408 }, 1409 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1410 if src == header.IPv4Any { 1411 src = localIPv4Addr 1412 } 1413 1414 netHdr := pkt.NetworkHeader() 1415 1416 if len(netHdr.View()) != header.IPv4MinimumSize { 1417 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) 1418 } 1419 1420 checker.IPv4(t, stack.PayloadSince(netHdr), 1421 checker.SrcAddr(src), 1422 checker.DstAddr(remoteIPv4Addr), 1423 checker.IPv4HeaderLength(header.IPv4MinimumSize), 1424 checker.IPFullLength(header.IPv4MinimumSize), 1425 checker.IPPayload(nil), 1426 ) 1427 }, 1428 }, 1429 { 1430 name: "IPv4 with options", 1431 protoFactory: ipv4.NewProtocol, 1432 protoNum: ipv4.ProtocolNumber, 1433 nicAddr: localIPv4Addr, 1434 remoteAddr: remoteIPv4Addr, 1435 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1436 ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) 1437 totalLen := ipHdrLen + len(data) 1438 hdr := buffer.NewPrependable(totalLen) 1439 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1440 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1441 } 1442 ip := header.IPv4(hdr.Prepend(ipHdrLen)) 1443 ip.Encode(&header.IPv4Fields{ 1444 Protocol: transportProto, 1445 TTL: ipv4.DefaultTTL, 1446 SrcAddr: src, 1447 DstAddr: remoteIPv4Addr, 1448 Options: ipv4Options, 1449 }) 1450 return hdr.View().ToVectorisedView() 1451 }, 1452 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1453 if src == header.IPv4Any { 1454 src = localIPv4Addr 1455 } 1456 1457 netHdr := pkt.NetworkHeader() 1458 1459 hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) 1460 if len(netHdr.View()) != hdrLen { 1461 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) 1462 } 1463 1464 checker.IPv4(t, stack.PayloadSince(netHdr), 1465 checker.SrcAddr(src), 1466 checker.DstAddr(remoteIPv4Addr), 1467 checker.IPv4HeaderLength(hdrLen), 1468 checker.IPFullLength(uint16(hdrLen+len(data))), 1469 checker.IPv4Options(expectOptions), 1470 checker.IPPayload(data), 1471 ) 1472 }, 1473 }, 1474 { 1475 name: "IPv4 with options and data across views", 1476 protoFactory: ipv4.NewProtocol, 1477 protoNum: ipv4.ProtocolNumber, 1478 nicAddr: localIPv4Addr, 1479 remoteAddr: remoteIPv4Addr, 1480 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1481 ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) 1482 ip.Encode(&header.IPv4Fields{ 1483 Protocol: transportProto, 1484 TTL: ipv4.DefaultTTL, 1485 SrcAddr: src, 1486 DstAddr: remoteIPv4Addr, 1487 Options: ipv4Options, 1488 }) 1489 vv := buffer.View(ip).ToVectorisedView() 1490 vv.AppendView(data) 1491 return vv 1492 }, 1493 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1494 if src == header.IPv4Any { 1495 src = localIPv4Addr 1496 } 1497 1498 netHdr := pkt.NetworkHeader() 1499 1500 hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) 1501 if len(netHdr.View()) != hdrLen { 1502 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) 1503 } 1504 1505 checker.IPv4(t, stack.PayloadSince(netHdr), 1506 checker.SrcAddr(src), 1507 checker.DstAddr(remoteIPv4Addr), 1508 checker.IPv4HeaderLength(hdrLen), 1509 checker.IPFullLength(uint16(hdrLen+len(data))), 1510 checker.IPv4Options(expectOptions), 1511 checker.IPPayload(data), 1512 ) 1513 }, 1514 }, 1515 { 1516 name: "IPv6", 1517 protoFactory: ipv6.NewProtocol, 1518 protoNum: ipv6.ProtocolNumber, 1519 nicAddr: localIPv6Addr, 1520 remoteAddr: remoteIPv6Addr, 1521 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1522 totalLen := header.IPv6MinimumSize + len(data) 1523 hdr := buffer.NewPrependable(totalLen) 1524 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1525 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1526 } 1527 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 1528 ip.Encode(&header.IPv6Fields{ 1529 TransportProtocol: transportProto, 1530 HopLimit: ipv6.DefaultTTL, 1531 SrcAddr: src, 1532 DstAddr: remoteIPv6Addr, 1533 }) 1534 return hdr.View().ToVectorisedView() 1535 }, 1536 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1537 if src == header.IPv6Any { 1538 src = localIPv6Addr 1539 } 1540 1541 netHdr := pkt.NetworkHeader() 1542 1543 if len(netHdr.View()) != header.IPv6MinimumSize { 1544 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) 1545 } 1546 1547 checker.IPv6(t, stack.PayloadSince(netHdr), 1548 checker.SrcAddr(src), 1549 checker.DstAddr(remoteIPv6Addr), 1550 checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), 1551 checker.IPPayload(data), 1552 ) 1553 }, 1554 }, 1555 { 1556 name: "IPv6 with extension header", 1557 protoFactory: ipv6.NewProtocol, 1558 protoNum: ipv6.ProtocolNumber, 1559 nicAddr: localIPv6Addr, 1560 remoteAddr: remoteIPv6Addr, 1561 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1562 totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) 1563 hdr := buffer.NewPrependable(totalLen) 1564 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1565 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1566 } 1567 if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { 1568 t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) 1569 } 1570 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 1571 ip.Encode(&header.IPv6Fields{ 1572 // NB: we're lying about transport protocol here to verify the raw 1573 // fragment header bytes. 1574 TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), 1575 HopLimit: ipv6.DefaultTTL, 1576 SrcAddr: src, 1577 DstAddr: remoteIPv6Addr, 1578 }) 1579 return hdr.View().ToVectorisedView() 1580 }, 1581 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1582 if src == header.IPv6Any { 1583 src = localIPv6Addr 1584 } 1585 1586 netHdr := pkt.NetworkHeader() 1587 1588 if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { 1589 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) 1590 } 1591 1592 checker.IPv6(t, stack.PayloadSince(netHdr), 1593 checker.SrcAddr(src), 1594 checker.DstAddr(remoteIPv6Addr), 1595 checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), 1596 checker.IPPayload(ipv6PayloadWithExtHdr), 1597 ) 1598 }, 1599 }, 1600 { 1601 name: "IPv6 minimum size", 1602 protoFactory: ipv6.NewProtocol, 1603 protoNum: ipv6.ProtocolNumber, 1604 nicAddr: localIPv6Addr, 1605 remoteAddr: remoteIPv6Addr, 1606 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1607 ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) 1608 ip.Encode(&header.IPv6Fields{ 1609 TransportProtocol: transportProto, 1610 HopLimit: ipv6.DefaultTTL, 1611 SrcAddr: src, 1612 DstAddr: remoteIPv6Addr, 1613 }) 1614 return buffer.View(ip).ToVectorisedView() 1615 }, 1616 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1617 if src == header.IPv6Any { 1618 src = localIPv6Addr 1619 } 1620 1621 netHdr := pkt.NetworkHeader() 1622 1623 if len(netHdr.View()) != header.IPv6MinimumSize { 1624 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) 1625 } 1626 1627 checker.IPv6(t, stack.PayloadSince(netHdr), 1628 checker.SrcAddr(src), 1629 checker.DstAddr(remoteIPv6Addr), 1630 checker.IPFullLength(header.IPv6MinimumSize), 1631 checker.IPPayload(nil), 1632 ) 1633 }, 1634 }, 1635 { 1636 name: "IPv6 too small", 1637 protoFactory: ipv6.NewProtocol, 1638 protoNum: ipv6.ProtocolNumber, 1639 nicAddr: localIPv6Addr, 1640 remoteAddr: remoteIPv6Addr, 1641 pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { 1642 ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) 1643 ip.Encode(&header.IPv6Fields{ 1644 TransportProtocol: transportProto, 1645 HopLimit: ipv6.DefaultTTL, 1646 SrcAddr: src, 1647 DstAddr: remoteIPv4Addr, 1648 }) 1649 return buffer.View(ip[:len(ip)-1]).ToVectorisedView() 1650 }, 1651 expectedErr: &tcpip.ErrMalformedHeader{}, 1652 }, 1653 } 1654 1655 for _, test := range tests { 1656 t.Run(test.name, func(t *testing.T) { 1657 subTests := []struct { 1658 name string 1659 srcAddr tcpip.Address 1660 }{ 1661 { 1662 name: "unspecified source", 1663 srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), 1664 }, 1665 { 1666 name: "random source", 1667 srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), 1668 }, 1669 } 1670 1671 for _, subTest := range subTests { 1672 t.Run(subTest.name, func(t *testing.T) { 1673 s := stack.New(stack.Options{ 1674 NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, 1675 }) 1676 e := channel.New(1, header.IPv6MinimumMTU, "") 1677 if err := s.CreateNIC(nicID, e); err != nil { 1678 t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) 1679 } 1680 if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { 1681 t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) 1682 } 1683 1684 s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) 1685 1686 r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) 1687 if err != nil { 1688 t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) 1689 } 1690 defer r.Release() 1691 1692 { 1693 err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ 1694 Data: test.pktGen(t, subTest.srcAddr), 1695 })) 1696 if diff := cmp.Diff(test.expectedErr, err); diff != "" { 1697 t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) 1698 } 1699 } 1700 1701 if test.expectedErr != nil { 1702 return 1703 } 1704 1705 pkt, ok := e.Read() 1706 if !ok { 1707 t.Fatal("expected a packet to be written") 1708 } 1709 test.checker(t, pkt.Pkt, subTest.srcAddr) 1710 }) 1711 } 1712 }) 1713 } 1714 } 1715 1716 // Test that the included data in an ICMP error packet conforms to the 1717 // requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 1718 func TestICMPInclusionSize(t *testing.T) { 1719 const ( 1720 replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize 1721 replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize 1722 targetSize4 = header.IPv4MinimumProcessableDatagramSize 1723 targetSize6 = header.IPv6MinimumMTU 1724 // A protocol number that will cause an error response. 1725 reservedProtocol = 254 1726 ) 1727 1728 // IPv4 function to create a IP packet and send it to the stack. 1729 // The packet should generate an error response. We can do that by using an 1730 // unknown transport protocol (254). 1731 rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { 1732 totalLen := header.IPv4MinimumSize + len(payload) 1733 hdr := buffer.NewPrependable(header.IPv4MinimumSize) 1734 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 1735 ip.Encode(&header.IPv4Fields{ 1736 TotalLength: uint16(totalLen), 1737 Protocol: reservedProtocol, 1738 TTL: ipv4.DefaultTTL, 1739 SrcAddr: src, 1740 DstAddr: localIPv4Addr, 1741 }) 1742 ip.SetChecksum(^ip.CalculateChecksum()) 1743 vv := hdr.View().ToVectorisedView() 1744 vv.AppendView(buffer.View(payload)) 1745 // Take a copy before InjectInbound takes ownership of vv 1746 // as vv may be changed during the call. 1747 v := vv.ToView() 1748 e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 1749 Data: vv, 1750 })) 1751 return v 1752 } 1753 1754 // IPv6 function to create a packet and send it to the stack. 1755 // The packet should be errant in a way that causes the stack to send an 1756 // ICMP error response and have enough data to allow the testing of the 1757 // inclusion of the errant packet. Use `unknown next header' to generate 1758 // the error. 1759 rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { 1760 hdr := buffer.NewPrependable(header.IPv6MinimumSize) 1761 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 1762 ip.Encode(&header.IPv6Fields{ 1763 PayloadLength: uint16(len(payload)), 1764 TransportProtocol: reservedProtocol, 1765 HopLimit: ipv6.DefaultTTL, 1766 SrcAddr: src, 1767 DstAddr: localIPv6Addr, 1768 }) 1769 vv := hdr.View().ToVectorisedView() 1770 vv.AppendView(buffer.View(payload)) 1771 // Take a copy before InjectInbound takes ownership of vv 1772 // as vv may be changed during the call. 1773 v := vv.ToView() 1774 1775 e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 1776 Data: vv, 1777 })) 1778 return v 1779 } 1780 1781 v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { 1782 // We already know the entire packet is the right size so we can use its 1783 // length to calculate the right payload size to check. 1784 expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize 1785 checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), 1786 checker.SrcAddr(localIPv4Addr), 1787 checker.DstAddr(remoteIPv4Addr), 1788 checker.IPv4HeaderLength(header.IPv4MinimumSize), 1789 checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), 1790 checker.ICMPv4( 1791 checker.ICMPv4Checksum(), 1792 checker.ICMPv4Type(header.ICMPv4DstUnreachable), 1793 checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), 1794 checker.ICMPv4Payload(payload[:expectedPayloadLength]), 1795 ), 1796 ) 1797 } 1798 1799 v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { 1800 // We already know the entire packet is the right size so we can use its 1801 // length to calculate the right payload size to check. 1802 expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize 1803 checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), 1804 checker.SrcAddr(localIPv6Addr), 1805 checker.DstAddr(remoteIPv6Addr), 1806 checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), 1807 checker.ICMPv6( 1808 checker.ICMPv6Type(header.ICMPv6ParamProblem), 1809 checker.ICMPv6Code(header.ICMPv6UnknownHeader), 1810 checker.ICMPv6Payload(payload[:expectedPayloadLength]), 1811 ), 1812 ) 1813 } 1814 tests := []struct { 1815 name string 1816 srcAddress tcpip.Address 1817 injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View 1818 checker func(*testing.T, *stack.PacketBuffer, buffer.View) 1819 payloadLength int // Not including IP header. 1820 linkMTU uint32 // Largest IP packet that the link can send as payload. 1821 replyLength int // Total size of IP/ICMP packet expected back. 1822 }{ 1823 { 1824 name: "IPv4 exact match", 1825 srcAddress: remoteIPv4Addr, 1826 injector: rxIPv4Bad, 1827 checker: v4Checker, 1828 payloadLength: targetSize4 - replyHeaderLength4, 1829 linkMTU: targetSize4, 1830 replyLength: targetSize4, 1831 }, 1832 { 1833 name: "IPv4 larger MTU", 1834 srcAddress: remoteIPv4Addr, 1835 injector: rxIPv4Bad, 1836 checker: v4Checker, 1837 payloadLength: targetSize4, 1838 linkMTU: targetSize4 + 1000, 1839 replyLength: targetSize4, 1840 }, 1841 { 1842 name: "IPv4 smaller MTU", 1843 srcAddress: remoteIPv4Addr, 1844 injector: rxIPv4Bad, 1845 checker: v4Checker, 1846 payloadLength: targetSize4, 1847 linkMTU: targetSize4 - 50, 1848 replyLength: targetSize4 - 50, 1849 }, 1850 { 1851 name: "IPv4 payload exceeds", 1852 srcAddress: remoteIPv4Addr, 1853 injector: rxIPv4Bad, 1854 checker: v4Checker, 1855 payloadLength: targetSize4 + 10, 1856 linkMTU: targetSize4, 1857 replyLength: targetSize4, 1858 }, 1859 { 1860 name: "IPv4 1 byte less", 1861 srcAddress: remoteIPv4Addr, 1862 injector: rxIPv4Bad, 1863 checker: v4Checker, 1864 payloadLength: targetSize4 - replyHeaderLength4 - 1, 1865 linkMTU: targetSize4, 1866 replyLength: targetSize4 - 1, 1867 }, 1868 { 1869 name: "IPv4 No payload", 1870 srcAddress: remoteIPv4Addr, 1871 injector: rxIPv4Bad, 1872 checker: v4Checker, 1873 payloadLength: 0, 1874 linkMTU: targetSize4, 1875 replyLength: replyHeaderLength4, 1876 }, 1877 { 1878 name: "IPv6 exact match", 1879 srcAddress: remoteIPv6Addr, 1880 injector: rxIPv6Bad, 1881 checker: v6Checker, 1882 payloadLength: targetSize6 - replyHeaderLength6, 1883 linkMTU: targetSize6, 1884 replyLength: targetSize6, 1885 }, 1886 { 1887 name: "IPv6 larger MTU", 1888 srcAddress: remoteIPv6Addr, 1889 injector: rxIPv6Bad, 1890 checker: v6Checker, 1891 payloadLength: targetSize6, 1892 linkMTU: targetSize6 + 400, 1893 replyLength: targetSize6, 1894 }, 1895 // NB. No "smaller MTU" test here as less than 1280 is not permitted 1896 // in IPv6. 1897 { 1898 name: "IPv6 payload exceeds", 1899 srcAddress: remoteIPv6Addr, 1900 injector: rxIPv6Bad, 1901 checker: v6Checker, 1902 payloadLength: targetSize6, 1903 linkMTU: targetSize6, 1904 replyLength: targetSize6, 1905 }, 1906 { 1907 name: "IPv6 1 byte less", 1908 srcAddress: remoteIPv6Addr, 1909 injector: rxIPv6Bad, 1910 checker: v6Checker, 1911 payloadLength: targetSize6 - replyHeaderLength6 - 1, 1912 linkMTU: targetSize6, 1913 replyLength: targetSize6 - 1, 1914 }, 1915 { 1916 name: "IPv6 no payload", 1917 srcAddress: remoteIPv6Addr, 1918 injector: rxIPv6Bad, 1919 checker: v6Checker, 1920 payloadLength: 0, 1921 linkMTU: targetSize6, 1922 replyLength: replyHeaderLength6, 1923 }, 1924 } 1925 1926 for _, test := range tests { 1927 t.Run(test.name, func(t *testing.T) { 1928 s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) 1929 // Allocate and initialize the payload view. 1930 payload := buffer.NewView(test.payloadLength) 1931 for i := 0; i < len(payload); i++ { 1932 payload[i] = uint8(i) 1933 } 1934 // Default routes for IPv4&6 so ICMP can find a route to the remote 1935 // node when attempting to send the ICMP error Reply. 1936 s.SetRouteTable([]tcpip.Route{ 1937 { 1938 Destination: header.IPv4EmptySubnet, 1939 NIC: nicID, 1940 }, 1941 { 1942 Destination: header.IPv6EmptySubnet, 1943 NIC: nicID, 1944 }, 1945 }) 1946 v := test.injector(e, test.srcAddress, payload) 1947 pkt, ok := e.Read() 1948 if !ok { 1949 t.Fatal("expected a packet to be written") 1950 } 1951 if got, want := pkt.Pkt.Size(), test.replyLength; got != want { 1952 t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) 1953 } 1954 test.checker(t, pkt.Pkt, v) 1955 }) 1956 } 1957 } 1958 1959 func TestJoinLeaveAllRoutersGroup(t *testing.T) { 1960 const nicID = 1 1961 1962 tests := []struct { 1963 name string 1964 netProto tcpip.NetworkProtocolNumber 1965 protoFactory stack.NetworkProtocolFactory 1966 allRoutersAddr tcpip.Address 1967 }{ 1968 { 1969 name: "IPv4", 1970 netProto: ipv4.ProtocolNumber, 1971 protoFactory: ipv4.NewProtocol, 1972 allRoutersAddr: header.IPv4AllRoutersGroup, 1973 }, 1974 { 1975 name: "IPv6 Interface Local", 1976 netProto: ipv6.ProtocolNumber, 1977 protoFactory: ipv6.NewProtocol, 1978 allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, 1979 }, 1980 { 1981 name: "IPv6 Link Local", 1982 netProto: ipv6.ProtocolNumber, 1983 protoFactory: ipv6.NewProtocol, 1984 allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, 1985 }, 1986 { 1987 name: "IPv6 Site Local", 1988 netProto: ipv6.ProtocolNumber, 1989 protoFactory: ipv6.NewProtocol, 1990 allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, 1991 }, 1992 } 1993 1994 for _, test := range tests { 1995 t.Run(test.name, func(t *testing.T) { 1996 for _, nicDisabled := range [...]bool{true, false} { 1997 t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { 1998 s := stack.New(stack.Options{ 1999 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 2000 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 2001 }) 2002 opts := stack.NICOptions{Disabled: nicDisabled} 2003 if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { 2004 t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) 2005 } 2006 2007 if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { 2008 t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) 2009 } else if got { 2010 t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) 2011 } 2012 2013 if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { 2014 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) 2015 } 2016 if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { 2017 t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) 2018 } else if !got { 2019 t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) 2020 } 2021 2022 if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { 2023 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) 2024 } 2025 if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { 2026 t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) 2027 } else if got { 2028 t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) 2029 } 2030 }) 2031 } 2032 }) 2033 } 2034 }