gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/ip_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ip_test 16 17 import ( 18 "bytes" 19 "fmt" 20 "strings" 21 "testing" 22 23 "github.com/google/go-cmp/cmp" 24 "gvisor.dev/gvisor/pkg/buffer" 25 "gvisor.dev/gvisor/pkg/refs" 26 "gvisor.dev/gvisor/pkg/sync" 27 "gvisor.dev/gvisor/pkg/tcpip" 28 "gvisor.dev/gvisor/pkg/tcpip/checker" 29 "gvisor.dev/gvisor/pkg/tcpip/checksum" 30 "gvisor.dev/gvisor/pkg/tcpip/header" 31 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 32 "gvisor.dev/gvisor/pkg/tcpip/link/loopback" 33 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 34 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 35 "gvisor.dev/gvisor/pkg/tcpip/prependable" 36 "gvisor.dev/gvisor/pkg/tcpip/stack" 37 "gvisor.dev/gvisor/pkg/tcpip/testutil" 38 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 39 "gvisor.dev/gvisor/pkg/tcpip/transport/raw" 40 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 41 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 42 "gvisor.dev/gvisor/pkg/waiter" 43 ) 44 45 const nicID = 1 46 47 var ( 48 localIPv4Addr = testutil.MustParse4("10.0.0.1") 49 remoteIPv4Addr = testutil.MustParse4("10.0.0.2") 50 ipv4SubnetAddr = testutil.MustParse4("10.0.0.0") 51 ipv4SubnetMask = testutil.MustParse4("255.255.255.0") 52 ipv4Gateway = testutil.MustParse4("10.0.0.3") 53 localIPv6Addr = testutil.MustParse6("a00::1") 54 remoteIPv6Addr = testutil.MustParse6("a00::2") 55 ipv6SubnetAddr = testutil.MustParse6("a00::") 56 ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00") 57 ipv6Gateway = testutil.MustParse6("a00::3") 58 ) 59 60 var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ 61 Address: localIPv4Addr, 62 PrefixLen: 24, 63 } 64 65 var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ 66 Address: localIPv6Addr, 67 PrefixLen: 120, 68 } 69 70 type transportError struct { 71 origin tcpip.SockErrOrigin 72 typ uint8 73 code uint8 74 info uint32 75 kind stack.TransportErrorKind 76 } 77 78 // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. 79 // The former is used to pretend that it's a link endpoint so that we can 80 // inspect packets written by the network endpoints. The latter is used to 81 // pretend that it's the network stack so that it can inspect incoming packets 82 // that have been handled by the network endpoints. 83 // 84 // Packets are checked by comparing their fields/values against the expected 85 // values stored in the test object itself. 86 type testObject struct { 87 t *testing.T 88 protocol tcpip.TransportProtocolNumber 89 contents []byte 90 srcAddr tcpip.Address 91 dstAddr tcpip.Address 92 v4 bool 93 transErr transportError 94 95 dataCalls int 96 controlCalls int 97 rawCalls int 98 } 99 100 // checkValues verifies that the transport protocol, data contents, src & dst 101 // addresses of a packet match what's expected. If any field doesn't match, the 102 // test fails. 103 func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v []byte, srcAddr, dstAddr tcpip.Address) { 104 if protocol != t.protocol { 105 t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) 106 } 107 108 if srcAddr != t.srcAddr { 109 t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) 110 } 111 112 if dstAddr != t.dstAddr { 113 t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) 114 } 115 116 if len(v) != len(t.contents) { 117 t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) 118 } 119 120 for i := range t.contents { 121 if t.contents[i] != v[i] { 122 t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) 123 } 124 } 125 } 126 127 // DeliverTransportPacket is called by network endpoints after parsing incoming 128 // packets. This is used by the test object to verify that the results of the 129 // parsing are expected. 130 func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { 131 netHdr := pkt.Network() 132 v := pkt.Data().AsRange().ToView() 133 defer v.Release() 134 t.checkValues(protocol, v.AsSlice(), netHdr.SourceAddress(), netHdr.DestinationAddress()) 135 t.dataCalls++ 136 return stack.TransportPacketHandled 137 } 138 139 // DeliverTransportError is called by network endpoints after parsing 140 // incoming control (ICMP) packets. This is used by the test object to verify 141 // that the results of the parsing are expected. 142 func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { 143 v := pkt.Data().AsRange().ToView() 144 defer v.Release() 145 t.checkValues(trans, v.AsSlice(), remote, local) 146 if diff := cmp.Diff( 147 t.transErr, 148 transportError{ 149 origin: transErr.Origin(), 150 typ: transErr.Type(), 151 code: transErr.Code(), 152 info: transErr.Info(), 153 kind: transErr.Kind(), 154 }, 155 cmp.AllowUnexported(transportError{}), 156 ); diff != "" { 157 t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) 158 } 159 t.controlCalls++ 160 } 161 162 func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { 163 t.rawCalls++ 164 } 165 166 // Attach is only implemented to satisfy the LinkEndpoint interface. 167 func (*testObject) Attach(stack.NetworkDispatcher) {} 168 169 // IsAttached implements stack.LinkEndpoint.IsAttached. 170 func (*testObject) IsAttached() bool { 171 return true 172 } 173 174 // MTU implements stack.LinkEndpoint.MTU. It just returns a constant that 175 // matches the linux loopback MTU. 176 func (*testObject) MTU() uint32 { 177 return 65536 178 } 179 180 // Capabilities implements stack.LinkEndpoint.Capabilities. 181 func (*testObject) Capabilities() stack.LinkEndpointCapabilities { 182 return 0 183 } 184 185 // MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. 186 func (*testObject) MaxHeaderLength() uint16 { 187 return 0 188 } 189 190 // LinkAddress returns the link address of this endpoint. 191 func (*testObject) LinkAddress() tcpip.LinkAddress { 192 return "" 193 } 194 195 // Wait implements stack.LinkEndpoint.Wait. 196 func (*testObject) Wait() {} 197 198 // WritePacket is called by network endpoints after producing a packet and 199 // writing it to the link endpoint. This is used by the test object to verify 200 // that the produced packet is as expected. 201 func (t *testObject) WritePacket(_ *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { 202 var prot tcpip.TransportProtocolNumber 203 var srcAddr tcpip.Address 204 var dstAddr tcpip.Address 205 206 if t.v4 { 207 h := header.IPv4(pkt.NetworkHeader().Slice()) 208 prot = tcpip.TransportProtocolNumber(h.Protocol()) 209 srcAddr = h.SourceAddress() 210 dstAddr = h.DestinationAddress() 211 212 } else { 213 h := header.IPv6(pkt.NetworkHeader().Slice()) 214 prot = tcpip.TransportProtocolNumber(h.NextHeader()) 215 srcAddr = h.SourceAddress() 216 dstAddr = h.DestinationAddress() 217 } 218 t.checkValues(prot, pkt.Data().AsRange().ToSlice(), srcAddr, dstAddr) 219 return nil 220 } 221 222 // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. 223 func (*testObject) ARPHardwareType() header.ARPHardwareType { 224 panic("not implemented") 225 } 226 227 // AddHeader implements stack.LinkEndpoint.AddHeader. 228 func (*testObject) AddHeader(*stack.PacketBuffer) { 229 panic("not implemented") 230 } 231 232 // ParseHeader implements stack.LinkEndpoint.ParseHeader. 233 func (*testObject) ParseHeader(*stack.PacketBuffer) bool { 234 panic("not implemented") 235 } 236 237 type testContext struct { 238 s *stack.Stack 239 } 240 241 func newTestContext() testContext { 242 s := stack.New(stack.Options{ 243 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 244 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 245 RawFactory: raw.EndpointFactory{}, 246 }) 247 return testContext{s: s} 248 } 249 250 func (ctx *testContext) cleanup() { 251 ctx.s.Close() 252 ctx.s.Wait() 253 refs.DoRepeatedLeakCheck() 254 } 255 256 func buildIPv4Route(ctx testContext, local, remote tcpip.Address) (*stack.Route, tcpip.Error) { 257 s := ctx.s 258 s.CreateNIC(nicID, loopback.New()) 259 protocolAddr := tcpip.ProtocolAddress{ 260 Protocol: ipv4.ProtocolNumber, 261 AddressWithPrefix: local.WithPrefix(), 262 } 263 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 264 return nil, err 265 } 266 s.SetRouteTable([]tcpip.Route{{ 267 Destination: header.IPv4EmptySubnet, 268 Gateway: ipv4Gateway, 269 NIC: 1, 270 }}) 271 272 return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) 273 } 274 275 func buildIPv6Route(ctx testContext, local, remote tcpip.Address) (*stack.Route, tcpip.Error) { 276 s := ctx.s 277 s.CreateNIC(nicID, loopback.New()) 278 protocolAddr := tcpip.ProtocolAddress{ 279 Protocol: ipv6.ProtocolNumber, 280 AddressWithPrefix: local.WithPrefix(), 281 } 282 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 283 return nil, err 284 } 285 s.SetRouteTable([]tcpip.Route{{ 286 Destination: header.IPv6EmptySubnet, 287 Gateway: ipv6Gateway, 288 NIC: 1, 289 }}) 290 291 return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) 292 } 293 294 func addLinkEndpointToStackWithMTU(t *testing.T, s *stack.Stack, mtu uint32) *channel.Endpoint { 295 t.Helper() 296 e := channel.New(1, mtu, "") 297 if err := s.CreateNIC(nicID, e); err != nil { 298 t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) 299 } 300 301 v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} 302 if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil { 303 t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err) 304 } 305 306 v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} 307 if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { 308 t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err) 309 } 310 311 return e 312 } 313 314 func addLinkEndpointToStack(t *testing.T, s *stack.Stack) *channel.Endpoint { 315 t.Helper() 316 return addLinkEndpointToStackWithMTU(t, s, header.IPv6MinimumMTU) 317 } 318 319 var _ stack.NetworkInterface = (*testInterface)(nil) 320 321 type testInterface struct { 322 testObject 323 324 mu struct { 325 sync.RWMutex 326 disabled bool 327 } 328 } 329 330 func (*testInterface) ID() tcpip.NICID { 331 return nicID 332 } 333 334 func (*testInterface) IsLoopback() bool { 335 return false 336 } 337 338 func (*testInterface) Name() string { 339 return "" 340 } 341 342 func (t *testInterface) Enabled() bool { 343 t.mu.RLock() 344 defer t.mu.RUnlock() 345 return !t.mu.disabled 346 } 347 348 func (*testInterface) Promiscuous() bool { 349 return false 350 } 351 352 func (*testInterface) Spoofing() bool { 353 return false 354 } 355 356 func (t *testInterface) setEnabled(v bool) { 357 t.mu.Lock() 358 defer t.mu.Unlock() 359 t.mu.disabled = !v 360 } 361 362 func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.PacketBuffer) tcpip.Error { 363 return &tcpip.ErrNotSupported{} 364 } 365 366 func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { 367 return nil 368 } 369 370 func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { 371 return nil 372 } 373 374 func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { 375 return tcpip.AddressWithPrefix{}, nil 376 } 377 378 func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { 379 return false 380 } 381 382 func TestSourceAddressValidation(t *testing.T) { 383 rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { 384 totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize 385 hdr := prependable.New(totalLen) 386 pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) 387 pkt.SetType(header.ICMPv4Echo) 388 pkt.SetCode(0) 389 pkt.SetChecksum(0) 390 pkt.SetChecksum(^checksum.Checksum(pkt, 0)) 391 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 392 ip.Encode(&header.IPv4Fields{ 393 TotalLength: uint16(totalLen), 394 Protocol: uint8(icmp.ProtocolNumber4), 395 TTL: ipv4.DefaultTTL, 396 SrcAddr: src, 397 DstAddr: localIPv4Addr, 398 }) 399 ip.SetChecksum(^ip.CalculateChecksum()) 400 401 pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ 402 Payload: buffer.MakeWithData(hdr.View()), 403 }) 404 e.InjectInbound(header.IPv4ProtocolNumber, pktBuf) 405 pktBuf.DecRef() 406 } 407 408 rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { 409 totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize 410 hdr := prependable.New(totalLen) 411 pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) 412 pkt.SetType(header.ICMPv6EchoRequest) 413 pkt.SetCode(0) 414 pkt.SetChecksum(0) 415 pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 416 Header: pkt, 417 Src: src, 418 Dst: localIPv6Addr, 419 })) 420 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 421 ip.Encode(&header.IPv6Fields{ 422 PayloadLength: header.ICMPv6MinimumSize, 423 TransportProtocol: icmp.ProtocolNumber6, 424 HopLimit: ipv6.DefaultTTL, 425 SrcAddr: src, 426 DstAddr: localIPv6Addr, 427 }) 428 pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ 429 Payload: buffer.MakeWithData(hdr.View()), 430 }) 431 e.InjectInbound(header.IPv6ProtocolNumber, pktBuf) 432 pktBuf.DecRef() 433 } 434 435 tests := []struct { 436 name string 437 srcAddress tcpip.Address 438 rxICMP func(*channel.Endpoint, tcpip.Address) 439 valid bool 440 }{ 441 { 442 name: "IPv4 valid", 443 srcAddress: tcpip.AddrFromSlice([]byte("\x01\x02\x03\x04")), 444 rxICMP: rxIPv4ICMP, 445 valid: true, 446 }, 447 { 448 name: "IPv6 valid", 449 srcAddress: tcpip.AddrFromSlice([]byte("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")), 450 rxICMP: rxIPv6ICMP, 451 valid: true, 452 }, 453 { 454 name: "IPv4 unspecified", 455 srcAddress: header.IPv4Any, 456 rxICMP: rxIPv4ICMP, 457 valid: true, 458 }, 459 { 460 name: "IPv6 unspecified", 461 srcAddress: header.IPv4Any, 462 rxICMP: rxIPv6ICMP, 463 valid: true, 464 }, 465 { 466 name: "IPv4 multicast", 467 srcAddress: tcpip.AddrFromSlice([]byte("\xe0\x00\x00\x01")), 468 rxICMP: rxIPv4ICMP, 469 valid: false, 470 }, 471 { 472 name: "IPv6 multicast", 473 srcAddress: tcpip.AddrFromSlice([]byte("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")), 474 rxICMP: rxIPv6ICMP, 475 valid: false, 476 }, 477 { 478 name: "IPv4 broadcast", 479 srcAddress: header.IPv4Broadcast, 480 rxICMP: rxIPv4ICMP, 481 valid: false, 482 }, 483 { 484 name: "IPv4 subnet broadcast", 485 srcAddress: func() tcpip.Address { 486 subnet := localIPv4AddrWithPrefix.Subnet() 487 return subnet.Broadcast() 488 }(), 489 rxICMP: rxIPv4ICMP, 490 valid: false, 491 }, 492 } 493 494 for _, test := range tests { 495 t.Run(test.name, func(t *testing.T) { 496 ctx := newTestContext() 497 defer ctx.cleanup() 498 s := ctx.s 499 500 e := addLinkEndpointToStack(t, s) 501 defer e.Close() 502 test.rxICMP(e, test.srcAddress) 503 504 var wantValid uint64 505 if test.valid { 506 wantValid = 1 507 } 508 509 if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { 510 t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) 511 } 512 if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { 513 t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) 514 } 515 }) 516 } 517 } 518 519 func TestEnableWhenNICDisabled(t *testing.T) { 520 tests := []struct { 521 name string 522 protocolFactory stack.NetworkProtocolFactory 523 protoNum tcpip.NetworkProtocolNumber 524 }{ 525 { 526 name: "IPv4", 527 protocolFactory: ipv4.NewProtocol, 528 protoNum: ipv4.ProtocolNumber, 529 }, 530 { 531 name: "IPv6", 532 protocolFactory: ipv6.NewProtocol, 533 protoNum: ipv6.ProtocolNumber, 534 }, 535 } 536 537 for _, test := range tests { 538 t.Run(test.name, func(t *testing.T) { 539 var nic testInterface 540 nic.setEnabled(false) 541 542 s := stack.New(stack.Options{ 543 NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, 544 }) 545 defer func() { 546 s.Close() 547 s.Wait() 548 }() 549 550 p := s.NetworkProtocolInstance(test.protoNum) 551 552 // We pass nil for all parameters except the NetworkInterface and Stack 553 // since Enable only depends on these. 554 ep := p.NewEndpoint(&nic, nil) 555 556 // The endpoint should initially be disabled, regardless the NIC's enabled 557 // status. 558 if ep.Enabled() { 559 t.Fatal("got ep.Enabled() = true, want = false") 560 } 561 nic.setEnabled(true) 562 if ep.Enabled() { 563 t.Fatal("got ep.Enabled() = true, want = false") 564 } 565 566 // Attempting to enable the endpoint while the NIC is disabled should 567 // fail. 568 nic.setEnabled(false) 569 err := ep.Enable() 570 if _, ok := err.(*tcpip.ErrNotPermitted); !ok { 571 t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) 572 } 573 // ep should consider the NIC's enabled status when determining its own 574 // enabled status so we "enable" the NIC to read just the endpoint's 575 // enabled status. 576 nic.setEnabled(true) 577 if ep.Enabled() { 578 t.Fatal("got ep.Enabled() = true, want = false") 579 } 580 581 // Enabling the interface after the NIC has been enabled should succeed. 582 if err := ep.Enable(); err != nil { 583 t.Fatalf("ep.Enable(): %s", err) 584 } 585 if !ep.Enabled() { 586 t.Fatal("got ep.Enabled() = false, want = true") 587 } 588 589 // ep should consider the NIC's enabled status when determining its own 590 // enabled status. 591 nic.setEnabled(false) 592 if ep.Enabled() { 593 t.Fatal("got ep.Enabled() = true, want = false") 594 } 595 596 // Disabling the endpoint when the NIC is enabled should make the endpoint 597 // disabled. 598 nic.setEnabled(true) 599 ep.Disable() 600 if ep.Enabled() { 601 t.Fatal("got ep.Enabled() = true, want = false") 602 } 603 }) 604 } 605 } 606 607 func TestIPv4Send(t *testing.T) { 608 ctx := newTestContext() 609 defer ctx.cleanup() 610 s := ctx.s 611 612 proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) 613 nic := testInterface{ 614 testObject: testObject{ 615 t: t, 616 v4: true, 617 }, 618 } 619 ep := proto.NewEndpoint(&nic, nil) 620 defer ep.Close() 621 622 // Allocate and initialize the payload view. 623 payload := make([]byte, 100) 624 for i := 0; i < len(payload); i++ { 625 payload[i] = uint8(i) 626 } 627 628 // Setup the packet buffer. 629 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 630 ReserveHeaderBytes: int(ep.MaxHeaderLength()), 631 Payload: buffer.MakeWithData(payload), 632 }) 633 defer pkt.DecRef() 634 635 // Issue the write. 636 nic.testObject.protocol = 123 637 nic.testObject.srcAddr = localIPv4Addr 638 nic.testObject.dstAddr = remoteIPv4Addr 639 nic.testObject.contents = payload 640 641 r, err := buildIPv4Route(ctx, localIPv4Addr, remoteIPv4Addr) 642 if err != nil { 643 t.Fatalf("could not find route: %v", err) 644 } 645 defer r.Release() 646 if err := ep.WritePacket(r, stack.NetworkHeaderParams{ 647 Protocol: 123, 648 TTL: 123, 649 TOS: stack.DefaultTOS, 650 }, pkt); err != nil { 651 t.Fatalf("WritePacket failed: %v", err) 652 } 653 } 654 655 func TestReceive(t *testing.T) { 656 tests := []struct { 657 name string 658 protoFactory stack.NetworkProtocolFactory 659 protoNum tcpip.NetworkProtocolNumber 660 v4 bool 661 epAddr tcpip.AddressWithPrefix 662 handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) 663 }{ 664 { 665 name: "IPv4", 666 protoFactory: ipv4.NewProtocol, 667 protoNum: ipv4.ProtocolNumber, 668 v4: true, 669 epAddr: localIPv4Addr.WithPrefix(), 670 handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { 671 const totalLen = header.IPv4MinimumSize + 30 /* payload length */ 672 673 view := make([]byte, totalLen) 674 ip := header.IPv4(view) 675 ip.Encode(&header.IPv4Fields{ 676 TotalLength: totalLen, 677 TTL: ipv4.DefaultTTL, 678 Protocol: 10, 679 SrcAddr: remoteIPv4Addr, 680 DstAddr: localIPv4Addr, 681 }) 682 ip.SetChecksum(^ip.CalculateChecksum()) 683 684 // Make payload be non-zero. 685 for i := header.IPv4MinimumSize; i < len(view); i++ { 686 view[i] = uint8(i) 687 } 688 689 // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. 690 nic.testObject.protocol = 10 691 nic.testObject.srcAddr = remoteIPv4Addr 692 nic.testObject.dstAddr = localIPv4Addr 693 nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] 694 695 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 696 Payload: buffer.MakeWithData(view), 697 }) 698 ep.HandlePacket(pkt) 699 pkt.DecRef() 700 }, 701 }, 702 { 703 name: "IPv6", 704 protoFactory: ipv6.NewProtocol, 705 protoNum: ipv6.ProtocolNumber, 706 v4: false, 707 epAddr: localIPv6Addr.WithPrefix(), 708 handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { 709 const payloadLen = 30 710 view := make([]byte, header.IPv6MinimumSize+payloadLen) 711 ip := header.IPv6(view) 712 ip.Encode(&header.IPv6Fields{ 713 PayloadLength: payloadLen, 714 TransportProtocol: 10, 715 HopLimit: ipv6.DefaultTTL, 716 SrcAddr: remoteIPv6Addr, 717 DstAddr: localIPv6Addr, 718 }) 719 720 // Make payload be non-zero. 721 for i := header.IPv6MinimumSize; i < len(view); i++ { 722 view[i] = uint8(i) 723 } 724 725 // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. 726 nic.testObject.protocol = 10 727 nic.testObject.srcAddr = remoteIPv6Addr 728 nic.testObject.dstAddr = localIPv6Addr 729 nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] 730 731 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 732 Payload: buffer.MakeWithData(view), 733 }) 734 ep.HandlePacket(pkt) 735 pkt.DecRef() 736 }, 737 }, 738 } 739 740 for _, test := range tests { 741 t.Run(test.name, func(t *testing.T) { 742 s := stack.New(stack.Options{ 743 NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, 744 }) 745 defer func() { 746 s.Close() 747 s.Wait() 748 }() 749 750 nic := testInterface{ 751 testObject: testObject{ 752 t: t, 753 v4: test.v4, 754 }, 755 } 756 ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) 757 defer ep.Close() 758 759 if err := ep.Enable(); err != nil { 760 t.Fatalf("ep.Enable(): %s", err) 761 } 762 763 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 764 if !ok { 765 t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) 766 } 767 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil { 768 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err) 769 } else { 770 ep.DecRef() 771 } 772 773 stat := s.Stats().IP.PacketsReceived 774 if got := stat.Value(); got != 0 { 775 t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) 776 } 777 test.handlePacket(t, ep, &nic) 778 if nic.testObject.dataCalls != 1 { 779 t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) 780 } 781 if nic.testObject.rawCalls != 1 { 782 t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) 783 } 784 if got := stat.Value(); got != 1 { 785 t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) 786 } 787 }) 788 } 789 } 790 791 func TestIPv4ReceiveControl(t *testing.T) { 792 const ( 793 mtu = 0xbeef - header.IPv4MinimumSize 794 dataLen = 8 795 ) 796 797 cases := []struct { 798 name string 799 expectedCount int 800 fragmentOffset uint16 801 code header.ICMPv4Code 802 transErr transportError 803 trunc int 804 }{ 805 { 806 name: "FragmentationNeeded", 807 expectedCount: 1, 808 fragmentOffset: 0, 809 code: header.ICMPv4FragmentationNeeded, 810 transErr: transportError{ 811 origin: tcpip.SockExtErrorOriginICMP, 812 typ: uint8(header.ICMPv4DstUnreachable), 813 code: uint8(header.ICMPv4FragmentationNeeded), 814 info: mtu, 815 kind: stack.PacketTooBigTransportError, 816 }, 817 trunc: 0, 818 }, 819 { 820 name: "Truncated (missing IPv4 header)", 821 expectedCount: 0, 822 fragmentOffset: 0, 823 code: header.ICMPv4FragmentationNeeded, 824 trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, 825 }, 826 { 827 name: "Truncated (partial offending packet's IP header)", 828 expectedCount: 0, 829 fragmentOffset: 0, 830 code: header.ICMPv4FragmentationNeeded, 831 trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, 832 }, 833 { 834 name: "Truncated (partial offending packet's data)", 835 expectedCount: 0, 836 fragmentOffset: 0, 837 code: header.ICMPv4FragmentationNeeded, 838 trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, 839 }, 840 { 841 name: "Port unreachable", 842 expectedCount: 1, 843 fragmentOffset: 0, 844 code: header.ICMPv4PortUnreachable, 845 transErr: transportError{ 846 origin: tcpip.SockExtErrorOriginICMP, 847 typ: uint8(header.ICMPv4DstUnreachable), 848 code: uint8(header.ICMPv4PortUnreachable), 849 kind: stack.DestinationPortUnreachableTransportError, 850 }, 851 trunc: 0, 852 }, 853 { 854 name: "Non-zero fragment offset", 855 expectedCount: 0, 856 fragmentOffset: 100, 857 code: header.ICMPv4PortUnreachable, 858 trunc: 0, 859 }, 860 { 861 name: "Zero-length packet", 862 expectedCount: 0, 863 fragmentOffset: 100, 864 code: header.ICMPv4PortUnreachable, 865 trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, 866 }, 867 } 868 for _, c := range cases { 869 t.Run(c.name, func(t *testing.T) { 870 ctx := newTestContext() 871 defer ctx.cleanup() 872 s := ctx.s 873 874 proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) 875 nic := testInterface{ 876 testObject: testObject{ 877 t: t, 878 }, 879 } 880 ep := proto.NewEndpoint(&nic, &nic.testObject) 881 defer ep.Close() 882 883 if err := ep.Enable(); err != nil { 884 t.Fatalf("ep.Enable(): %s", err) 885 } 886 887 const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize 888 view := make([]byte, dataOffset+dataLen) 889 890 // Create the outer IPv4 header. 891 ip := header.IPv4(view) 892 ip.Encode(&header.IPv4Fields{ 893 TotalLength: uint16(len(view) - c.trunc), 894 TTL: 20, 895 Protocol: uint8(header.ICMPv4ProtocolNumber), 896 SrcAddr: tcpip.AddrFromSlice([]byte("\x0a\x00\x00\xbb")), 897 DstAddr: localIPv4Addr, 898 }) 899 ip.SetChecksum(^ip.CalculateChecksum()) 900 901 // Create the ICMP header. 902 icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) 903 icmp.SetType(header.ICMPv4DstUnreachable) 904 icmp.SetCode(c.code) 905 icmp.SetIdent(0xdead) 906 icmp.SetSequence(0xbeef) 907 908 // Create the inner IPv4 header. 909 ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) 910 ip.Encode(&header.IPv4Fields{ 911 TotalLength: 100, 912 TTL: 20, 913 Protocol: 10, 914 FragmentOffset: c.fragmentOffset, 915 SrcAddr: localIPv4Addr, 916 DstAddr: remoteIPv4Addr, 917 }) 918 ip.SetChecksum(^ip.CalculateChecksum()) 919 920 // Make payload be non-zero. 921 for i := dataOffset; i < len(view); i++ { 922 view[i] = uint8(i) 923 } 924 925 icmp.SetChecksum(0) 926 xsum := ^checksum.Checksum(icmp, 0 /* initial */) 927 icmp.SetChecksum(xsum) 928 929 // Give packet to IPv4 endpoint, dispatcher will validate that 930 // it's ok. 931 nic.testObject.protocol = 10 932 nic.testObject.srcAddr = remoteIPv4Addr 933 nic.testObject.dstAddr = localIPv4Addr 934 nic.testObject.contents = view[dataOffset:] 935 nic.testObject.transErr = c.transErr 936 937 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 938 if !ok { 939 t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") 940 } 941 addr := localIPv4Addr.WithPrefix() 942 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { 943 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) 944 } else { 945 ep.DecRef() 946 } 947 948 pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) 949 ep.HandlePacket(pkt) 950 pkt.DecRef() 951 if want := c.expectedCount; nic.testObject.controlCalls != want { 952 t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) 953 } 954 }) 955 } 956 } 957 958 func TestIPv4FragmentationReceive(t *testing.T) { 959 ctx := newTestContext() 960 defer ctx.cleanup() 961 s := ctx.s 962 963 proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) 964 nic := testInterface{ 965 testObject: testObject{ 966 t: t, 967 v4: true, 968 }, 969 } 970 ep := proto.NewEndpoint(&nic, &nic.testObject) 971 defer ep.Close() 972 973 if err := ep.Enable(); err != nil { 974 t.Fatalf("ep.Enable(): %s", err) 975 } 976 977 totalLen := header.IPv4MinimumSize + 24 978 979 frag1 := make([]byte, totalLen) 980 ip1 := header.IPv4(frag1) 981 ip1.Encode(&header.IPv4Fields{ 982 TotalLength: uint16(totalLen), 983 TTL: 20, 984 Protocol: 10, 985 FragmentOffset: 0, 986 Flags: header.IPv4FlagMoreFragments, 987 SrcAddr: remoteIPv4Addr, 988 DstAddr: localIPv4Addr, 989 }) 990 ip1.SetChecksum(^ip1.CalculateChecksum()) 991 992 // Make payload be non-zero. 993 for i := header.IPv4MinimumSize; i < totalLen; i++ { 994 frag1[i] = uint8(i) 995 } 996 997 frag2 := make([]byte, totalLen) 998 ip2 := header.IPv4(frag2) 999 ip2.Encode(&header.IPv4Fields{ 1000 TotalLength: uint16(totalLen), 1001 TTL: 20, 1002 Protocol: 10, 1003 FragmentOffset: 24, 1004 SrcAddr: remoteIPv4Addr, 1005 DstAddr: localIPv4Addr, 1006 }) 1007 ip2.SetChecksum(^ip2.CalculateChecksum()) 1008 1009 // Make payload be non-zero. 1010 for i := header.IPv4MinimumSize; i < totalLen; i++ { 1011 frag2[i] = uint8(i) 1012 } 1013 1014 // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. 1015 nic.testObject.protocol = 10 1016 nic.testObject.srcAddr = remoteIPv4Addr 1017 nic.testObject.dstAddr = localIPv4Addr 1018 nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) 1019 1020 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 1021 if !ok { 1022 t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") 1023 } 1024 addr := localIPv4Addr.WithPrefix() 1025 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { 1026 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) 1027 } else { 1028 ep.DecRef() 1029 } 1030 1031 // Send first segment. 1032 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1033 Payload: buffer.MakeWithData(frag1), 1034 }) 1035 ep.HandlePacket(pkt) 1036 pkt.DecRef() 1037 1038 if nic.testObject.dataCalls != 0 { 1039 t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls) 1040 } 1041 if nic.testObject.rawCalls != 0 { 1042 t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls) 1043 } 1044 1045 // Send second segment. 1046 pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ 1047 Payload: buffer.MakeWithData(frag2), 1048 }) 1049 ep.HandlePacket(pkt) 1050 pkt.DecRef() 1051 1052 if nic.testObject.dataCalls != 1 { 1053 t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) 1054 } 1055 if nic.testObject.rawCalls != 1 { 1056 t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) 1057 } 1058 } 1059 1060 func TestIPv6Send(t *testing.T) { 1061 ctx := newTestContext() 1062 defer ctx.cleanup() 1063 s := ctx.s 1064 1065 proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) 1066 nic := testInterface{ 1067 testObject: testObject{ 1068 t: t, 1069 }, 1070 } 1071 ep := proto.NewEndpoint(&nic, nil) 1072 defer ep.Close() 1073 1074 if err := ep.Enable(); err != nil { 1075 t.Fatalf("ep.Enable(): %s", err) 1076 } 1077 1078 // Allocate and initialize the payload view. 1079 payload := make([]byte, 100) 1080 for i := 0; i < len(payload); i++ { 1081 payload[i] = uint8(i) 1082 } 1083 1084 // Setup the packet buffer. 1085 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1086 ReserveHeaderBytes: int(ep.MaxHeaderLength()), 1087 Payload: buffer.MakeWithData(payload), 1088 }) 1089 defer pkt.DecRef() 1090 // Issue the write. 1091 nic.testObject.protocol = 123 1092 nic.testObject.srcAddr = localIPv6Addr 1093 nic.testObject.dstAddr = remoteIPv6Addr 1094 nic.testObject.contents = payload 1095 1096 r, err := buildIPv6Route(ctx, localIPv6Addr, remoteIPv6Addr) 1097 if err != nil { 1098 t.Fatalf("could not find route: %v", err) 1099 } 1100 defer r.Release() 1101 if err := ep.WritePacket(r, stack.NetworkHeaderParams{ 1102 Protocol: 123, 1103 TTL: 123, 1104 TOS: stack.DefaultTOS, 1105 }, pkt); err != nil { 1106 t.Fatalf("WritePacket failed: %v", err) 1107 } 1108 } 1109 1110 func TestIPv6ReceiveControl(t *testing.T) { 1111 const ( 1112 mtu = 0xffff 1113 dataLen = 8 1114 ) 1115 outerSrcAddr := tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa\x00\x00\x00")) 1116 1117 newUint16 := func(v uint16) *uint16 { return &v } 1118 1119 portUnreachableTransErr := transportError{ 1120 origin: tcpip.SockExtErrorOriginICMP6, 1121 typ: uint8(header.ICMPv6DstUnreachable), 1122 code: uint8(header.ICMPv6PortUnreachable), 1123 kind: stack.DestinationPortUnreachableTransportError, 1124 } 1125 1126 cases := []struct { 1127 name string 1128 expectedCount int 1129 fragmentOffset *uint16 1130 typ header.ICMPv6Type 1131 code header.ICMPv6Code 1132 transErr transportError 1133 trunc int 1134 }{ 1135 { 1136 name: "PacketTooBig", 1137 expectedCount: 1, 1138 fragmentOffset: nil, 1139 typ: header.ICMPv6PacketTooBig, 1140 code: header.ICMPv6UnusedCode, 1141 transErr: transportError{ 1142 origin: tcpip.SockExtErrorOriginICMP6, 1143 typ: uint8(header.ICMPv6PacketTooBig), 1144 code: uint8(header.ICMPv6UnusedCode), 1145 info: mtu, 1146 kind: stack.PacketTooBigTransportError, 1147 }, 1148 trunc: 0, 1149 }, 1150 { 1151 name: "Truncated (missing offending packet's IPv6 header)", 1152 expectedCount: 0, 1153 fragmentOffset: nil, 1154 typ: header.ICMPv6PacketTooBig, 1155 code: header.ICMPv6UnusedCode, 1156 trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, 1157 }, 1158 { 1159 name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", 1160 expectedCount: 0, 1161 fragmentOffset: nil, 1162 typ: header.ICMPv6PacketTooBig, 1163 code: header.ICMPv6UnusedCode, 1164 trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, 1165 }, 1166 { 1167 name: "Truncated (partial offending packet's data)", 1168 expectedCount: 0, 1169 fragmentOffset: nil, 1170 typ: header.ICMPv6PacketTooBig, 1171 code: header.ICMPv6UnusedCode, 1172 trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, 1173 }, 1174 { 1175 name: "Port unreachable", 1176 expectedCount: 1, 1177 fragmentOffset: nil, 1178 typ: header.ICMPv6DstUnreachable, 1179 code: header.ICMPv6PortUnreachable, 1180 transErr: portUnreachableTransErr, 1181 trunc: 0, 1182 }, 1183 { 1184 name: "Truncated DstPortUnreachable (partial offending packet's IP header)", 1185 expectedCount: 0, 1186 fragmentOffset: nil, 1187 typ: header.ICMPv6DstUnreachable, 1188 code: header.ICMPv6PortUnreachable, 1189 trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, 1190 }, 1191 { 1192 name: "DstPortUnreachable for Fragmented, zero offset", 1193 expectedCount: 1, 1194 fragmentOffset: newUint16(0), 1195 typ: header.ICMPv6DstUnreachable, 1196 code: header.ICMPv6PortUnreachable, 1197 transErr: portUnreachableTransErr, 1198 trunc: 0, 1199 }, 1200 { 1201 name: "DstPortUnreachable for Non-zero fragment offset", 1202 expectedCount: 0, 1203 fragmentOffset: newUint16(100), 1204 typ: header.ICMPv6DstUnreachable, 1205 code: header.ICMPv6PortUnreachable, 1206 transErr: portUnreachableTransErr, 1207 trunc: 0, 1208 }, 1209 { 1210 name: "Zero-length packet", 1211 expectedCount: 0, 1212 fragmentOffset: nil, 1213 typ: header.ICMPv6DstUnreachable, 1214 code: header.ICMPv6PortUnreachable, 1215 trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, 1216 }, 1217 } 1218 for _, c := range cases { 1219 t.Run(c.name, func(t *testing.T) { 1220 ctx := newTestContext() 1221 defer ctx.cleanup() 1222 s := ctx.s 1223 1224 proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) 1225 nic := testInterface{ 1226 testObject: testObject{ 1227 t: t, 1228 }, 1229 } 1230 ep := proto.NewEndpoint(&nic, &nic.testObject) 1231 defer ep.Close() 1232 1233 if err := ep.Enable(); err != nil { 1234 t.Fatalf("ep.Enable(): %s", err) 1235 } 1236 1237 dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize 1238 if c.fragmentOffset != nil { 1239 dataOffset += header.IPv6FragmentHeaderSize 1240 } 1241 view := make([]byte, dataOffset+dataLen) 1242 1243 // Create the outer IPv6 header. 1244 ip := header.IPv6(view) 1245 ip.Encode(&header.IPv6Fields{ 1246 PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), 1247 TransportProtocol: header.ICMPv6ProtocolNumber, 1248 HopLimit: 20, 1249 SrcAddr: outerSrcAddr, 1250 DstAddr: localIPv6Addr, 1251 }) 1252 1253 // Create the ICMP header. 1254 icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) 1255 icmp.SetType(c.typ) 1256 icmp.SetCode(c.code) 1257 icmp.SetIdent(0xdead) 1258 icmp.SetSequence(0xbeef) 1259 1260 var extHdrs header.IPv6ExtHdrSerializer 1261 // Build the fragmentation header if needed. 1262 if c.fragmentOffset != nil { 1263 extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ 1264 FragmentOffset: *c.fragmentOffset, 1265 M: true, 1266 Identification: 0x12345678, 1267 }) 1268 } 1269 1270 // Create the inner IPv6 header. 1271 ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) 1272 ip.Encode(&header.IPv6Fields{ 1273 PayloadLength: 100, 1274 TransportProtocol: 10, 1275 HopLimit: 20, 1276 SrcAddr: localIPv6Addr, 1277 DstAddr: remoteIPv6Addr, 1278 ExtensionHeaders: extHdrs, 1279 }) 1280 1281 // Make payload be non-zero. 1282 for i := dataOffset; i < len(view); i++ { 1283 view[i] = uint8(i) 1284 } 1285 1286 // Give packet to IPv6 endpoint, dispatcher will validate that 1287 // it's ok. 1288 nic.testObject.protocol = 10 1289 nic.testObject.srcAddr = remoteIPv6Addr 1290 nic.testObject.dstAddr = localIPv6Addr 1291 nic.testObject.contents = view[dataOffset:] 1292 nic.testObject.transErr = c.transErr 1293 1294 // Set ICMPv6 checksum. 1295 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 1296 Header: icmp, 1297 Src: outerSrcAddr, 1298 Dst: localIPv6Addr, 1299 })) 1300 1301 addressableEndpoint, ok := ep.(stack.AddressableEndpoint) 1302 if !ok { 1303 t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") 1304 } 1305 addr := localIPv6Addr.WithPrefix() 1306 if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil { 1307 t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err) 1308 } else { 1309 ep.DecRef() 1310 } 1311 pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) 1312 ep.HandlePacket(pkt) 1313 pkt.DecRef() 1314 if want := c.expectedCount; nic.testObject.controlCalls != want { 1315 t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) 1316 } 1317 }) 1318 } 1319 } 1320 1321 // truncatedPacket returns a PacketBuffer based on a truncated view. If view, 1322 // after truncation, is large enough to hold a network header, it makes part of 1323 // view the packet's NetworkHeader and the rest its Data. Otherwise all of view 1324 // becomes Data. 1325 func truncatedPacket(view []byte, trunc, netHdrLen int) *stack.PacketBuffer { 1326 v := view[:len(view)-trunc] 1327 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1328 Payload: buffer.MakeWithData(v), 1329 }) 1330 return pkt 1331 } 1332 1333 func TestWriteHeaderIncludedPacket(t *testing.T) { 1334 const ( 1335 nicID = 1 1336 transportProto = 5 1337 1338 dataLen = 4 1339 ) 1340 1341 dataBuf := [dataLen]byte{1, 2, 3, 4} 1342 data := dataBuf[:] 1343 1344 ipv4Options := header.IPv4OptionsSerializer{ 1345 &header.IPv4SerializableListEndOption{}, 1346 &header.IPv4SerializableNOPOption{}, 1347 &header.IPv4SerializableListEndOption{}, 1348 &header.IPv4SerializableNOPOption{}, 1349 } 1350 1351 expectOptions := header.IPv4Options{ 1352 byte(header.IPv4OptionListEndType), 1353 byte(header.IPv4OptionNOPType), 1354 byte(header.IPv4OptionListEndType), 1355 byte(header.IPv4OptionNOPType), 1356 } 1357 1358 ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} 1359 ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] 1360 1361 var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte 1362 ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] 1363 if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { 1364 t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) 1365 } 1366 if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { 1367 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1368 } 1369 1370 tests := []struct { 1371 name string 1372 protoFactory stack.NetworkProtocolFactory 1373 protoNum tcpip.NetworkProtocolNumber 1374 nicAddr tcpip.AddressWithPrefix 1375 remoteAddr tcpip.Address 1376 pktGen func(*testing.T, tcpip.Address) buffer.Buffer 1377 checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) 1378 expectedErr tcpip.Error 1379 }{ 1380 { 1381 name: "IPv4", 1382 protoFactory: ipv4.NewProtocol, 1383 protoNum: ipv4.ProtocolNumber, 1384 nicAddr: localIPv4AddrWithPrefix, 1385 remoteAddr: remoteIPv4Addr, 1386 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1387 totalLen := header.IPv4MinimumSize + len(data) 1388 hdr := prependable.New(totalLen) 1389 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1390 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1391 } 1392 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 1393 ip.Encode(&header.IPv4Fields{ 1394 Protocol: transportProto, 1395 TTL: ipv4.DefaultTTL, 1396 SrcAddr: src, 1397 DstAddr: remoteIPv4Addr, 1398 }) 1399 return buffer.MakeWithData(hdr.View()) 1400 }, 1401 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1402 if src == header.IPv4Any { 1403 src = localIPv4Addr 1404 } 1405 1406 netHdr := pkt.NetworkHeader() 1407 1408 if len(netHdr.Slice()) != header.IPv4MinimumSize { 1409 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), header.IPv4MinimumSize) 1410 } 1411 1412 payload := stack.PayloadSince(netHdr) 1413 defer payload.Release() 1414 checker.IPv4(t, payload, 1415 checker.SrcAddr(src), 1416 checker.DstAddr(remoteIPv4Addr), 1417 checker.IPv4HeaderLength(header.IPv4MinimumSize), 1418 checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), 1419 checker.IPPayload(data), 1420 ) 1421 }, 1422 }, 1423 { 1424 name: "IPv4 with IHL too small", 1425 protoFactory: ipv4.NewProtocol, 1426 protoNum: ipv4.ProtocolNumber, 1427 nicAddr: localIPv4AddrWithPrefix, 1428 remoteAddr: remoteIPv4Addr, 1429 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1430 totalLen := header.IPv4MinimumSize + len(data) 1431 hdr := prependable.New(totalLen) 1432 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1433 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1434 } 1435 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 1436 ip.Encode(&header.IPv4Fields{ 1437 Protocol: transportProto, 1438 TTL: ipv4.DefaultTTL, 1439 SrcAddr: src, 1440 DstAddr: remoteIPv4Addr, 1441 }) 1442 ip.SetHeaderLength(header.IPv4MinimumSize - 1) 1443 return buffer.MakeWithData(hdr.View()) 1444 }, 1445 expectedErr: &tcpip.ErrMalformedHeader{}, 1446 }, 1447 { 1448 name: "IPv4 too small", 1449 protoFactory: ipv4.NewProtocol, 1450 protoNum: ipv4.ProtocolNumber, 1451 nicAddr: localIPv4AddrWithPrefix, 1452 remoteAddr: remoteIPv4Addr, 1453 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1454 ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) 1455 ip.Encode(&header.IPv4Fields{ 1456 Protocol: transportProto, 1457 TTL: ipv4.DefaultTTL, 1458 SrcAddr: src, 1459 DstAddr: remoteIPv4Addr, 1460 }) 1461 return buffer.MakeWithData(ip[:len(ip)-1]) 1462 }, 1463 expectedErr: &tcpip.ErrMalformedHeader{}, 1464 }, 1465 { 1466 name: "IPv4 minimum size", 1467 protoFactory: ipv4.NewProtocol, 1468 protoNum: ipv4.ProtocolNumber, 1469 nicAddr: localIPv4AddrWithPrefix, 1470 remoteAddr: remoteIPv4Addr, 1471 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1472 ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) 1473 ip.Encode(&header.IPv4Fields{ 1474 Protocol: transportProto, 1475 TTL: ipv4.DefaultTTL, 1476 SrcAddr: src, 1477 DstAddr: remoteIPv4Addr, 1478 }) 1479 return buffer.MakeWithData(ip) 1480 }, 1481 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1482 if src == header.IPv4Any { 1483 src = localIPv4Addr 1484 } 1485 1486 netHdr := pkt.NetworkHeader() 1487 1488 if len(netHdr.Slice()) != header.IPv4MinimumSize { 1489 t.Errorf("got len(netHdr.Slice()) = %d, want = %d", len(netHdr.Slice()), header.IPv4MinimumSize) 1490 } 1491 1492 payload := stack.PayloadSince(netHdr) 1493 defer payload.Release() 1494 checker.IPv4(t, payload, 1495 checker.SrcAddr(src), 1496 checker.DstAddr(remoteIPv4Addr), 1497 checker.IPv4HeaderLength(header.IPv4MinimumSize), 1498 checker.IPFullLength(header.IPv4MinimumSize), 1499 checker.IPPayload(nil), 1500 ) 1501 }, 1502 }, 1503 { 1504 name: "IPv4 with options", 1505 protoFactory: ipv4.NewProtocol, 1506 protoNum: ipv4.ProtocolNumber, 1507 nicAddr: localIPv4AddrWithPrefix, 1508 remoteAddr: remoteIPv4Addr, 1509 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1510 ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) 1511 totalLen := ipHdrLen + len(data) 1512 hdr := prependable.New(totalLen) 1513 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1514 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1515 } 1516 ip := header.IPv4(hdr.Prepend(ipHdrLen)) 1517 ip.Encode(&header.IPv4Fields{ 1518 Protocol: transportProto, 1519 TTL: ipv4.DefaultTTL, 1520 SrcAddr: src, 1521 DstAddr: remoteIPv4Addr, 1522 Options: ipv4Options, 1523 }) 1524 return buffer.MakeWithData(hdr.View()) 1525 }, 1526 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1527 if src == header.IPv4Any { 1528 src = localIPv4Addr 1529 } 1530 1531 netHdr := pkt.NetworkHeader() 1532 1533 hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) 1534 if len(netHdr.Slice()) != hdrLen { 1535 t.Errorf("got len(netHdr.Slice()) = %d, want = %d", len(netHdr.Slice()), hdrLen) 1536 } 1537 1538 payload := stack.PayloadSince(netHdr) 1539 defer payload.Release() 1540 checker.IPv4(t, payload, 1541 checker.SrcAddr(src), 1542 checker.DstAddr(remoteIPv4Addr), 1543 checker.IPv4HeaderLength(hdrLen), 1544 checker.IPFullLength(uint16(hdrLen+len(data))), 1545 checker.IPv4Options(expectOptions), 1546 checker.IPPayload(data), 1547 ) 1548 }, 1549 }, 1550 { 1551 name: "IPv4 with options and data across views", 1552 protoFactory: ipv4.NewProtocol, 1553 protoNum: ipv4.ProtocolNumber, 1554 nicAddr: localIPv4AddrWithPrefix, 1555 remoteAddr: remoteIPv4Addr, 1556 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1557 ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) 1558 ip.Encode(&header.IPv4Fields{ 1559 Protocol: transportProto, 1560 TTL: ipv4.DefaultTTL, 1561 SrcAddr: src, 1562 DstAddr: remoteIPv4Addr, 1563 Options: ipv4Options, 1564 }) 1565 buf := buffer.MakeWithData(ip) 1566 buf.Append(buffer.NewViewWithData(data)) 1567 return buf 1568 }, 1569 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1570 if src == header.IPv4Any { 1571 src = localIPv4Addr 1572 } 1573 1574 netHdr := pkt.NetworkHeader() 1575 1576 hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) 1577 if len(netHdr.Slice()) != hdrLen { 1578 t.Errorf("got len(netHdr.Slice()) = %d, want = %d", len(netHdr.Slice()), hdrLen) 1579 } 1580 1581 payload := stack.PayloadSince(netHdr) 1582 defer payload.Release() 1583 checker.IPv4(t, payload, 1584 checker.SrcAddr(src), 1585 checker.DstAddr(remoteIPv4Addr), 1586 checker.IPv4HeaderLength(hdrLen), 1587 checker.IPFullLength(uint16(hdrLen+len(data))), 1588 checker.IPv4Options(expectOptions), 1589 checker.IPPayload(data), 1590 ) 1591 }, 1592 }, 1593 { 1594 name: "IPv6", 1595 protoFactory: ipv6.NewProtocol, 1596 protoNum: ipv6.ProtocolNumber, 1597 nicAddr: localIPv6AddrWithPrefix, 1598 remoteAddr: remoteIPv6Addr, 1599 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1600 totalLen := header.IPv6MinimumSize + len(data) 1601 hdr := prependable.New(totalLen) 1602 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1603 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1604 } 1605 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 1606 ip.Encode(&header.IPv6Fields{ 1607 TransportProtocol: transportProto, 1608 HopLimit: ipv6.DefaultTTL, 1609 SrcAddr: src, 1610 DstAddr: remoteIPv6Addr, 1611 }) 1612 return buffer.MakeWithData(hdr.View()) 1613 }, 1614 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1615 if src == header.IPv6Any { 1616 src = localIPv6Addr 1617 } 1618 1619 netHdr := pkt.NetworkHeader() 1620 1621 if len(netHdr.Slice()) != header.IPv6MinimumSize { 1622 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), header.IPv6MinimumSize) 1623 } 1624 1625 payload := stack.PayloadSince(netHdr) 1626 defer payload.Release() 1627 checker.IPv6(t, payload, 1628 checker.SrcAddr(src), 1629 checker.DstAddr(remoteIPv6Addr), 1630 checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), 1631 checker.IPPayload(data), 1632 ) 1633 }, 1634 }, 1635 { 1636 name: "IPv6 with extension header", 1637 protoFactory: ipv6.NewProtocol, 1638 protoNum: ipv6.ProtocolNumber, 1639 nicAddr: localIPv6AddrWithPrefix, 1640 remoteAddr: remoteIPv6Addr, 1641 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1642 totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) 1643 hdr := prependable.New(totalLen) 1644 if n := copy(hdr.Prepend(len(data)), data); n != len(data) { 1645 t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) 1646 } 1647 if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { 1648 t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) 1649 } 1650 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 1651 ip.Encode(&header.IPv6Fields{ 1652 // NB: we're lying about transport protocol here to verify the raw 1653 // fragment header bytes. 1654 TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), 1655 HopLimit: ipv6.DefaultTTL, 1656 SrcAddr: src, 1657 DstAddr: remoteIPv6Addr, 1658 }) 1659 return buffer.MakeWithData(hdr.View()) 1660 }, 1661 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1662 if src == header.IPv6Any { 1663 src = localIPv6Addr 1664 } 1665 1666 netHdr := pkt.NetworkHeader() 1667 1668 if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.Slice()) != want { 1669 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), want) 1670 } 1671 1672 payload := stack.PayloadSince(netHdr) 1673 defer payload.Release() 1674 checker.IPv6(t, payload, 1675 checker.SrcAddr(src), 1676 checker.DstAddr(remoteIPv6Addr), 1677 checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), 1678 checker.IPPayload(ipv6PayloadWithExtHdr), 1679 ) 1680 }, 1681 }, 1682 { 1683 name: "IPv6 minimum size", 1684 protoFactory: ipv6.NewProtocol, 1685 protoNum: ipv6.ProtocolNumber, 1686 nicAddr: localIPv6AddrWithPrefix, 1687 remoteAddr: remoteIPv6Addr, 1688 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1689 ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) 1690 ip.Encode(&header.IPv6Fields{ 1691 TransportProtocol: transportProto, 1692 HopLimit: ipv6.DefaultTTL, 1693 SrcAddr: src, 1694 DstAddr: remoteIPv6Addr, 1695 }) 1696 return buffer.MakeWithData(ip) 1697 }, 1698 checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { 1699 if src == header.IPv6Any { 1700 src = localIPv6Addr 1701 } 1702 1703 netHdr := pkt.NetworkHeader() 1704 1705 if len(netHdr.Slice()) != header.IPv6MinimumSize { 1706 t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), header.IPv6MinimumSize) 1707 } 1708 1709 payload := stack.PayloadSince(netHdr) 1710 defer payload.Release() 1711 checker.IPv6(t, payload, 1712 checker.SrcAddr(src), 1713 checker.DstAddr(remoteIPv6Addr), 1714 checker.IPFullLength(header.IPv6MinimumSize), 1715 checker.IPPayload(nil), 1716 ) 1717 }, 1718 }, 1719 { 1720 name: "IPv6 too small", 1721 protoFactory: ipv6.NewProtocol, 1722 protoNum: ipv6.ProtocolNumber, 1723 nicAddr: localIPv6AddrWithPrefix, 1724 remoteAddr: remoteIPv6Addr, 1725 pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer { 1726 ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) 1727 ip.Encode(&header.IPv6Fields{ 1728 TransportProtocol: transportProto, 1729 HopLimit: ipv6.DefaultTTL, 1730 SrcAddr: src, 1731 DstAddr: remoteIPv4Addr, 1732 }) 1733 return buffer.MakeWithData(ip[:len(ip)-1]) 1734 }, 1735 expectedErr: &tcpip.ErrMalformedHeader{}, 1736 }, 1737 } 1738 1739 for _, test := range tests { 1740 t.Run(test.name, func(t *testing.T) { 1741 subTests := []struct { 1742 name string 1743 srcAddr tcpip.Address 1744 }{ 1745 { 1746 name: "unspecified source", 1747 srcAddr: tcpip.AddrFromSlice([]byte(strings.Repeat("\x00", test.nicAddr.Address.Len()))), 1748 }, 1749 { 1750 name: "random source", 1751 srcAddr: tcpip.AddrFromSlice([]byte(strings.Repeat("\xab", test.nicAddr.Address.Len()))), 1752 }, 1753 } 1754 1755 for _, subTest := range subTests { 1756 t.Run(subTest.name, func(t *testing.T) { 1757 s := stack.New(stack.Options{ 1758 NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, 1759 }) 1760 defer func() { 1761 s.Close() 1762 s.Wait() 1763 }() 1764 1765 e := channel.New(1, header.IPv6MinimumMTU, "") 1766 defer e.Close() 1767 if err := s.CreateNIC(nicID, e); err != nil { 1768 t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) 1769 } 1770 protocolAddr := tcpip.ProtocolAddress{ 1771 Protocol: test.protoNum, 1772 AddressWithPrefix: test.nicAddr, 1773 } 1774 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 1775 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) 1776 } 1777 1778 s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) 1779 1780 r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */) 1781 if err != nil { 1782 t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err) 1783 } 1784 defer r.Release() 1785 1786 { 1787 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1788 Payload: test.pktGen(t, subTest.srcAddr), 1789 }) 1790 err := r.WriteHeaderIncludedPacket(pkt) 1791 pkt.DecRef() 1792 if diff := cmp.Diff(test.expectedErr, err); diff != "" { 1793 t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) 1794 } 1795 } 1796 1797 if test.expectedErr != nil { 1798 return 1799 } 1800 1801 pkt := e.Read() 1802 if pkt == nil { 1803 t.Fatal("expected a packet to be written") 1804 } 1805 test.checker(t, pkt, subTest.srcAddr) 1806 pkt.DecRef() 1807 }) 1808 } 1809 }) 1810 } 1811 } 1812 1813 // Test that the included data in an ICMP error packet conforms to the 1814 // requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 1815 func TestICMPInclusionSize(t *testing.T) { 1816 const ( 1817 replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize 1818 replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize 1819 targetSize4 = header.IPv4MinimumProcessableDatagramSize 1820 targetSize6 = header.IPv6MinimumMTU 1821 // A protocol number that will cause an error response. 1822 reservedProtocol = 254 1823 ) 1824 1825 // IPv4 function to create a IP packet and send it to the stack. 1826 // The packet should generate an error response. We can do that by using an 1827 // unknown transport protocol (254). 1828 rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) []byte { 1829 totalLen := header.IPv4MinimumSize + len(payload) 1830 hdr := prependable.New(header.IPv4MinimumSize) 1831 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 1832 ip.Encode(&header.IPv4Fields{ 1833 TotalLength: uint16(totalLen), 1834 Protocol: reservedProtocol, 1835 TTL: ipv4.DefaultTTL, 1836 SrcAddr: src, 1837 DstAddr: localIPv4Addr, 1838 }) 1839 ip.SetChecksum(^ip.CalculateChecksum()) 1840 buf := buffer.MakeWithData(hdr.View()) 1841 buf.Append(buffer.NewViewWithData(payload)) 1842 // Take a copy before InjectInbound takes ownership of vv 1843 // as vv may be changed during the call. 1844 v := buf.Flatten() 1845 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1846 Payload: buf, 1847 }) 1848 e.InjectInbound(header.IPv4ProtocolNumber, pkt) 1849 pkt.DecRef() 1850 return v 1851 } 1852 1853 // IPv6 function to create a packet and send it to the stack. 1854 // The packet should be errant in a way that causes the stack to send an 1855 // ICMP error response and have enough data to allow the testing of the 1856 // inclusion of the errant packet. Use `unknown next header' to generate 1857 // the error. 1858 rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) []byte { 1859 hdr := prependable.New(header.IPv6MinimumSize) 1860 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 1861 ip.Encode(&header.IPv6Fields{ 1862 PayloadLength: uint16(len(payload)), 1863 TransportProtocol: reservedProtocol, 1864 HopLimit: ipv6.DefaultTTL, 1865 SrcAddr: src, 1866 DstAddr: localIPv6Addr, 1867 }) 1868 buf := buffer.MakeWithData(hdr.View()) 1869 buf.Append(buffer.NewViewWithData(payload)) 1870 // Take a copy before InjectInbound takes ownership of vv 1871 // as vv may be changed during the call. 1872 v := buf.Flatten() 1873 1874 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 1875 Payload: buf, 1876 }) 1877 e.InjectInbound(header.IPv6ProtocolNumber, pkt) 1878 pkt.DecRef() 1879 return v 1880 } 1881 1882 v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload []byte) { 1883 // We already know the entire packet is the right size so we can use its 1884 // length to calculate the right payload size to check. 1885 expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize 1886 p := stack.PayloadSince(pkt.NetworkHeader()) 1887 defer p.Release() 1888 checker.IPv4(t, p, 1889 checker.SrcAddr(localIPv4Addr), 1890 checker.DstAddr(remoteIPv4Addr), 1891 checker.IPv4HeaderLength(header.IPv4MinimumSize), 1892 checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), 1893 checker.ICMPv4( 1894 checker.ICMPv4Checksum(), 1895 checker.ICMPv4Type(header.ICMPv4DstUnreachable), 1896 checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), 1897 checker.ICMPv4Payload(payload[:expectedPayloadLength]), 1898 ), 1899 ) 1900 } 1901 1902 v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload []byte) { 1903 // We already know the entire packet is the right size so we can use its 1904 // length to calculate the right payload size to check. 1905 expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize 1906 p := stack.PayloadSince(pkt.NetworkHeader()) 1907 defer p.Release() 1908 checker.IPv6(t, p, 1909 checker.SrcAddr(localIPv6Addr), 1910 checker.DstAddr(remoteIPv6Addr), 1911 checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), 1912 checker.ICMPv6( 1913 checker.ICMPv6Type(header.ICMPv6ParamProblem), 1914 checker.ICMPv6Code(header.ICMPv6UnknownHeader), 1915 checker.ICMPv6Payload(payload[:expectedPayloadLength]), 1916 ), 1917 ) 1918 } 1919 tests := []struct { 1920 name string 1921 srcAddress tcpip.Address 1922 injector func(*channel.Endpoint, tcpip.Address, []byte) []byte 1923 checker func(*testing.T, *stack.PacketBuffer, []byte) 1924 payloadLength int // Not including IP header. 1925 linkMTU uint32 // Largest IP packet that the link can send as payload. 1926 replyLength int // Total size of IP/ICMP packet expected back. 1927 }{ 1928 { 1929 name: "IPv4 exact match", 1930 srcAddress: remoteIPv4Addr, 1931 injector: rxIPv4Bad, 1932 checker: v4Checker, 1933 payloadLength: targetSize4 - replyHeaderLength4, 1934 linkMTU: targetSize4, 1935 replyLength: targetSize4, 1936 }, 1937 { 1938 name: "IPv4 larger MTU", 1939 srcAddress: remoteIPv4Addr, 1940 injector: rxIPv4Bad, 1941 checker: v4Checker, 1942 payloadLength: targetSize4, 1943 linkMTU: targetSize4 + 1000, 1944 replyLength: targetSize4, 1945 }, 1946 { 1947 name: "IPv4 smaller MTU", 1948 srcAddress: remoteIPv4Addr, 1949 injector: rxIPv4Bad, 1950 checker: v4Checker, 1951 payloadLength: targetSize4, 1952 linkMTU: targetSize4 - 50, 1953 replyLength: targetSize4 - 50, 1954 }, 1955 { 1956 name: "IPv4 payload exceeds", 1957 srcAddress: remoteIPv4Addr, 1958 injector: rxIPv4Bad, 1959 checker: v4Checker, 1960 payloadLength: targetSize4 + 10, 1961 linkMTU: targetSize4, 1962 replyLength: targetSize4, 1963 }, 1964 { 1965 name: "IPv4 1 byte less", 1966 srcAddress: remoteIPv4Addr, 1967 injector: rxIPv4Bad, 1968 checker: v4Checker, 1969 payloadLength: targetSize4 - replyHeaderLength4 - 1, 1970 linkMTU: targetSize4, 1971 replyLength: targetSize4 - 1, 1972 }, 1973 { 1974 name: "IPv4 No payload", 1975 srcAddress: remoteIPv4Addr, 1976 injector: rxIPv4Bad, 1977 checker: v4Checker, 1978 payloadLength: 0, 1979 linkMTU: targetSize4, 1980 replyLength: replyHeaderLength4, 1981 }, 1982 { 1983 name: "IPv6 exact match", 1984 srcAddress: remoteIPv6Addr, 1985 injector: rxIPv6Bad, 1986 checker: v6Checker, 1987 payloadLength: targetSize6 - replyHeaderLength6, 1988 linkMTU: targetSize6, 1989 replyLength: targetSize6, 1990 }, 1991 { 1992 name: "IPv6 larger MTU", 1993 srcAddress: remoteIPv6Addr, 1994 injector: rxIPv6Bad, 1995 checker: v6Checker, 1996 payloadLength: targetSize6, 1997 linkMTU: targetSize6 + 400, 1998 replyLength: targetSize6, 1999 }, 2000 // NB. No "smaller MTU" test here as less than 1280 is not permitted 2001 // in IPv6. 2002 { 2003 name: "IPv6 payload exceeds", 2004 srcAddress: remoteIPv6Addr, 2005 injector: rxIPv6Bad, 2006 checker: v6Checker, 2007 payloadLength: targetSize6, 2008 linkMTU: targetSize6, 2009 replyLength: targetSize6, 2010 }, 2011 { 2012 name: "IPv6 1 byte less", 2013 srcAddress: remoteIPv6Addr, 2014 injector: rxIPv6Bad, 2015 checker: v6Checker, 2016 payloadLength: targetSize6 - replyHeaderLength6 - 1, 2017 linkMTU: targetSize6, 2018 replyLength: targetSize6 - 1, 2019 }, 2020 { 2021 name: "IPv6 no payload", 2022 srcAddress: remoteIPv6Addr, 2023 injector: rxIPv6Bad, 2024 checker: v6Checker, 2025 payloadLength: 0, 2026 linkMTU: targetSize6, 2027 replyLength: replyHeaderLength6, 2028 }, 2029 } 2030 2031 for _, test := range tests { 2032 t.Run(test.name, func(t *testing.T) { 2033 ctx := newTestContext() 2034 defer ctx.cleanup() 2035 s := ctx.s 2036 2037 e := addLinkEndpointToStackWithMTU(t, s, test.linkMTU) 2038 defer e.Close() 2039 // Allocate and initialize the payload view. 2040 payload := make([]byte, test.payloadLength) 2041 for i := 0; i < len(payload); i++ { 2042 payload[i] = uint8(i) 2043 } 2044 // Default routes for IPv4&6 so ICMP can find a route to the remote 2045 // node when attempting to send the ICMP error Reply. 2046 s.SetRouteTable([]tcpip.Route{ 2047 { 2048 Destination: header.IPv4EmptySubnet, 2049 NIC: nicID, 2050 }, 2051 { 2052 Destination: header.IPv6EmptySubnet, 2053 NIC: nicID, 2054 }, 2055 }) 2056 v := test.injector(e, test.srcAddress, payload) 2057 pkt := e.Read() 2058 if pkt == nil { 2059 t.Fatal("expected a packet to be written") 2060 } 2061 if got, want := pkt.Size(), test.replyLength; got != want { 2062 t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) 2063 } 2064 test.checker(t, pkt, v) 2065 pkt.DecRef() 2066 }) 2067 } 2068 } 2069 2070 func TestJoinLeaveAllRoutersGroup(t *testing.T) { 2071 const nicID = 1 2072 2073 tests := []struct { 2074 name string 2075 netProto tcpip.NetworkProtocolNumber 2076 protoFactory stack.NetworkProtocolFactory 2077 allRoutersAddr tcpip.Address 2078 }{ 2079 { 2080 name: "IPv4", 2081 netProto: ipv4.ProtocolNumber, 2082 protoFactory: ipv4.NewProtocol, 2083 allRoutersAddr: header.IPv4AllRoutersGroup, 2084 }, 2085 { 2086 name: "IPv6 Interface Local", 2087 netProto: ipv6.ProtocolNumber, 2088 protoFactory: ipv6.NewProtocol, 2089 allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, 2090 }, 2091 { 2092 name: "IPv6 Link Local", 2093 netProto: ipv6.ProtocolNumber, 2094 protoFactory: ipv6.NewProtocol, 2095 allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, 2096 }, 2097 { 2098 name: "IPv6 Site Local", 2099 netProto: ipv6.ProtocolNumber, 2100 protoFactory: ipv6.NewProtocol, 2101 allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, 2102 }, 2103 } 2104 2105 for _, test := range tests { 2106 t.Run(test.name, func(t *testing.T) { 2107 for _, nicDisabled := range [...]bool{true, false} { 2108 t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { 2109 ctx := newTestContext() 2110 defer ctx.cleanup() 2111 s := ctx.s 2112 2113 opts := stack.NICOptions{Disabled: nicDisabled} 2114 if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { 2115 t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) 2116 } 2117 2118 if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { 2119 t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) 2120 } else if got { 2121 t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) 2122 } 2123 2124 if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { 2125 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) 2126 } 2127 if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { 2128 t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) 2129 } else if !got { 2130 t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) 2131 } 2132 2133 if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { 2134 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) 2135 } 2136 if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { 2137 t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) 2138 } else if got { 2139 t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) 2140 } 2141 }) 2142 } 2143 }) 2144 } 2145 } 2146 2147 func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) { 2148 const nicID = 1 2149 2150 tests := []struct { 2151 name string 2152 proto tcpip.NetworkProtocolNumber 2153 addr tcpip.AddressWithPrefix 2154 payloadOffset int 2155 }{ 2156 { 2157 name: "IPv4", 2158 proto: header.IPv4ProtocolNumber, 2159 addr: localIPv4AddrWithPrefix, 2160 payloadOffset: header.IPv4MinimumSize, 2161 }, 2162 { 2163 name: "IPv6", 2164 proto: header.IPv6ProtocolNumber, 2165 addr: localIPv6AddrWithPrefix, 2166 payloadOffset: 0, 2167 }, 2168 } 2169 2170 for _, test := range tests { 2171 t.Run(test.name, func(t *testing.T) { 2172 ctx := newTestContext() 2173 defer ctx.cleanup() 2174 s := ctx.s 2175 2176 if err := s.CreateNIC(nicID, loopback.New()); err != nil { 2177 t.Fatalf("CreateNIC(%d, _): %s", nicID, err) 2178 } 2179 protocolAddr := tcpip.ProtocolAddress{ 2180 Protocol: test.proto, 2181 AddressWithPrefix: test.addr, 2182 } 2183 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 2184 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) 2185 } 2186 2187 s.SetRouteTable([]tcpip.Route{ 2188 { 2189 Destination: test.addr.Subnet(), 2190 NIC: nicID, 2191 }, 2192 }) 2193 2194 var wq waiter.Queue 2195 we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) 2196 wq.EventRegister(&we) 2197 ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */) 2198 if err != nil { 2199 t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) 2200 } 2201 defer ep.Close() 2202 2203 writeOpts := tcpip.WriteOptions{ 2204 To: &tcpip.FullAddress{ 2205 Addr: test.addr.Address, 2206 }, 2207 } 2208 data := []byte{1, 2, 3, 4} 2209 var r bytes.Reader 2210 r.Reset(data) 2211 if n, err := ep.Write(&r, writeOpts); err != nil { 2212 t.Fatalf("ep.Write(_, _): %s", err) 2213 } else if want := int64(len(data)); n != want { 2214 t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) 2215 } 2216 2217 // Wait for the endpoint to become readable. 2218 <-ch 2219 2220 var w bytes.Buffer 2221 rr, err := ep.Read(&w, tcpip.ReadOptions{ 2222 NeedRemoteAddr: true, 2223 }) 2224 if err != nil { 2225 t.Fatalf("ep.Read(...): %s", err) 2226 } 2227 if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" { 2228 t.Errorf("payload mismatch (-want +got):\n%s", diff) 2229 } 2230 if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" { 2231 t.Errorf("remote addr mismatch (-want +got):\n%s", diff) 2232 } 2233 }) 2234 } 2235 }