github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/ip_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ip_test
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"github.com/SagerNet/gvisor/pkg/sync"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/link/loopback"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp"
    36  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    37  )
    38  
    39  const nicID = 1
    40  
    41  var (
    42  	localIPv4Addr  = testutil.MustParse4("10.0.0.1")
    43  	remoteIPv4Addr = testutil.MustParse4("10.0.0.2")
    44  	ipv4SubnetAddr = testutil.MustParse4("10.0.0.0")
    45  	ipv4SubnetMask = testutil.MustParse4("255.255.255.0")
    46  	ipv4Gateway    = testutil.MustParse4("10.0.0.3")
    47  	localIPv6Addr  = testutil.MustParse6("a00::1")
    48  	remoteIPv6Addr = testutil.MustParse6("a00::2")
    49  	ipv6SubnetAddr = testutil.MustParse6("a00::")
    50  	ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00")
    51  	ipv6Gateway    = testutil.MustParse6("a00::3")
    52  )
    53  
    54  var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
    55  	Address:   localIPv4Addr,
    56  	PrefixLen: 24,
    57  }
    58  
    59  var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
    60  	Address:   localIPv6Addr,
    61  	PrefixLen: 120,
    62  }
    63  
    64  type transportError struct {
    65  	origin tcpip.SockErrOrigin
    66  	typ    uint8
    67  	code   uint8
    68  	info   uint32
    69  	kind   stack.TransportErrorKind
    70  }
    71  
    72  // testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
    73  // The former is used to pretend that it's a link endpoint so that we can
    74  // inspect packets written by the network endpoints. The latter is used to
    75  // pretend that it's the network stack so that it can inspect incoming packets
    76  // that have been handled by the network endpoints.
    77  //
    78  // Packets are checked by comparing their fields/values against the expected
    79  // values stored in the test object itself.
    80  type testObject struct {
    81  	t        *testing.T
    82  	protocol tcpip.TransportProtocolNumber
    83  	contents []byte
    84  	srcAddr  tcpip.Address
    85  	dstAddr  tcpip.Address
    86  	v4       bool
    87  	transErr transportError
    88  
    89  	dataCalls    int
    90  	controlCalls int
    91  	rawCalls     int
    92  }
    93  
    94  // checkValues verifies that the transport protocol, data contents, src & dst
    95  // addresses of a packet match what's expected. If any field doesn't match, the
    96  // test fails.
    97  func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) {
    98  	if protocol != t.protocol {
    99  		t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
   100  	}
   101  
   102  	if srcAddr != t.srcAddr {
   103  		t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
   104  	}
   105  
   106  	if dstAddr != t.dstAddr {
   107  		t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
   108  	}
   109  
   110  	if len(v) != len(t.contents) {
   111  		t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
   112  	}
   113  
   114  	for i := range t.contents {
   115  		if t.contents[i] != v[i] {
   116  			t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
   117  		}
   118  	}
   119  }
   120  
   121  // DeliverTransportPacket is called by network endpoints after parsing incoming
   122  // packets. This is used by the test object to verify that the results of the
   123  // parsing are expected.
   124  func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
   125  	netHdr := pkt.Network()
   126  	t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress())
   127  	t.dataCalls++
   128  	return stack.TransportPacketHandled
   129  }
   130  
   131  // DeliverTransportError is called by network endpoints after parsing
   132  // incoming control (ICMP) packets. This is used by the test object to verify
   133  // that the results of the parsing are expected.
   134  func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) {
   135  	t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local)
   136  	if diff := cmp.Diff(
   137  		t.transErr,
   138  		transportError{
   139  			origin: transErr.Origin(),
   140  			typ:    transErr.Type(),
   141  			code:   transErr.Code(),
   142  			info:   transErr.Info(),
   143  			kind:   transErr.Kind(),
   144  		},
   145  		cmp.AllowUnexported(transportError{}),
   146  	); diff != "" {
   147  		t.t.Errorf("transport error mismatch (-want +got):\n%s", diff)
   148  	}
   149  	t.controlCalls++
   150  }
   151  
   152  func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
   153  	t.rawCalls++
   154  }
   155  
   156  // Attach is only implemented to satisfy the LinkEndpoint interface.
   157  func (*testObject) Attach(stack.NetworkDispatcher) {}
   158  
   159  // IsAttached implements stack.LinkEndpoint.IsAttached.
   160  func (*testObject) IsAttached() bool {
   161  	return true
   162  }
   163  
   164  // MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
   165  // matches the linux loopback MTU.
   166  func (*testObject) MTU() uint32 {
   167  	return 65536
   168  }
   169  
   170  // Capabilities implements stack.LinkEndpoint.Capabilities.
   171  func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
   172  	return 0
   173  }
   174  
   175  // MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
   176  func (*testObject) MaxHeaderLength() uint16 {
   177  	return 0
   178  }
   179  
   180  // LinkAddress returns the link address of this endpoint.
   181  func (*testObject) LinkAddress() tcpip.LinkAddress {
   182  	return ""
   183  }
   184  
   185  // Wait implements stack.LinkEndpoint.Wait.
   186  func (*testObject) Wait() {}
   187  
   188  // WritePacket is called by network endpoints after producing a packet and
   189  // writing it to the link endpoint. This is used by the test object to verify
   190  // that the produced packet is as expected.
   191  func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
   192  	var prot tcpip.TransportProtocolNumber
   193  	var srcAddr tcpip.Address
   194  	var dstAddr tcpip.Address
   195  
   196  	if t.v4 {
   197  		h := header.IPv4(pkt.NetworkHeader().View())
   198  		prot = tcpip.TransportProtocolNumber(h.Protocol())
   199  		srcAddr = h.SourceAddress()
   200  		dstAddr = h.DestinationAddress()
   201  
   202  	} else {
   203  		h := header.IPv6(pkt.NetworkHeader().View())
   204  		prot = tcpip.TransportProtocolNumber(h.NextHeader())
   205  		srcAddr = h.SourceAddress()
   206  		dstAddr = h.DestinationAddress()
   207  	}
   208  	t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr)
   209  	return nil
   210  }
   211  
   212  // WritePackets implements stack.LinkEndpoint.WritePackets.
   213  func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
   214  	panic("not implemented")
   215  }
   216  
   217  // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
   218  func (*testObject) ARPHardwareType() header.ARPHardwareType {
   219  	panic("not implemented")
   220  }
   221  
   222  // AddHeader implements stack.LinkEndpoint.AddHeader.
   223  func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
   224  	panic("not implemented")
   225  }
   226  
   227  func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
   228  	s := stack.New(stack.Options{
   229  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   230  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
   231  	})
   232  	s.CreateNIC(nicID, loopback.New())
   233  	s.AddAddress(nicID, ipv4.ProtocolNumber, local)
   234  	s.SetRouteTable([]tcpip.Route{{
   235  		Destination: header.IPv4EmptySubnet,
   236  		Gateway:     ipv4Gateway,
   237  		NIC:         1,
   238  	}})
   239  
   240  	return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
   241  }
   242  
   243  func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
   244  	s := stack.New(stack.Options{
   245  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv6.NewProtocol},
   246  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
   247  	})
   248  	s.CreateNIC(nicID, loopback.New())
   249  	s.AddAddress(nicID, ipv6.ProtocolNumber, local)
   250  	s.SetRouteTable([]tcpip.Route{{
   251  		Destination: header.IPv6EmptySubnet,
   252  		Gateway:     ipv6Gateway,
   253  		NIC:         1,
   254  	}})
   255  
   256  	return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
   257  }
   258  
   259  func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) {
   260  	t.Helper()
   261  
   262  	s := stack.New(stack.Options{
   263  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   264  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
   265  	})
   266  	e := channel.New(1, mtu, "")
   267  	if err := s.CreateNIC(nicID, e); err != nil {
   268  		t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
   269  	}
   270  
   271  	v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
   272  	if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
   273  		t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
   274  	}
   275  
   276  	v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
   277  	if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
   278  		t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
   279  	}
   280  
   281  	return s, e
   282  }
   283  
   284  func buildDummyStack(t *testing.T) *stack.Stack {
   285  	t.Helper()
   286  
   287  	s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
   288  	return s
   289  }
   290  
   291  var _ stack.NetworkInterface = (*testInterface)(nil)
   292  
   293  type testInterface struct {
   294  	testObject
   295  
   296  	mu struct {
   297  		sync.RWMutex
   298  		disabled bool
   299  	}
   300  }
   301  
   302  func (*testInterface) ID() tcpip.NICID {
   303  	return nicID
   304  }
   305  
   306  func (*testInterface) IsLoopback() bool {
   307  	return false
   308  }
   309  
   310  func (*testInterface) Name() string {
   311  	return ""
   312  }
   313  
   314  func (t *testInterface) Enabled() bool {
   315  	t.mu.RLock()
   316  	defer t.mu.RUnlock()
   317  	return !t.mu.disabled
   318  }
   319  
   320  func (*testInterface) Promiscuous() bool {
   321  	return false
   322  }
   323  
   324  func (*testInterface) Spoofing() bool {
   325  	return false
   326  }
   327  
   328  func (t *testInterface) setEnabled(v bool) {
   329  	t.mu.Lock()
   330  	defer t.mu.Unlock()
   331  	t.mu.disabled = !v
   332  }
   333  
   334  func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
   335  	return &tcpip.ErrNotSupported{}
   336  }
   337  
   338  func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
   339  	return nil
   340  }
   341  
   342  func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
   343  	return nil
   344  }
   345  
   346  func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
   347  	return tcpip.AddressWithPrefix{}, nil
   348  }
   349  
   350  func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
   351  	return false
   352  }
   353  
   354  func TestSourceAddressValidation(t *testing.T) {
   355  	rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
   356  		totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
   357  		hdr := buffer.NewPrependable(totalLen)
   358  		pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
   359  		pkt.SetType(header.ICMPv4Echo)
   360  		pkt.SetCode(0)
   361  		pkt.SetChecksum(0)
   362  		pkt.SetChecksum(^header.Checksum(pkt, 0))
   363  		ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
   364  		ip.Encode(&header.IPv4Fields{
   365  			TotalLength: uint16(totalLen),
   366  			Protocol:    uint8(icmp.ProtocolNumber4),
   367  			TTL:         ipv4.DefaultTTL,
   368  			SrcAddr:     src,
   369  			DstAddr:     localIPv4Addr,
   370  		})
   371  		ip.SetChecksum(^ip.CalculateChecksum())
   372  
   373  		e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   374  			Data: hdr.View().ToVectorisedView(),
   375  		}))
   376  	}
   377  
   378  	rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
   379  		totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
   380  		hdr := buffer.NewPrependable(totalLen)
   381  		pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
   382  		pkt.SetType(header.ICMPv6EchoRequest)
   383  		pkt.SetCode(0)
   384  		pkt.SetChecksum(0)
   385  		pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   386  			Header: pkt,
   387  			Src:    src,
   388  			Dst:    localIPv6Addr,
   389  		}))
   390  		ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   391  		ip.Encode(&header.IPv6Fields{
   392  			PayloadLength:     header.ICMPv6MinimumSize,
   393  			TransportProtocol: icmp.ProtocolNumber6,
   394  			HopLimit:          ipv6.DefaultTTL,
   395  			SrcAddr:           src,
   396  			DstAddr:           localIPv6Addr,
   397  		})
   398  		e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   399  			Data: hdr.View().ToVectorisedView(),
   400  		}))
   401  	}
   402  
   403  	tests := []struct {
   404  		name       string
   405  		srcAddress tcpip.Address
   406  		rxICMP     func(*channel.Endpoint, tcpip.Address)
   407  		valid      bool
   408  	}{
   409  		{
   410  			name:       "IPv4 valid",
   411  			srcAddress: "\x01\x02\x03\x04",
   412  			rxICMP:     rxIPv4ICMP,
   413  			valid:      true,
   414  		},
   415  		{
   416  			name:       "IPv6 valid",
   417  			srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
   418  			rxICMP:     rxIPv6ICMP,
   419  			valid:      true,
   420  		},
   421  		{
   422  			name:       "IPv4 unspecified",
   423  			srcAddress: header.IPv4Any,
   424  			rxICMP:     rxIPv4ICMP,
   425  			valid:      true,
   426  		},
   427  		{
   428  			name:       "IPv6 unspecified",
   429  			srcAddress: header.IPv4Any,
   430  			rxICMP:     rxIPv6ICMP,
   431  			valid:      true,
   432  		},
   433  		{
   434  			name:       "IPv4 multicast",
   435  			srcAddress: "\xe0\x00\x00\x01",
   436  			rxICMP:     rxIPv4ICMP,
   437  			valid:      false,
   438  		},
   439  		{
   440  			name:       "IPv6 multicast",
   441  			srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
   442  			rxICMP:     rxIPv6ICMP,
   443  			valid:      false,
   444  		},
   445  		{
   446  			name:       "IPv4 broadcast",
   447  			srcAddress: header.IPv4Broadcast,
   448  			rxICMP:     rxIPv4ICMP,
   449  			valid:      false,
   450  		},
   451  		{
   452  			name: "IPv4 subnet broadcast",
   453  			srcAddress: func() tcpip.Address {
   454  				subnet := localIPv4AddrWithPrefix.Subnet()
   455  				return subnet.Broadcast()
   456  			}(),
   457  			rxICMP: rxIPv4ICMP,
   458  			valid:  false,
   459  		},
   460  	}
   461  
   462  	for _, test := range tests {
   463  		t.Run(test.name, func(t *testing.T) {
   464  			s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
   465  			test.rxICMP(e, test.srcAddress)
   466  
   467  			var wantValid uint64
   468  			if test.valid {
   469  				wantValid = 1
   470  			}
   471  
   472  			if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
   473  				t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
   474  			}
   475  			if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
   476  				t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
   477  			}
   478  		})
   479  	}
   480  }
   481  
   482  func TestEnableWhenNICDisabled(t *testing.T) {
   483  	tests := []struct {
   484  		name            string
   485  		protocolFactory stack.NetworkProtocolFactory
   486  		protoNum        tcpip.NetworkProtocolNumber
   487  	}{
   488  		{
   489  			name:            "IPv4",
   490  			protocolFactory: ipv4.NewProtocol,
   491  			protoNum:        ipv4.ProtocolNumber,
   492  		},
   493  		{
   494  			name:            "IPv6",
   495  			protocolFactory: ipv6.NewProtocol,
   496  			protoNum:        ipv6.ProtocolNumber,
   497  		},
   498  	}
   499  
   500  	for _, test := range tests {
   501  		t.Run(test.name, func(t *testing.T) {
   502  			var nic testInterface
   503  			nic.setEnabled(false)
   504  
   505  			s := stack.New(stack.Options{
   506  				NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
   507  			})
   508  			p := s.NetworkProtocolInstance(test.protoNum)
   509  
   510  			// We pass nil for all parameters except the NetworkInterface and Stack
   511  			// since Enable only depends on these.
   512  			ep := p.NewEndpoint(&nic, nil)
   513  
   514  			// The endpoint should initially be disabled, regardless the NIC's enabled
   515  			// status.
   516  			if ep.Enabled() {
   517  				t.Fatal("got ep.Enabled() = true, want = false")
   518  			}
   519  			nic.setEnabled(true)
   520  			if ep.Enabled() {
   521  				t.Fatal("got ep.Enabled() = true, want = false")
   522  			}
   523  
   524  			// Attempting to enable the endpoint while the NIC is disabled should
   525  			// fail.
   526  			nic.setEnabled(false)
   527  			err := ep.Enable()
   528  			if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
   529  				t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
   530  			}
   531  			// ep should consider the NIC's enabled status when determining its own
   532  			// enabled status so we "enable" the NIC to read just the endpoint's
   533  			// enabled status.
   534  			nic.setEnabled(true)
   535  			if ep.Enabled() {
   536  				t.Fatal("got ep.Enabled() = true, want = false")
   537  			}
   538  
   539  			// Enabling the interface after the NIC has been enabled should succeed.
   540  			if err := ep.Enable(); err != nil {
   541  				t.Fatalf("ep.Enable(): %s", err)
   542  			}
   543  			if !ep.Enabled() {
   544  				t.Fatal("got ep.Enabled() = false, want = true")
   545  			}
   546  
   547  			// ep should consider the NIC's enabled status when determining its own
   548  			// enabled status.
   549  			nic.setEnabled(false)
   550  			if ep.Enabled() {
   551  				t.Fatal("got ep.Enabled() = true, want = false")
   552  			}
   553  
   554  			// Disabling the endpoint when the NIC is enabled should make the endpoint
   555  			// disabled.
   556  			nic.setEnabled(true)
   557  			ep.Disable()
   558  			if ep.Enabled() {
   559  				t.Fatal("got ep.Enabled() = true, want = false")
   560  			}
   561  		})
   562  	}
   563  }
   564  
   565  func TestIPv4Send(t *testing.T) {
   566  	s := buildDummyStack(t)
   567  	proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
   568  	nic := testInterface{
   569  		testObject: testObject{
   570  			t:  t,
   571  			v4: true,
   572  		},
   573  	}
   574  	ep := proto.NewEndpoint(&nic, nil)
   575  	defer ep.Close()
   576  
   577  	// Allocate and initialize the payload view.
   578  	payload := buffer.NewView(100)
   579  	for i := 0; i < len(payload); i++ {
   580  		payload[i] = uint8(i)
   581  	}
   582  
   583  	// Setup the packet buffer.
   584  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   585  		ReserveHeaderBytes: int(ep.MaxHeaderLength()),
   586  		Data:               payload.ToVectorisedView(),
   587  	})
   588  
   589  	// Issue the write.
   590  	nic.testObject.protocol = 123
   591  	nic.testObject.srcAddr = localIPv4Addr
   592  	nic.testObject.dstAddr = remoteIPv4Addr
   593  	nic.testObject.contents = payload
   594  
   595  	r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
   596  	if err != nil {
   597  		t.Fatalf("could not find route: %v", err)
   598  	}
   599  	if err := ep.WritePacket(r, stack.NetworkHeaderParams{
   600  		Protocol: 123,
   601  		TTL:      123,
   602  		TOS:      stack.DefaultTOS,
   603  	}, pkt); err != nil {
   604  		t.Fatalf("WritePacket failed: %v", err)
   605  	}
   606  }
   607  
   608  func TestReceive(t *testing.T) {
   609  	tests := []struct {
   610  		name         string
   611  		protoFactory stack.NetworkProtocolFactory
   612  		protoNum     tcpip.NetworkProtocolNumber
   613  		v4           bool
   614  		epAddr       tcpip.AddressWithPrefix
   615  		handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
   616  	}{
   617  		{
   618  			name:         "IPv4",
   619  			protoFactory: ipv4.NewProtocol,
   620  			protoNum:     ipv4.ProtocolNumber,
   621  			v4:           true,
   622  			epAddr:       localIPv4Addr.WithPrefix(),
   623  			handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
   624  				const totalLen = header.IPv4MinimumSize + 30 /* payload length */
   625  
   626  				view := buffer.NewView(totalLen)
   627  				ip := header.IPv4(view)
   628  				ip.Encode(&header.IPv4Fields{
   629  					TotalLength: totalLen,
   630  					TTL:         ipv4.DefaultTTL,
   631  					Protocol:    10,
   632  					SrcAddr:     remoteIPv4Addr,
   633  					DstAddr:     localIPv4Addr,
   634  				})
   635  				ip.SetChecksum(^ip.CalculateChecksum())
   636  
   637  				// Make payload be non-zero.
   638  				for i := header.IPv4MinimumSize; i < len(view); i++ {
   639  					view[i] = uint8(i)
   640  				}
   641  
   642  				// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
   643  				nic.testObject.protocol = 10
   644  				nic.testObject.srcAddr = remoteIPv4Addr
   645  				nic.testObject.dstAddr = localIPv4Addr
   646  				nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
   647  
   648  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   649  					Data: view.ToVectorisedView(),
   650  				})
   651  				ep.HandlePacket(pkt)
   652  			},
   653  		},
   654  		{
   655  			name:         "IPv6",
   656  			protoFactory: ipv6.NewProtocol,
   657  			protoNum:     ipv6.ProtocolNumber,
   658  			v4:           false,
   659  			epAddr:       localIPv6Addr.WithPrefix(),
   660  			handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
   661  				const payloadLen = 30
   662  				view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
   663  				ip := header.IPv6(view)
   664  				ip.Encode(&header.IPv6Fields{
   665  					PayloadLength:     payloadLen,
   666  					TransportProtocol: 10,
   667  					HopLimit:          ipv6.DefaultTTL,
   668  					SrcAddr:           remoteIPv6Addr,
   669  					DstAddr:           localIPv6Addr,
   670  				})
   671  
   672  				// Make payload be non-zero.
   673  				for i := header.IPv6MinimumSize; i < len(view); i++ {
   674  					view[i] = uint8(i)
   675  				}
   676  
   677  				// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
   678  				nic.testObject.protocol = 10
   679  				nic.testObject.srcAddr = remoteIPv6Addr
   680  				nic.testObject.dstAddr = localIPv6Addr
   681  				nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
   682  
   683  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   684  					Data: view.ToVectorisedView(),
   685  				})
   686  				ep.HandlePacket(pkt)
   687  			},
   688  		},
   689  	}
   690  
   691  	for _, test := range tests {
   692  		t.Run(test.name, func(t *testing.T) {
   693  			s := stack.New(stack.Options{
   694  				NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
   695  			})
   696  			nic := testInterface{
   697  				testObject: testObject{
   698  					t:  t,
   699  					v4: test.v4,
   700  				},
   701  			}
   702  			ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject)
   703  			defer ep.Close()
   704  
   705  			if err := ep.Enable(); err != nil {
   706  				t.Fatalf("ep.Enable(): %s", err)
   707  			}
   708  
   709  			addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
   710  			if !ok {
   711  				t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
   712  			}
   713  			if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
   714  				t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
   715  			} else {
   716  				ep.DecRef()
   717  			}
   718  
   719  			stat := s.Stats().IP.PacketsReceived
   720  			if got := stat.Value(); got != 0 {
   721  				t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
   722  			}
   723  			test.handlePacket(t, ep, &nic)
   724  			if nic.testObject.dataCalls != 1 {
   725  				t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
   726  			}
   727  			if nic.testObject.rawCalls != 1 {
   728  				t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
   729  			}
   730  			if got := stat.Value(); got != 1 {
   731  				t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
   732  			}
   733  		})
   734  	}
   735  }
   736  
   737  func TestIPv4ReceiveControl(t *testing.T) {
   738  	const (
   739  		mtu     = 0xbeef - header.IPv4MinimumSize
   740  		dataLen = 8
   741  	)
   742  
   743  	cases := []struct {
   744  		name           string
   745  		expectedCount  int
   746  		fragmentOffset uint16
   747  		code           header.ICMPv4Code
   748  		transErr       transportError
   749  		trunc          int
   750  	}{
   751  		{
   752  			name:           "FragmentationNeeded",
   753  			expectedCount:  1,
   754  			fragmentOffset: 0,
   755  			code:           header.ICMPv4FragmentationNeeded,
   756  			transErr: transportError{
   757  				origin: tcpip.SockExtErrorOriginICMP,
   758  				typ:    uint8(header.ICMPv4DstUnreachable),
   759  				code:   uint8(header.ICMPv4FragmentationNeeded),
   760  				info:   mtu,
   761  				kind:   stack.PacketTooBigTransportError,
   762  			},
   763  			trunc: 0,
   764  		},
   765  		{
   766  			name:           "Truncated (missing IPv4 header)",
   767  			expectedCount:  0,
   768  			fragmentOffset: 0,
   769  			code:           header.ICMPv4FragmentationNeeded,
   770  			trunc:          header.IPv4MinimumSize + header.ICMPv4MinimumSize,
   771  		},
   772  		{
   773  			name:           "Truncated (partial offending packet's IP header)",
   774  			expectedCount:  0,
   775  			fragmentOffset: 0,
   776  			code:           header.ICMPv4FragmentationNeeded,
   777  			trunc:          header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1,
   778  		},
   779  		{
   780  			name:           "Truncated (partial offending packet's data)",
   781  			expectedCount:  0,
   782  			fragmentOffset: 0,
   783  			code:           header.ICMPv4FragmentationNeeded,
   784  			trunc:          header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1,
   785  		},
   786  		{
   787  			name:           "Port unreachable",
   788  			expectedCount:  1,
   789  			fragmentOffset: 0,
   790  			code:           header.ICMPv4PortUnreachable,
   791  			transErr: transportError{
   792  				origin: tcpip.SockExtErrorOriginICMP,
   793  				typ:    uint8(header.ICMPv4DstUnreachable),
   794  				code:   uint8(header.ICMPv4PortUnreachable),
   795  				kind:   stack.DestinationPortUnreachableTransportError,
   796  			},
   797  			trunc: 0,
   798  		},
   799  		{
   800  			name:           "Non-zero fragment offset",
   801  			expectedCount:  0,
   802  			fragmentOffset: 100,
   803  			code:           header.ICMPv4PortUnreachable,
   804  			trunc:          0,
   805  		},
   806  		{
   807  			name:           "Zero-length packet",
   808  			expectedCount:  0,
   809  			fragmentOffset: 100,
   810  			code:           header.ICMPv4PortUnreachable,
   811  			trunc:          2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen,
   812  		},
   813  	}
   814  	for _, c := range cases {
   815  		t.Run(c.name, func(t *testing.T) {
   816  			s := buildDummyStack(t)
   817  			proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
   818  			nic := testInterface{
   819  				testObject: testObject{
   820  					t: t,
   821  				},
   822  			}
   823  			ep := proto.NewEndpoint(&nic, &nic.testObject)
   824  			defer ep.Close()
   825  
   826  			if err := ep.Enable(); err != nil {
   827  				t.Fatalf("ep.Enable(): %s", err)
   828  			}
   829  
   830  			const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
   831  			view := buffer.NewView(dataOffset + dataLen)
   832  
   833  			// Create the outer IPv4 header.
   834  			ip := header.IPv4(view)
   835  			ip.Encode(&header.IPv4Fields{
   836  				TotalLength: uint16(len(view) - c.trunc),
   837  				TTL:         20,
   838  				Protocol:    uint8(header.ICMPv4ProtocolNumber),
   839  				SrcAddr:     "\x0a\x00\x00\xbb",
   840  				DstAddr:     localIPv4Addr,
   841  			})
   842  			ip.SetChecksum(^ip.CalculateChecksum())
   843  
   844  			// Create the ICMP header.
   845  			icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
   846  			icmp.SetType(header.ICMPv4DstUnreachable)
   847  			icmp.SetCode(c.code)
   848  			icmp.SetIdent(0xdead)
   849  			icmp.SetSequence(0xbeef)
   850  
   851  			// Create the inner IPv4 header.
   852  			ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
   853  			ip.Encode(&header.IPv4Fields{
   854  				TotalLength:    100,
   855  				TTL:            20,
   856  				Protocol:       10,
   857  				FragmentOffset: c.fragmentOffset,
   858  				SrcAddr:        localIPv4Addr,
   859  				DstAddr:        remoteIPv4Addr,
   860  			})
   861  			ip.SetChecksum(^ip.CalculateChecksum())
   862  
   863  			// Make payload be non-zero.
   864  			for i := dataOffset; i < len(view); i++ {
   865  				view[i] = uint8(i)
   866  			}
   867  
   868  			icmp.SetChecksum(0)
   869  			checksum := ^header.Checksum(icmp, 0 /* initial */)
   870  			icmp.SetChecksum(checksum)
   871  
   872  			// Give packet to IPv4 endpoint, dispatcher will validate that
   873  			// it's ok.
   874  			nic.testObject.protocol = 10
   875  			nic.testObject.srcAddr = remoteIPv4Addr
   876  			nic.testObject.dstAddr = localIPv4Addr
   877  			nic.testObject.contents = view[dataOffset:]
   878  			nic.testObject.transErr = c.transErr
   879  
   880  			addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
   881  			if !ok {
   882  				t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
   883  			}
   884  			addr := localIPv4Addr.WithPrefix()
   885  			if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
   886  				t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
   887  			} else {
   888  				ep.DecRef()
   889  			}
   890  
   891  			pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
   892  			ep.HandlePacket(pkt)
   893  			if want := c.expectedCount; nic.testObject.controlCalls != want {
   894  				t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
   895  			}
   896  		})
   897  	}
   898  }
   899  
   900  func TestIPv4FragmentationReceive(t *testing.T) {
   901  	s := stack.New(stack.Options{
   902  		NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   903  	})
   904  	proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
   905  	nic := testInterface{
   906  		testObject: testObject{
   907  			t:  t,
   908  			v4: true,
   909  		},
   910  	}
   911  	ep := proto.NewEndpoint(&nic, &nic.testObject)
   912  	defer ep.Close()
   913  
   914  	if err := ep.Enable(); err != nil {
   915  		t.Fatalf("ep.Enable(): %s", err)
   916  	}
   917  
   918  	totalLen := header.IPv4MinimumSize + 24
   919  
   920  	frag1 := buffer.NewView(totalLen)
   921  	ip1 := header.IPv4(frag1)
   922  	ip1.Encode(&header.IPv4Fields{
   923  		TotalLength:    uint16(totalLen),
   924  		TTL:            20,
   925  		Protocol:       10,
   926  		FragmentOffset: 0,
   927  		Flags:          header.IPv4FlagMoreFragments,
   928  		SrcAddr:        remoteIPv4Addr,
   929  		DstAddr:        localIPv4Addr,
   930  	})
   931  	ip1.SetChecksum(^ip1.CalculateChecksum())
   932  
   933  	// Make payload be non-zero.
   934  	for i := header.IPv4MinimumSize; i < totalLen; i++ {
   935  		frag1[i] = uint8(i)
   936  	}
   937  
   938  	frag2 := buffer.NewView(totalLen)
   939  	ip2 := header.IPv4(frag2)
   940  	ip2.Encode(&header.IPv4Fields{
   941  		TotalLength:    uint16(totalLen),
   942  		TTL:            20,
   943  		Protocol:       10,
   944  		FragmentOffset: 24,
   945  		SrcAddr:        remoteIPv4Addr,
   946  		DstAddr:        localIPv4Addr,
   947  	})
   948  	ip2.SetChecksum(^ip2.CalculateChecksum())
   949  
   950  	// Make payload be non-zero.
   951  	for i := header.IPv4MinimumSize; i < totalLen; i++ {
   952  		frag2[i] = uint8(i)
   953  	}
   954  
   955  	// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
   956  	nic.testObject.protocol = 10
   957  	nic.testObject.srcAddr = remoteIPv4Addr
   958  	nic.testObject.dstAddr = localIPv4Addr
   959  	nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
   960  
   961  	// Send first segment.
   962  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   963  		Data: frag1.ToVectorisedView(),
   964  	})
   965  
   966  	addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
   967  	if !ok {
   968  		t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
   969  	}
   970  	addr := localIPv4Addr.WithPrefix()
   971  	if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
   972  		t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
   973  	} else {
   974  		ep.DecRef()
   975  	}
   976  
   977  	ep.HandlePacket(pkt)
   978  	if nic.testObject.dataCalls != 0 {
   979  		t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls)
   980  	}
   981  	if nic.testObject.rawCalls != 0 {
   982  		t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls)
   983  	}
   984  
   985  	// Send second segment.
   986  	pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
   987  		Data: frag2.ToVectorisedView(),
   988  	})
   989  	ep.HandlePacket(pkt)
   990  	if nic.testObject.dataCalls != 1 {
   991  		t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
   992  	}
   993  	if nic.testObject.rawCalls != 1 {
   994  		t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
   995  	}
   996  }
   997  
   998  func TestIPv6Send(t *testing.T) {
   999  	s := buildDummyStack(t)
  1000  	proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
  1001  	nic := testInterface{
  1002  		testObject: testObject{
  1003  			t: t,
  1004  		},
  1005  	}
  1006  	ep := proto.NewEndpoint(&nic, nil)
  1007  	defer ep.Close()
  1008  
  1009  	if err := ep.Enable(); err != nil {
  1010  		t.Fatalf("ep.Enable(): %s", err)
  1011  	}
  1012  
  1013  	// Allocate and initialize the payload view.
  1014  	payload := buffer.NewView(100)
  1015  	for i := 0; i < len(payload); i++ {
  1016  		payload[i] = uint8(i)
  1017  	}
  1018  
  1019  	// Setup the packet buffer.
  1020  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1021  		ReserveHeaderBytes: int(ep.MaxHeaderLength()),
  1022  		Data:               payload.ToVectorisedView(),
  1023  	})
  1024  
  1025  	// Issue the write.
  1026  	nic.testObject.protocol = 123
  1027  	nic.testObject.srcAddr = localIPv6Addr
  1028  	nic.testObject.dstAddr = remoteIPv6Addr
  1029  	nic.testObject.contents = payload
  1030  
  1031  	r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
  1032  	if err != nil {
  1033  		t.Fatalf("could not find route: %v", err)
  1034  	}
  1035  	if err := ep.WritePacket(r, stack.NetworkHeaderParams{
  1036  		Protocol: 123,
  1037  		TTL:      123,
  1038  		TOS:      stack.DefaultTOS,
  1039  	}, pkt); err != nil {
  1040  		t.Fatalf("WritePacket failed: %v", err)
  1041  	}
  1042  }
  1043  
  1044  func TestIPv6ReceiveControl(t *testing.T) {
  1045  	const (
  1046  		mtu          = 0xffff
  1047  		outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
  1048  		dataLen      = 8
  1049  	)
  1050  
  1051  	newUint16 := func(v uint16) *uint16 { return &v }
  1052  
  1053  	portUnreachableTransErr := transportError{
  1054  		origin: tcpip.SockExtErrorOriginICMP6,
  1055  		typ:    uint8(header.ICMPv6DstUnreachable),
  1056  		code:   uint8(header.ICMPv6PortUnreachable),
  1057  		kind:   stack.DestinationPortUnreachableTransportError,
  1058  	}
  1059  
  1060  	cases := []struct {
  1061  		name           string
  1062  		expectedCount  int
  1063  		fragmentOffset *uint16
  1064  		typ            header.ICMPv6Type
  1065  		code           header.ICMPv6Code
  1066  		transErr       transportError
  1067  		trunc          int
  1068  	}{
  1069  		{
  1070  			name:           "PacketTooBig",
  1071  			expectedCount:  1,
  1072  			fragmentOffset: nil,
  1073  			typ:            header.ICMPv6PacketTooBig,
  1074  			code:           header.ICMPv6UnusedCode,
  1075  			transErr: transportError{
  1076  				origin: tcpip.SockExtErrorOriginICMP6,
  1077  				typ:    uint8(header.ICMPv6PacketTooBig),
  1078  				code:   uint8(header.ICMPv6UnusedCode),
  1079  				info:   mtu,
  1080  				kind:   stack.PacketTooBigTransportError,
  1081  			},
  1082  			trunc: 0,
  1083  		},
  1084  		{
  1085  			name:           "Truncated (missing offending packet's IPv6 header)",
  1086  			expectedCount:  0,
  1087  			fragmentOffset: nil,
  1088  			typ:            header.ICMPv6PacketTooBig,
  1089  			code:           header.ICMPv6UnusedCode,
  1090  			trunc:          header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize,
  1091  		},
  1092  		{
  1093  			name:           "Truncated PacketTooBig (partial offending packet's IPv6 header)",
  1094  			expectedCount:  0,
  1095  			fragmentOffset: nil,
  1096  			typ:            header.ICMPv6PacketTooBig,
  1097  			code:           header.ICMPv6UnusedCode,
  1098  			trunc:          header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1,
  1099  		},
  1100  		{
  1101  			name:           "Truncated (partial offending packet's data)",
  1102  			expectedCount:  0,
  1103  			fragmentOffset: nil,
  1104  			typ:            header.ICMPv6PacketTooBig,
  1105  			code:           header.ICMPv6UnusedCode,
  1106  			trunc:          header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1,
  1107  		},
  1108  		{
  1109  			name:           "Port unreachable",
  1110  			expectedCount:  1,
  1111  			fragmentOffset: nil,
  1112  			typ:            header.ICMPv6DstUnreachable,
  1113  			code:           header.ICMPv6PortUnreachable,
  1114  			transErr:       portUnreachableTransErr,
  1115  			trunc:          0,
  1116  		},
  1117  		{
  1118  			name:           "Truncated DstPortUnreachable (partial offending packet's IP header)",
  1119  			expectedCount:  0,
  1120  			fragmentOffset: nil,
  1121  			typ:            header.ICMPv6DstUnreachable,
  1122  			code:           header.ICMPv6PortUnreachable,
  1123  			trunc:          header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1,
  1124  		},
  1125  		{
  1126  			name:           "DstPortUnreachable for Fragmented, zero offset",
  1127  			expectedCount:  1,
  1128  			fragmentOffset: newUint16(0),
  1129  			typ:            header.ICMPv6DstUnreachable,
  1130  			code:           header.ICMPv6PortUnreachable,
  1131  			transErr:       portUnreachableTransErr,
  1132  			trunc:          0,
  1133  		},
  1134  		{
  1135  			name:           "DstPortUnreachable for Non-zero fragment offset",
  1136  			expectedCount:  0,
  1137  			fragmentOffset: newUint16(100),
  1138  			typ:            header.ICMPv6DstUnreachable,
  1139  			code:           header.ICMPv6PortUnreachable,
  1140  			transErr:       portUnreachableTransErr,
  1141  			trunc:          0,
  1142  		},
  1143  		{
  1144  			name:           "Zero-length packet",
  1145  			expectedCount:  0,
  1146  			fragmentOffset: nil,
  1147  			typ:            header.ICMPv6DstUnreachable,
  1148  			code:           header.ICMPv6PortUnreachable,
  1149  			trunc:          2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen,
  1150  		},
  1151  	}
  1152  	for _, c := range cases {
  1153  		t.Run(c.name, func(t *testing.T) {
  1154  			s := buildDummyStack(t)
  1155  			proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
  1156  			nic := testInterface{
  1157  				testObject: testObject{
  1158  					t: t,
  1159  				},
  1160  			}
  1161  			ep := proto.NewEndpoint(&nic, &nic.testObject)
  1162  			defer ep.Close()
  1163  
  1164  			if err := ep.Enable(); err != nil {
  1165  				t.Fatalf("ep.Enable(): %s", err)
  1166  			}
  1167  
  1168  			dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
  1169  			if c.fragmentOffset != nil {
  1170  				dataOffset += header.IPv6FragmentHeaderSize
  1171  			}
  1172  			view := buffer.NewView(dataOffset + dataLen)
  1173  
  1174  			// Create the outer IPv6 header.
  1175  			ip := header.IPv6(view)
  1176  			ip.Encode(&header.IPv6Fields{
  1177  				PayloadLength:     uint16(len(view) - header.IPv6MinimumSize - c.trunc),
  1178  				TransportProtocol: header.ICMPv6ProtocolNumber,
  1179  				HopLimit:          20,
  1180  				SrcAddr:           outerSrcAddr,
  1181  				DstAddr:           localIPv6Addr,
  1182  			})
  1183  
  1184  			// Create the ICMP header.
  1185  			icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
  1186  			icmp.SetType(c.typ)
  1187  			icmp.SetCode(c.code)
  1188  			icmp.SetIdent(0xdead)
  1189  			icmp.SetSequence(0xbeef)
  1190  
  1191  			var extHdrs header.IPv6ExtHdrSerializer
  1192  			// Build the fragmentation header if needed.
  1193  			if c.fragmentOffset != nil {
  1194  				extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
  1195  					FragmentOffset: *c.fragmentOffset,
  1196  					M:              true,
  1197  					Identification: 0x12345678,
  1198  				})
  1199  			}
  1200  
  1201  			// Create the inner IPv6 header.
  1202  			ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
  1203  			ip.Encode(&header.IPv6Fields{
  1204  				PayloadLength:     100,
  1205  				TransportProtocol: 10,
  1206  				HopLimit:          20,
  1207  				SrcAddr:           localIPv6Addr,
  1208  				DstAddr:           remoteIPv6Addr,
  1209  				ExtensionHeaders:  extHdrs,
  1210  			})
  1211  
  1212  			// Make payload be non-zero.
  1213  			for i := dataOffset; i < len(view); i++ {
  1214  				view[i] = uint8(i)
  1215  			}
  1216  
  1217  			// Give packet to IPv6 endpoint, dispatcher will validate that
  1218  			// it's ok.
  1219  			nic.testObject.protocol = 10
  1220  			nic.testObject.srcAddr = remoteIPv6Addr
  1221  			nic.testObject.dstAddr = localIPv6Addr
  1222  			nic.testObject.contents = view[dataOffset:]
  1223  			nic.testObject.transErr = c.transErr
  1224  
  1225  			// Set ICMPv6 checksum.
  1226  			icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
  1227  				Header: icmp,
  1228  				Src:    outerSrcAddr,
  1229  				Dst:    localIPv6Addr,
  1230  			}))
  1231  
  1232  			addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
  1233  			if !ok {
  1234  				t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
  1235  			}
  1236  			addr := localIPv6Addr.WithPrefix()
  1237  			if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
  1238  				t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
  1239  			} else {
  1240  				ep.DecRef()
  1241  			}
  1242  			pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
  1243  			ep.HandlePacket(pkt)
  1244  			if want := c.expectedCount; nic.testObject.controlCalls != want {
  1245  				t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
  1246  			}
  1247  		})
  1248  	}
  1249  }
  1250  
  1251  // truncatedPacket returns a PacketBuffer based on a truncated view. If view,
  1252  // after truncation, is large enough to hold a network header, it makes part of
  1253  // view the packet's NetworkHeader and the rest its Data. Otherwise all of view
  1254  // becomes Data.
  1255  func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
  1256  	v := view[:len(view)-trunc]
  1257  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1258  		Data: v.ToVectorisedView(),
  1259  	})
  1260  	return pkt
  1261  }
  1262  
  1263  func TestWriteHeaderIncludedPacket(t *testing.T) {
  1264  	const (
  1265  		nicID          = 1
  1266  		transportProto = 5
  1267  
  1268  		dataLen = 4
  1269  	)
  1270  
  1271  	dataBuf := [dataLen]byte{1, 2, 3, 4}
  1272  	data := dataBuf[:]
  1273  
  1274  	ipv4Options := header.IPv4OptionsSerializer{
  1275  		&header.IPv4SerializableListEndOption{},
  1276  		&header.IPv4SerializableNOPOption{},
  1277  		&header.IPv4SerializableListEndOption{},
  1278  		&header.IPv4SerializableNOPOption{},
  1279  	}
  1280  
  1281  	expectOptions := header.IPv4Options{
  1282  		byte(header.IPv4OptionListEndType),
  1283  		byte(header.IPv4OptionNOPType),
  1284  		byte(header.IPv4OptionListEndType),
  1285  		byte(header.IPv4OptionNOPType),
  1286  	}
  1287  
  1288  	ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
  1289  	ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
  1290  
  1291  	var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
  1292  	ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
  1293  	if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
  1294  		t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
  1295  	}
  1296  	if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
  1297  		t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1298  	}
  1299  
  1300  	tests := []struct {
  1301  		name         string
  1302  		protoFactory stack.NetworkProtocolFactory
  1303  		protoNum     tcpip.NetworkProtocolNumber
  1304  		nicAddr      tcpip.Address
  1305  		remoteAddr   tcpip.Address
  1306  		pktGen       func(*testing.T, tcpip.Address) buffer.VectorisedView
  1307  		checker      func(*testing.T, *stack.PacketBuffer, tcpip.Address)
  1308  		expectedErr  tcpip.Error
  1309  	}{
  1310  		{
  1311  			name:         "IPv4",
  1312  			protoFactory: ipv4.NewProtocol,
  1313  			protoNum:     ipv4.ProtocolNumber,
  1314  			nicAddr:      localIPv4Addr,
  1315  			remoteAddr:   remoteIPv4Addr,
  1316  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1317  				totalLen := header.IPv4MinimumSize + len(data)
  1318  				hdr := buffer.NewPrependable(totalLen)
  1319  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1320  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1321  				}
  1322  				ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  1323  				ip.Encode(&header.IPv4Fields{
  1324  					Protocol: transportProto,
  1325  					TTL:      ipv4.DefaultTTL,
  1326  					SrcAddr:  src,
  1327  					DstAddr:  remoteIPv4Addr,
  1328  				})
  1329  				return hdr.View().ToVectorisedView()
  1330  			},
  1331  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1332  				if src == header.IPv4Any {
  1333  					src = localIPv4Addr
  1334  				}
  1335  
  1336  				netHdr := pkt.NetworkHeader()
  1337  
  1338  				if len(netHdr.View()) != header.IPv4MinimumSize {
  1339  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
  1340  				}
  1341  
  1342  				checker.IPv4(t, stack.PayloadSince(netHdr),
  1343  					checker.SrcAddr(src),
  1344  					checker.DstAddr(remoteIPv4Addr),
  1345  					checker.IPv4HeaderLength(header.IPv4MinimumSize),
  1346  					checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
  1347  					checker.IPPayload(data),
  1348  				)
  1349  			},
  1350  		},
  1351  		{
  1352  			name:         "IPv4 with IHL too small",
  1353  			protoFactory: ipv4.NewProtocol,
  1354  			protoNum:     ipv4.ProtocolNumber,
  1355  			nicAddr:      localIPv4Addr,
  1356  			remoteAddr:   remoteIPv4Addr,
  1357  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1358  				totalLen := header.IPv4MinimumSize + len(data)
  1359  				hdr := buffer.NewPrependable(totalLen)
  1360  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1361  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1362  				}
  1363  				ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  1364  				ip.Encode(&header.IPv4Fields{
  1365  					Protocol: transportProto,
  1366  					TTL:      ipv4.DefaultTTL,
  1367  					SrcAddr:  src,
  1368  					DstAddr:  remoteIPv4Addr,
  1369  				})
  1370  				ip.SetHeaderLength(header.IPv4MinimumSize - 1)
  1371  				return hdr.View().ToVectorisedView()
  1372  			},
  1373  			expectedErr: &tcpip.ErrMalformedHeader{},
  1374  		},
  1375  		{
  1376  			name:         "IPv4 too small",
  1377  			protoFactory: ipv4.NewProtocol,
  1378  			protoNum:     ipv4.ProtocolNumber,
  1379  			nicAddr:      localIPv4Addr,
  1380  			remoteAddr:   remoteIPv4Addr,
  1381  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1382  				ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
  1383  				ip.Encode(&header.IPv4Fields{
  1384  					Protocol: transportProto,
  1385  					TTL:      ipv4.DefaultTTL,
  1386  					SrcAddr:  src,
  1387  					DstAddr:  remoteIPv4Addr,
  1388  				})
  1389  				return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
  1390  			},
  1391  			expectedErr: &tcpip.ErrMalformedHeader{},
  1392  		},
  1393  		{
  1394  			name:         "IPv4 minimum size",
  1395  			protoFactory: ipv4.NewProtocol,
  1396  			protoNum:     ipv4.ProtocolNumber,
  1397  			nicAddr:      localIPv4Addr,
  1398  			remoteAddr:   remoteIPv4Addr,
  1399  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1400  				ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
  1401  				ip.Encode(&header.IPv4Fields{
  1402  					Protocol: transportProto,
  1403  					TTL:      ipv4.DefaultTTL,
  1404  					SrcAddr:  src,
  1405  					DstAddr:  remoteIPv4Addr,
  1406  				})
  1407  				return buffer.View(ip).ToVectorisedView()
  1408  			},
  1409  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1410  				if src == header.IPv4Any {
  1411  					src = localIPv4Addr
  1412  				}
  1413  
  1414  				netHdr := pkt.NetworkHeader()
  1415  
  1416  				if len(netHdr.View()) != header.IPv4MinimumSize {
  1417  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
  1418  				}
  1419  
  1420  				checker.IPv4(t, stack.PayloadSince(netHdr),
  1421  					checker.SrcAddr(src),
  1422  					checker.DstAddr(remoteIPv4Addr),
  1423  					checker.IPv4HeaderLength(header.IPv4MinimumSize),
  1424  					checker.IPFullLength(header.IPv4MinimumSize),
  1425  					checker.IPPayload(nil),
  1426  				)
  1427  			},
  1428  		},
  1429  		{
  1430  			name:         "IPv4 with options",
  1431  			protoFactory: ipv4.NewProtocol,
  1432  			protoNum:     ipv4.ProtocolNumber,
  1433  			nicAddr:      localIPv4Addr,
  1434  			remoteAddr:   remoteIPv4Addr,
  1435  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1436  				ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
  1437  				totalLen := ipHdrLen + len(data)
  1438  				hdr := buffer.NewPrependable(totalLen)
  1439  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1440  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1441  				}
  1442  				ip := header.IPv4(hdr.Prepend(ipHdrLen))
  1443  				ip.Encode(&header.IPv4Fields{
  1444  					Protocol: transportProto,
  1445  					TTL:      ipv4.DefaultTTL,
  1446  					SrcAddr:  src,
  1447  					DstAddr:  remoteIPv4Addr,
  1448  					Options:  ipv4Options,
  1449  				})
  1450  				return hdr.View().ToVectorisedView()
  1451  			},
  1452  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1453  				if src == header.IPv4Any {
  1454  					src = localIPv4Addr
  1455  				}
  1456  
  1457  				netHdr := pkt.NetworkHeader()
  1458  
  1459  				hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
  1460  				if len(netHdr.View()) != hdrLen {
  1461  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
  1462  				}
  1463  
  1464  				checker.IPv4(t, stack.PayloadSince(netHdr),
  1465  					checker.SrcAddr(src),
  1466  					checker.DstAddr(remoteIPv4Addr),
  1467  					checker.IPv4HeaderLength(hdrLen),
  1468  					checker.IPFullLength(uint16(hdrLen+len(data))),
  1469  					checker.IPv4Options(expectOptions),
  1470  					checker.IPPayload(data),
  1471  				)
  1472  			},
  1473  		},
  1474  		{
  1475  			name:         "IPv4 with options and data across views",
  1476  			protoFactory: ipv4.NewProtocol,
  1477  			protoNum:     ipv4.ProtocolNumber,
  1478  			nicAddr:      localIPv4Addr,
  1479  			remoteAddr:   remoteIPv4Addr,
  1480  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1481  				ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
  1482  				ip.Encode(&header.IPv4Fields{
  1483  					Protocol: transportProto,
  1484  					TTL:      ipv4.DefaultTTL,
  1485  					SrcAddr:  src,
  1486  					DstAddr:  remoteIPv4Addr,
  1487  					Options:  ipv4Options,
  1488  				})
  1489  				vv := buffer.View(ip).ToVectorisedView()
  1490  				vv.AppendView(data)
  1491  				return vv
  1492  			},
  1493  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1494  				if src == header.IPv4Any {
  1495  					src = localIPv4Addr
  1496  				}
  1497  
  1498  				netHdr := pkt.NetworkHeader()
  1499  
  1500  				hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
  1501  				if len(netHdr.View()) != hdrLen {
  1502  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
  1503  				}
  1504  
  1505  				checker.IPv4(t, stack.PayloadSince(netHdr),
  1506  					checker.SrcAddr(src),
  1507  					checker.DstAddr(remoteIPv4Addr),
  1508  					checker.IPv4HeaderLength(hdrLen),
  1509  					checker.IPFullLength(uint16(hdrLen+len(data))),
  1510  					checker.IPv4Options(expectOptions),
  1511  					checker.IPPayload(data),
  1512  				)
  1513  			},
  1514  		},
  1515  		{
  1516  			name:         "IPv6",
  1517  			protoFactory: ipv6.NewProtocol,
  1518  			protoNum:     ipv6.ProtocolNumber,
  1519  			nicAddr:      localIPv6Addr,
  1520  			remoteAddr:   remoteIPv6Addr,
  1521  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1522  				totalLen := header.IPv6MinimumSize + len(data)
  1523  				hdr := buffer.NewPrependable(totalLen)
  1524  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1525  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1526  				}
  1527  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  1528  				ip.Encode(&header.IPv6Fields{
  1529  					TransportProtocol: transportProto,
  1530  					HopLimit:          ipv6.DefaultTTL,
  1531  					SrcAddr:           src,
  1532  					DstAddr:           remoteIPv6Addr,
  1533  				})
  1534  				return hdr.View().ToVectorisedView()
  1535  			},
  1536  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1537  				if src == header.IPv6Any {
  1538  					src = localIPv6Addr
  1539  				}
  1540  
  1541  				netHdr := pkt.NetworkHeader()
  1542  
  1543  				if len(netHdr.View()) != header.IPv6MinimumSize {
  1544  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
  1545  				}
  1546  
  1547  				checker.IPv6(t, stack.PayloadSince(netHdr),
  1548  					checker.SrcAddr(src),
  1549  					checker.DstAddr(remoteIPv6Addr),
  1550  					checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
  1551  					checker.IPPayload(data),
  1552  				)
  1553  			},
  1554  		},
  1555  		{
  1556  			name:         "IPv6 with extension header",
  1557  			protoFactory: ipv6.NewProtocol,
  1558  			protoNum:     ipv6.ProtocolNumber,
  1559  			nicAddr:      localIPv6Addr,
  1560  			remoteAddr:   remoteIPv6Addr,
  1561  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1562  				totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
  1563  				hdr := buffer.NewPrependable(totalLen)
  1564  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1565  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1566  				}
  1567  				if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
  1568  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
  1569  				}
  1570  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  1571  				ip.Encode(&header.IPv6Fields{
  1572  					// NB: we're lying about transport protocol here to verify the raw
  1573  					// fragment header bytes.
  1574  					TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
  1575  					HopLimit:          ipv6.DefaultTTL,
  1576  					SrcAddr:           src,
  1577  					DstAddr:           remoteIPv6Addr,
  1578  				})
  1579  				return hdr.View().ToVectorisedView()
  1580  			},
  1581  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1582  				if src == header.IPv6Any {
  1583  					src = localIPv6Addr
  1584  				}
  1585  
  1586  				netHdr := pkt.NetworkHeader()
  1587  
  1588  				if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
  1589  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
  1590  				}
  1591  
  1592  				checker.IPv6(t, stack.PayloadSince(netHdr),
  1593  					checker.SrcAddr(src),
  1594  					checker.DstAddr(remoteIPv6Addr),
  1595  					checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
  1596  					checker.IPPayload(ipv6PayloadWithExtHdr),
  1597  				)
  1598  			},
  1599  		},
  1600  		{
  1601  			name:         "IPv6 minimum size",
  1602  			protoFactory: ipv6.NewProtocol,
  1603  			protoNum:     ipv6.ProtocolNumber,
  1604  			nicAddr:      localIPv6Addr,
  1605  			remoteAddr:   remoteIPv6Addr,
  1606  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1607  				ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
  1608  				ip.Encode(&header.IPv6Fields{
  1609  					TransportProtocol: transportProto,
  1610  					HopLimit:          ipv6.DefaultTTL,
  1611  					SrcAddr:           src,
  1612  					DstAddr:           remoteIPv6Addr,
  1613  				})
  1614  				return buffer.View(ip).ToVectorisedView()
  1615  			},
  1616  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1617  				if src == header.IPv6Any {
  1618  					src = localIPv6Addr
  1619  				}
  1620  
  1621  				netHdr := pkt.NetworkHeader()
  1622  
  1623  				if len(netHdr.View()) != header.IPv6MinimumSize {
  1624  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
  1625  				}
  1626  
  1627  				checker.IPv6(t, stack.PayloadSince(netHdr),
  1628  					checker.SrcAddr(src),
  1629  					checker.DstAddr(remoteIPv6Addr),
  1630  					checker.IPFullLength(header.IPv6MinimumSize),
  1631  					checker.IPPayload(nil),
  1632  				)
  1633  			},
  1634  		},
  1635  		{
  1636  			name:         "IPv6 too small",
  1637  			protoFactory: ipv6.NewProtocol,
  1638  			protoNum:     ipv6.ProtocolNumber,
  1639  			nicAddr:      localIPv6Addr,
  1640  			remoteAddr:   remoteIPv6Addr,
  1641  			pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
  1642  				ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
  1643  				ip.Encode(&header.IPv6Fields{
  1644  					TransportProtocol: transportProto,
  1645  					HopLimit:          ipv6.DefaultTTL,
  1646  					SrcAddr:           src,
  1647  					DstAddr:           remoteIPv4Addr,
  1648  				})
  1649  				return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
  1650  			},
  1651  			expectedErr: &tcpip.ErrMalformedHeader{},
  1652  		},
  1653  	}
  1654  
  1655  	for _, test := range tests {
  1656  		t.Run(test.name, func(t *testing.T) {
  1657  			subTests := []struct {
  1658  				name    string
  1659  				srcAddr tcpip.Address
  1660  			}{
  1661  				{
  1662  					name:    "unspecified source",
  1663  					srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
  1664  				},
  1665  				{
  1666  					name:    "random source",
  1667  					srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
  1668  				},
  1669  			}
  1670  
  1671  			for _, subTest := range subTests {
  1672  				t.Run(subTest.name, func(t *testing.T) {
  1673  					s := stack.New(stack.Options{
  1674  						NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
  1675  					})
  1676  					e := channel.New(1, header.IPv6MinimumMTU, "")
  1677  					if err := s.CreateNIC(nicID, e); err != nil {
  1678  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
  1679  					}
  1680  					if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
  1681  						t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
  1682  					}
  1683  
  1684  					s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
  1685  
  1686  					r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
  1687  					if err != nil {
  1688  						t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
  1689  					}
  1690  					defer r.Release()
  1691  
  1692  					{
  1693  						err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
  1694  							Data: test.pktGen(t, subTest.srcAddr),
  1695  						}))
  1696  						if diff := cmp.Diff(test.expectedErr, err); diff != "" {
  1697  							t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
  1698  						}
  1699  					}
  1700  
  1701  					if test.expectedErr != nil {
  1702  						return
  1703  					}
  1704  
  1705  					pkt, ok := e.Read()
  1706  					if !ok {
  1707  						t.Fatal("expected a packet to be written")
  1708  					}
  1709  					test.checker(t, pkt.Pkt, subTest.srcAddr)
  1710  				})
  1711  			}
  1712  		})
  1713  	}
  1714  }
  1715  
  1716  // Test that the included data in an ICMP error packet conforms to the
  1717  // requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3
  1718  func TestICMPInclusionSize(t *testing.T) {
  1719  	const (
  1720  		replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize
  1721  		replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize
  1722  		targetSize4        = header.IPv4MinimumProcessableDatagramSize
  1723  		targetSize6        = header.IPv6MinimumMTU
  1724  		// A protocol number that will cause an error response.
  1725  		reservedProtocol = 254
  1726  	)
  1727  
  1728  	// IPv4 function to create a IP packet and send it to the stack.
  1729  	// The packet should generate an error response. We can do that by using an
  1730  	// unknown transport protocol (254).
  1731  	rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
  1732  		totalLen := header.IPv4MinimumSize + len(payload)
  1733  		hdr := buffer.NewPrependable(header.IPv4MinimumSize)
  1734  		ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  1735  		ip.Encode(&header.IPv4Fields{
  1736  			TotalLength: uint16(totalLen),
  1737  			Protocol:    reservedProtocol,
  1738  			TTL:         ipv4.DefaultTTL,
  1739  			SrcAddr:     src,
  1740  			DstAddr:     localIPv4Addr,
  1741  		})
  1742  		ip.SetChecksum(^ip.CalculateChecksum())
  1743  		vv := hdr.View().ToVectorisedView()
  1744  		vv.AppendView(buffer.View(payload))
  1745  		// Take a copy before InjectInbound takes ownership of vv
  1746  		// as vv may be changed during the call.
  1747  		v := vv.ToView()
  1748  		e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  1749  			Data: vv,
  1750  		}))
  1751  		return v
  1752  	}
  1753  
  1754  	// IPv6 function to create a packet and send it to the stack.
  1755  	// The packet should be errant in a way that causes the stack to send an
  1756  	// ICMP error response and have enough data to allow the testing of the
  1757  	// inclusion of the errant packet. Use `unknown next header' to generate
  1758  	// the error.
  1759  	rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
  1760  		hdr := buffer.NewPrependable(header.IPv6MinimumSize)
  1761  		ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  1762  		ip.Encode(&header.IPv6Fields{
  1763  			PayloadLength:     uint16(len(payload)),
  1764  			TransportProtocol: reservedProtocol,
  1765  			HopLimit:          ipv6.DefaultTTL,
  1766  			SrcAddr:           src,
  1767  			DstAddr:           localIPv6Addr,
  1768  		})
  1769  		vv := hdr.View().ToVectorisedView()
  1770  		vv.AppendView(buffer.View(payload))
  1771  		// Take a copy before InjectInbound takes ownership of vv
  1772  		// as vv may be changed during the call.
  1773  		v := vv.ToView()
  1774  
  1775  		e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  1776  			Data: vv,
  1777  		}))
  1778  		return v
  1779  	}
  1780  
  1781  	v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
  1782  		// We already know the entire packet is the right size so we can use its
  1783  		// length to calculate the right payload size to check.
  1784  		expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
  1785  		checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()),
  1786  			checker.SrcAddr(localIPv4Addr),
  1787  			checker.DstAddr(remoteIPv4Addr),
  1788  			checker.IPv4HeaderLength(header.IPv4MinimumSize),
  1789  			checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)),
  1790  			checker.ICMPv4(
  1791  				checker.ICMPv4Checksum(),
  1792  				checker.ICMPv4Type(header.ICMPv4DstUnreachable),
  1793  				checker.ICMPv4Code(header.ICMPv4ProtoUnreachable),
  1794  				checker.ICMPv4Payload(payload[:expectedPayloadLength]),
  1795  			),
  1796  		)
  1797  	}
  1798  
  1799  	v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
  1800  		// We already know the entire packet is the right size so we can use its
  1801  		// length to calculate the right payload size to check.
  1802  		expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize
  1803  		checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()),
  1804  			checker.SrcAddr(localIPv6Addr),
  1805  			checker.DstAddr(remoteIPv6Addr),
  1806  			checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)),
  1807  			checker.ICMPv6(
  1808  				checker.ICMPv6Type(header.ICMPv6ParamProblem),
  1809  				checker.ICMPv6Code(header.ICMPv6UnknownHeader),
  1810  				checker.ICMPv6Payload(payload[:expectedPayloadLength]),
  1811  			),
  1812  		)
  1813  	}
  1814  	tests := []struct {
  1815  		name          string
  1816  		srcAddress    tcpip.Address
  1817  		injector      func(*channel.Endpoint, tcpip.Address, []byte) buffer.View
  1818  		checker       func(*testing.T, *stack.PacketBuffer, buffer.View)
  1819  		payloadLength int    // Not including IP header.
  1820  		linkMTU       uint32 // Largest IP packet that the link can send as payload.
  1821  		replyLength   int    // Total size of IP/ICMP packet expected back.
  1822  	}{
  1823  		{
  1824  			name:          "IPv4 exact match",
  1825  			srcAddress:    remoteIPv4Addr,
  1826  			injector:      rxIPv4Bad,
  1827  			checker:       v4Checker,
  1828  			payloadLength: targetSize4 - replyHeaderLength4,
  1829  			linkMTU:       targetSize4,
  1830  			replyLength:   targetSize4,
  1831  		},
  1832  		{
  1833  			name:          "IPv4 larger MTU",
  1834  			srcAddress:    remoteIPv4Addr,
  1835  			injector:      rxIPv4Bad,
  1836  			checker:       v4Checker,
  1837  			payloadLength: targetSize4,
  1838  			linkMTU:       targetSize4 + 1000,
  1839  			replyLength:   targetSize4,
  1840  		},
  1841  		{
  1842  			name:          "IPv4 smaller MTU",
  1843  			srcAddress:    remoteIPv4Addr,
  1844  			injector:      rxIPv4Bad,
  1845  			checker:       v4Checker,
  1846  			payloadLength: targetSize4,
  1847  			linkMTU:       targetSize4 - 50,
  1848  			replyLength:   targetSize4 - 50,
  1849  		},
  1850  		{
  1851  			name:          "IPv4 payload exceeds",
  1852  			srcAddress:    remoteIPv4Addr,
  1853  			injector:      rxIPv4Bad,
  1854  			checker:       v4Checker,
  1855  			payloadLength: targetSize4 + 10,
  1856  			linkMTU:       targetSize4,
  1857  			replyLength:   targetSize4,
  1858  		},
  1859  		{
  1860  			name:          "IPv4 1 byte less",
  1861  			srcAddress:    remoteIPv4Addr,
  1862  			injector:      rxIPv4Bad,
  1863  			checker:       v4Checker,
  1864  			payloadLength: targetSize4 - replyHeaderLength4 - 1,
  1865  			linkMTU:       targetSize4,
  1866  			replyLength:   targetSize4 - 1,
  1867  		},
  1868  		{
  1869  			name:          "IPv4 No payload",
  1870  			srcAddress:    remoteIPv4Addr,
  1871  			injector:      rxIPv4Bad,
  1872  			checker:       v4Checker,
  1873  			payloadLength: 0,
  1874  			linkMTU:       targetSize4,
  1875  			replyLength:   replyHeaderLength4,
  1876  		},
  1877  		{
  1878  			name:          "IPv6 exact match",
  1879  			srcAddress:    remoteIPv6Addr,
  1880  			injector:      rxIPv6Bad,
  1881  			checker:       v6Checker,
  1882  			payloadLength: targetSize6 - replyHeaderLength6,
  1883  			linkMTU:       targetSize6,
  1884  			replyLength:   targetSize6,
  1885  		},
  1886  		{
  1887  			name:          "IPv6 larger MTU",
  1888  			srcAddress:    remoteIPv6Addr,
  1889  			injector:      rxIPv6Bad,
  1890  			checker:       v6Checker,
  1891  			payloadLength: targetSize6,
  1892  			linkMTU:       targetSize6 + 400,
  1893  			replyLength:   targetSize6,
  1894  		},
  1895  		// NB. No "smaller MTU" test here as less than 1280 is not permitted
  1896  		// in IPv6.
  1897  		{
  1898  			name:          "IPv6 payload exceeds",
  1899  			srcAddress:    remoteIPv6Addr,
  1900  			injector:      rxIPv6Bad,
  1901  			checker:       v6Checker,
  1902  			payloadLength: targetSize6,
  1903  			linkMTU:       targetSize6,
  1904  			replyLength:   targetSize6,
  1905  		},
  1906  		{
  1907  			name:          "IPv6 1 byte less",
  1908  			srcAddress:    remoteIPv6Addr,
  1909  			injector:      rxIPv6Bad,
  1910  			checker:       v6Checker,
  1911  			payloadLength: targetSize6 - replyHeaderLength6 - 1,
  1912  			linkMTU:       targetSize6,
  1913  			replyLength:   targetSize6 - 1,
  1914  		},
  1915  		{
  1916  			name:          "IPv6 no payload",
  1917  			srcAddress:    remoteIPv6Addr,
  1918  			injector:      rxIPv6Bad,
  1919  			checker:       v6Checker,
  1920  			payloadLength: 0,
  1921  			linkMTU:       targetSize6,
  1922  			replyLength:   replyHeaderLength6,
  1923  		},
  1924  	}
  1925  
  1926  	for _, test := range tests {
  1927  		t.Run(test.name, func(t *testing.T) {
  1928  			s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU)
  1929  			// Allocate and initialize the payload view.
  1930  			payload := buffer.NewView(test.payloadLength)
  1931  			for i := 0; i < len(payload); i++ {
  1932  				payload[i] = uint8(i)
  1933  			}
  1934  			// Default routes for IPv4&6 so ICMP can find a route to the remote
  1935  			// node when attempting to send the ICMP error Reply.
  1936  			s.SetRouteTable([]tcpip.Route{
  1937  				{
  1938  					Destination: header.IPv4EmptySubnet,
  1939  					NIC:         nicID,
  1940  				},
  1941  				{
  1942  					Destination: header.IPv6EmptySubnet,
  1943  					NIC:         nicID,
  1944  				},
  1945  			})
  1946  			v := test.injector(e, test.srcAddress, payload)
  1947  			pkt, ok := e.Read()
  1948  			if !ok {
  1949  				t.Fatal("expected a packet to be written")
  1950  			}
  1951  			if got, want := pkt.Pkt.Size(), test.replyLength; got != want {
  1952  				t.Fatalf("got %d bytes of icmp error packet, want %d", got, want)
  1953  			}
  1954  			test.checker(t, pkt.Pkt, v)
  1955  		})
  1956  	}
  1957  }
  1958  
  1959  func TestJoinLeaveAllRoutersGroup(t *testing.T) {
  1960  	const nicID = 1
  1961  
  1962  	tests := []struct {
  1963  		name           string
  1964  		netProto       tcpip.NetworkProtocolNumber
  1965  		protoFactory   stack.NetworkProtocolFactory
  1966  		allRoutersAddr tcpip.Address
  1967  	}{
  1968  		{
  1969  			name:           "IPv4",
  1970  			netProto:       ipv4.ProtocolNumber,
  1971  			protoFactory:   ipv4.NewProtocol,
  1972  			allRoutersAddr: header.IPv4AllRoutersGroup,
  1973  		},
  1974  		{
  1975  			name:           "IPv6 Interface Local",
  1976  			netProto:       ipv6.ProtocolNumber,
  1977  			protoFactory:   ipv6.NewProtocol,
  1978  			allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
  1979  		},
  1980  		{
  1981  			name:           "IPv6 Link Local",
  1982  			netProto:       ipv6.ProtocolNumber,
  1983  			protoFactory:   ipv6.NewProtocol,
  1984  			allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
  1985  		},
  1986  		{
  1987  			name:           "IPv6 Site Local",
  1988  			netProto:       ipv6.ProtocolNumber,
  1989  			protoFactory:   ipv6.NewProtocol,
  1990  			allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
  1991  		},
  1992  	}
  1993  
  1994  	for _, test := range tests {
  1995  		t.Run(test.name, func(t *testing.T) {
  1996  			for _, nicDisabled := range [...]bool{true, false} {
  1997  				t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
  1998  					s := stack.New(stack.Options{
  1999  						NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  2000  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
  2001  					})
  2002  					opts := stack.NICOptions{Disabled: nicDisabled}
  2003  					if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
  2004  						t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
  2005  					}
  2006  
  2007  					if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
  2008  						t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
  2009  					} else if got {
  2010  						t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
  2011  					}
  2012  
  2013  					if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil {
  2014  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err)
  2015  					}
  2016  					if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
  2017  						t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
  2018  					} else if !got {
  2019  						t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
  2020  					}
  2021  
  2022  					if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil {
  2023  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err)
  2024  					}
  2025  					if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
  2026  						t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
  2027  					} else if got {
  2028  						t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
  2029  					}
  2030  				})
  2031  			}
  2032  		})
  2033  	}
  2034  }