gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/ip_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ip_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"strings"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"gvisor.dev/gvisor/pkg/buffer"
    25  	"gvisor.dev/gvisor/pkg/refs"
    26  	"gvisor.dev/gvisor/pkg/sync"
    27  	"gvisor.dev/gvisor/pkg/tcpip"
    28  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    29  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    30  	"gvisor.dev/gvisor/pkg/tcpip/header"
    31  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    32  	"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
    33  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    34  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    35  	"gvisor.dev/gvisor/pkg/tcpip/prependable"
    36  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    37  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    38  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    39  	"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
    40  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    41  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    42  	"gvisor.dev/gvisor/pkg/waiter"
    43  )
    44  
    45  const nicID = 1
    46  
    47  var (
    48  	localIPv4Addr  = testutil.MustParse4("10.0.0.1")
    49  	remoteIPv4Addr = testutil.MustParse4("10.0.0.2")
    50  	ipv4SubnetAddr = testutil.MustParse4("10.0.0.0")
    51  	ipv4SubnetMask = testutil.MustParse4("255.255.255.0")
    52  	ipv4Gateway    = testutil.MustParse4("10.0.0.3")
    53  	localIPv6Addr  = testutil.MustParse6("a00::1")
    54  	remoteIPv6Addr = testutil.MustParse6("a00::2")
    55  	ipv6SubnetAddr = testutil.MustParse6("a00::")
    56  	ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00")
    57  	ipv6Gateway    = testutil.MustParse6("a00::3")
    58  )
    59  
    60  var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
    61  	Address:   localIPv4Addr,
    62  	PrefixLen: 24,
    63  }
    64  
    65  var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
    66  	Address:   localIPv6Addr,
    67  	PrefixLen: 120,
    68  }
    69  
    70  type transportError struct {
    71  	origin tcpip.SockErrOrigin
    72  	typ    uint8
    73  	code   uint8
    74  	info   uint32
    75  	kind   stack.TransportErrorKind
    76  }
    77  
    78  // testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
    79  // The former is used to pretend that it's a link endpoint so that we can
    80  // inspect packets written by the network endpoints. The latter is used to
    81  // pretend that it's the network stack so that it can inspect incoming packets
    82  // that have been handled by the network endpoints.
    83  //
    84  // Packets are checked by comparing their fields/values against the expected
    85  // values stored in the test object itself.
    86  type testObject struct {
    87  	t        *testing.T
    88  	protocol tcpip.TransportProtocolNumber
    89  	contents []byte
    90  	srcAddr  tcpip.Address
    91  	dstAddr  tcpip.Address
    92  	v4       bool
    93  	transErr transportError
    94  
    95  	dataCalls    int
    96  	controlCalls int
    97  	rawCalls     int
    98  }
    99  
   100  // checkValues verifies that the transport protocol, data contents, src & dst
   101  // addresses of a packet match what's expected. If any field doesn't match, the
   102  // test fails.
   103  func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v []byte, srcAddr, dstAddr tcpip.Address) {
   104  	if protocol != t.protocol {
   105  		t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
   106  	}
   107  
   108  	if srcAddr != t.srcAddr {
   109  		t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
   110  	}
   111  
   112  	if dstAddr != t.dstAddr {
   113  		t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
   114  	}
   115  
   116  	if len(v) != len(t.contents) {
   117  		t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
   118  	}
   119  
   120  	for i := range t.contents {
   121  		if t.contents[i] != v[i] {
   122  			t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
   123  		}
   124  	}
   125  }
   126  
   127  // DeliverTransportPacket is called by network endpoints after parsing incoming
   128  // packets. This is used by the test object to verify that the results of the
   129  // parsing are expected.
   130  func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
   131  	netHdr := pkt.Network()
   132  	v := pkt.Data().AsRange().ToView()
   133  	defer v.Release()
   134  	t.checkValues(protocol, v.AsSlice(), netHdr.SourceAddress(), netHdr.DestinationAddress())
   135  	t.dataCalls++
   136  	return stack.TransportPacketHandled
   137  }
   138  
   139  // DeliverTransportError is called by network endpoints after parsing
   140  // incoming control (ICMP) packets. This is used by the test object to verify
   141  // that the results of the parsing are expected.
   142  func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) {
   143  	v := pkt.Data().AsRange().ToView()
   144  	defer v.Release()
   145  	t.checkValues(trans, v.AsSlice(), remote, local)
   146  	if diff := cmp.Diff(
   147  		t.transErr,
   148  		transportError{
   149  			origin: transErr.Origin(),
   150  			typ:    transErr.Type(),
   151  			code:   transErr.Code(),
   152  			info:   transErr.Info(),
   153  			kind:   transErr.Kind(),
   154  		},
   155  		cmp.AllowUnexported(transportError{}),
   156  	); diff != "" {
   157  		t.t.Errorf("transport error mismatch (-want +got):\n%s", diff)
   158  	}
   159  	t.controlCalls++
   160  }
   161  
   162  func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
   163  	t.rawCalls++
   164  }
   165  
   166  // Attach is only implemented to satisfy the LinkEndpoint interface.
   167  func (*testObject) Attach(stack.NetworkDispatcher) {}
   168  
   169  // IsAttached implements stack.LinkEndpoint.IsAttached.
   170  func (*testObject) IsAttached() bool {
   171  	return true
   172  }
   173  
   174  // MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
   175  // matches the linux loopback MTU.
   176  func (*testObject) MTU() uint32 {
   177  	return 65536
   178  }
   179  
   180  // Capabilities implements stack.LinkEndpoint.Capabilities.
   181  func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
   182  	return 0
   183  }
   184  
   185  // MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
   186  func (*testObject) MaxHeaderLength() uint16 {
   187  	return 0
   188  }
   189  
   190  // LinkAddress returns the link address of this endpoint.
   191  func (*testObject) LinkAddress() tcpip.LinkAddress {
   192  	return ""
   193  }
   194  
   195  // Wait implements stack.LinkEndpoint.Wait.
   196  func (*testObject) Wait() {}
   197  
   198  // WritePacket is called by network endpoints after producing a packet and
   199  // writing it to the link endpoint. This is used by the test object to verify
   200  // that the produced packet is as expected.
   201  func (t *testObject) WritePacket(_ *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
   202  	var prot tcpip.TransportProtocolNumber
   203  	var srcAddr tcpip.Address
   204  	var dstAddr tcpip.Address
   205  
   206  	if t.v4 {
   207  		h := header.IPv4(pkt.NetworkHeader().Slice())
   208  		prot = tcpip.TransportProtocolNumber(h.Protocol())
   209  		srcAddr = h.SourceAddress()
   210  		dstAddr = h.DestinationAddress()
   211  
   212  	} else {
   213  		h := header.IPv6(pkt.NetworkHeader().Slice())
   214  		prot = tcpip.TransportProtocolNumber(h.NextHeader())
   215  		srcAddr = h.SourceAddress()
   216  		dstAddr = h.DestinationAddress()
   217  	}
   218  	t.checkValues(prot, pkt.Data().AsRange().ToSlice(), srcAddr, dstAddr)
   219  	return nil
   220  }
   221  
   222  // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
   223  func (*testObject) ARPHardwareType() header.ARPHardwareType {
   224  	panic("not implemented")
   225  }
   226  
   227  // AddHeader implements stack.LinkEndpoint.AddHeader.
   228  func (*testObject) AddHeader(*stack.PacketBuffer) {
   229  	panic("not implemented")
   230  }
   231  
   232  // ParseHeader implements stack.LinkEndpoint.ParseHeader.
   233  func (*testObject) ParseHeader(*stack.PacketBuffer) bool {
   234  	panic("not implemented")
   235  }
   236  
   237  type testContext struct {
   238  	s *stack.Stack
   239  }
   240  
   241  func newTestContext() testContext {
   242  	s := stack.New(stack.Options{
   243  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   244  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
   245  		RawFactory:         raw.EndpointFactory{},
   246  	})
   247  	return testContext{s: s}
   248  }
   249  
   250  func (ctx *testContext) cleanup() {
   251  	ctx.s.Close()
   252  	ctx.s.Wait()
   253  	refs.DoRepeatedLeakCheck()
   254  }
   255  
   256  func buildIPv4Route(ctx testContext, local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
   257  	s := ctx.s
   258  	s.CreateNIC(nicID, loopback.New())
   259  	protocolAddr := tcpip.ProtocolAddress{
   260  		Protocol:          ipv4.ProtocolNumber,
   261  		AddressWithPrefix: local.WithPrefix(),
   262  	}
   263  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   264  		return nil, err
   265  	}
   266  	s.SetRouteTable([]tcpip.Route{{
   267  		Destination: header.IPv4EmptySubnet,
   268  		Gateway:     ipv4Gateway,
   269  		NIC:         1,
   270  	}})
   271  
   272  	return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
   273  }
   274  
   275  func buildIPv6Route(ctx testContext, local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
   276  	s := ctx.s
   277  	s.CreateNIC(nicID, loopback.New())
   278  	protocolAddr := tcpip.ProtocolAddress{
   279  		Protocol:          ipv6.ProtocolNumber,
   280  		AddressWithPrefix: local.WithPrefix(),
   281  	}
   282  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   283  		return nil, err
   284  	}
   285  	s.SetRouteTable([]tcpip.Route{{
   286  		Destination: header.IPv6EmptySubnet,
   287  		Gateway:     ipv6Gateway,
   288  		NIC:         1,
   289  	}})
   290  
   291  	return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
   292  }
   293  
   294  func addLinkEndpointToStackWithMTU(t *testing.T, s *stack.Stack, mtu uint32) *channel.Endpoint {
   295  	t.Helper()
   296  	e := channel.New(1, mtu, "")
   297  	if err := s.CreateNIC(nicID, e); err != nil {
   298  		t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
   299  	}
   300  
   301  	v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
   302  	if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
   303  		t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
   304  	}
   305  
   306  	v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
   307  	if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
   308  		t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
   309  	}
   310  
   311  	return e
   312  }
   313  
   314  func addLinkEndpointToStack(t *testing.T, s *stack.Stack) *channel.Endpoint {
   315  	t.Helper()
   316  	return addLinkEndpointToStackWithMTU(t, s, header.IPv6MinimumMTU)
   317  }
   318  
   319  var _ stack.NetworkInterface = (*testInterface)(nil)
   320  
   321  type testInterface struct {
   322  	testObject
   323  
   324  	mu struct {
   325  		sync.RWMutex
   326  		disabled bool
   327  	}
   328  }
   329  
   330  func (*testInterface) ID() tcpip.NICID {
   331  	return nicID
   332  }
   333  
   334  func (*testInterface) IsLoopback() bool {
   335  	return false
   336  }
   337  
   338  func (*testInterface) Name() string {
   339  	return ""
   340  }
   341  
   342  func (t *testInterface) Enabled() bool {
   343  	t.mu.RLock()
   344  	defer t.mu.RUnlock()
   345  	return !t.mu.disabled
   346  }
   347  
   348  func (*testInterface) Promiscuous() bool {
   349  	return false
   350  }
   351  
   352  func (*testInterface) Spoofing() bool {
   353  	return false
   354  }
   355  
   356  func (t *testInterface) setEnabled(v bool) {
   357  	t.mu.Lock()
   358  	defer t.mu.Unlock()
   359  	t.mu.disabled = !v
   360  }
   361  
   362  func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.PacketBuffer) tcpip.Error {
   363  	return &tcpip.ErrNotSupported{}
   364  }
   365  
   366  func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
   367  	return nil
   368  }
   369  
   370  func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
   371  	return nil
   372  }
   373  
   374  func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
   375  	return tcpip.AddressWithPrefix{}, nil
   376  }
   377  
   378  func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
   379  	return false
   380  }
   381  
   382  func TestSourceAddressValidation(t *testing.T) {
   383  	rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
   384  		totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
   385  		hdr := prependable.New(totalLen)
   386  		pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
   387  		pkt.SetType(header.ICMPv4Echo)
   388  		pkt.SetCode(0)
   389  		pkt.SetChecksum(0)
   390  		pkt.SetChecksum(^checksum.Checksum(pkt, 0))
   391  		ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
   392  		ip.Encode(&header.IPv4Fields{
   393  			TotalLength: uint16(totalLen),
   394  			Protocol:    uint8(icmp.ProtocolNumber4),
   395  			TTL:         ipv4.DefaultTTL,
   396  			SrcAddr:     src,
   397  			DstAddr:     localIPv4Addr,
   398  		})
   399  		ip.SetChecksum(^ip.CalculateChecksum())
   400  
   401  		pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
   402  			Payload: buffer.MakeWithData(hdr.View()),
   403  		})
   404  		e.InjectInbound(header.IPv4ProtocolNumber, pktBuf)
   405  		pktBuf.DecRef()
   406  	}
   407  
   408  	rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
   409  		totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
   410  		hdr := prependable.New(totalLen)
   411  		pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
   412  		pkt.SetType(header.ICMPv6EchoRequest)
   413  		pkt.SetCode(0)
   414  		pkt.SetChecksum(0)
   415  		pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   416  			Header: pkt,
   417  			Src:    src,
   418  			Dst:    localIPv6Addr,
   419  		}))
   420  		ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   421  		ip.Encode(&header.IPv6Fields{
   422  			PayloadLength:     header.ICMPv6MinimumSize,
   423  			TransportProtocol: icmp.ProtocolNumber6,
   424  			HopLimit:          ipv6.DefaultTTL,
   425  			SrcAddr:           src,
   426  			DstAddr:           localIPv6Addr,
   427  		})
   428  		pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
   429  			Payload: buffer.MakeWithData(hdr.View()),
   430  		})
   431  		e.InjectInbound(header.IPv6ProtocolNumber, pktBuf)
   432  		pktBuf.DecRef()
   433  	}
   434  
   435  	tests := []struct {
   436  		name       string
   437  		srcAddress tcpip.Address
   438  		rxICMP     func(*channel.Endpoint, tcpip.Address)
   439  		valid      bool
   440  	}{
   441  		{
   442  			name:       "IPv4 valid",
   443  			srcAddress: tcpip.AddrFromSlice([]byte("\x01\x02\x03\x04")),
   444  			rxICMP:     rxIPv4ICMP,
   445  			valid:      true,
   446  		},
   447  		{
   448  			name:       "IPv6 valid",
   449  			srcAddress: tcpip.AddrFromSlice([]byte("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")),
   450  			rxICMP:     rxIPv6ICMP,
   451  			valid:      true,
   452  		},
   453  		{
   454  			name:       "IPv4 unspecified",
   455  			srcAddress: header.IPv4Any,
   456  			rxICMP:     rxIPv4ICMP,
   457  			valid:      true,
   458  		},
   459  		{
   460  			name:       "IPv6 unspecified",
   461  			srcAddress: header.IPv4Any,
   462  			rxICMP:     rxIPv6ICMP,
   463  			valid:      true,
   464  		},
   465  		{
   466  			name:       "IPv4 multicast",
   467  			srcAddress: tcpip.AddrFromSlice([]byte("\xe0\x00\x00\x01")),
   468  			rxICMP:     rxIPv4ICMP,
   469  			valid:      false,
   470  		},
   471  		{
   472  			name:       "IPv6 multicast",
   473  			srcAddress: tcpip.AddrFromSlice([]byte("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")),
   474  			rxICMP:     rxIPv6ICMP,
   475  			valid:      false,
   476  		},
   477  		{
   478  			name:       "IPv4 broadcast",
   479  			srcAddress: header.IPv4Broadcast,
   480  			rxICMP:     rxIPv4ICMP,
   481  			valid:      false,
   482  		},
   483  		{
   484  			name: "IPv4 subnet broadcast",
   485  			srcAddress: func() tcpip.Address {
   486  				subnet := localIPv4AddrWithPrefix.Subnet()
   487  				return subnet.Broadcast()
   488  			}(),
   489  			rxICMP: rxIPv4ICMP,
   490  			valid:  false,
   491  		},
   492  	}
   493  
   494  	for _, test := range tests {
   495  		t.Run(test.name, func(t *testing.T) {
   496  			ctx := newTestContext()
   497  			defer ctx.cleanup()
   498  			s := ctx.s
   499  
   500  			e := addLinkEndpointToStack(t, s)
   501  			defer e.Close()
   502  			test.rxICMP(e, test.srcAddress)
   503  
   504  			var wantValid uint64
   505  			if test.valid {
   506  				wantValid = 1
   507  			}
   508  
   509  			if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
   510  				t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
   511  			}
   512  			if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
   513  				t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
   514  			}
   515  		})
   516  	}
   517  }
   518  
   519  func TestEnableWhenNICDisabled(t *testing.T) {
   520  	tests := []struct {
   521  		name            string
   522  		protocolFactory stack.NetworkProtocolFactory
   523  		protoNum        tcpip.NetworkProtocolNumber
   524  	}{
   525  		{
   526  			name:            "IPv4",
   527  			protocolFactory: ipv4.NewProtocol,
   528  			protoNum:        ipv4.ProtocolNumber,
   529  		},
   530  		{
   531  			name:            "IPv6",
   532  			protocolFactory: ipv6.NewProtocol,
   533  			protoNum:        ipv6.ProtocolNumber,
   534  		},
   535  	}
   536  
   537  	for _, test := range tests {
   538  		t.Run(test.name, func(t *testing.T) {
   539  			var nic testInterface
   540  			nic.setEnabled(false)
   541  
   542  			s := stack.New(stack.Options{
   543  				NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
   544  			})
   545  			defer func() {
   546  				s.Close()
   547  				s.Wait()
   548  			}()
   549  
   550  			p := s.NetworkProtocolInstance(test.protoNum)
   551  
   552  			// We pass nil for all parameters except the NetworkInterface and Stack
   553  			// since Enable only depends on these.
   554  			ep := p.NewEndpoint(&nic, nil)
   555  
   556  			// The endpoint should initially be disabled, regardless the NIC's enabled
   557  			// status.
   558  			if ep.Enabled() {
   559  				t.Fatal("got ep.Enabled() = true, want = false")
   560  			}
   561  			nic.setEnabled(true)
   562  			if ep.Enabled() {
   563  				t.Fatal("got ep.Enabled() = true, want = false")
   564  			}
   565  
   566  			// Attempting to enable the endpoint while the NIC is disabled should
   567  			// fail.
   568  			nic.setEnabled(false)
   569  			err := ep.Enable()
   570  			if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
   571  				t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
   572  			}
   573  			// ep should consider the NIC's enabled status when determining its own
   574  			// enabled status so we "enable" the NIC to read just the endpoint's
   575  			// enabled status.
   576  			nic.setEnabled(true)
   577  			if ep.Enabled() {
   578  				t.Fatal("got ep.Enabled() = true, want = false")
   579  			}
   580  
   581  			// Enabling the interface after the NIC has been enabled should succeed.
   582  			if err := ep.Enable(); err != nil {
   583  				t.Fatalf("ep.Enable(): %s", err)
   584  			}
   585  			if !ep.Enabled() {
   586  				t.Fatal("got ep.Enabled() = false, want = true")
   587  			}
   588  
   589  			// ep should consider the NIC's enabled status when determining its own
   590  			// enabled status.
   591  			nic.setEnabled(false)
   592  			if ep.Enabled() {
   593  				t.Fatal("got ep.Enabled() = true, want = false")
   594  			}
   595  
   596  			// Disabling the endpoint when the NIC is enabled should make the endpoint
   597  			// disabled.
   598  			nic.setEnabled(true)
   599  			ep.Disable()
   600  			if ep.Enabled() {
   601  				t.Fatal("got ep.Enabled() = true, want = false")
   602  			}
   603  		})
   604  	}
   605  }
   606  
   607  func TestIPv4Send(t *testing.T) {
   608  	ctx := newTestContext()
   609  	defer ctx.cleanup()
   610  	s := ctx.s
   611  
   612  	proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
   613  	nic := testInterface{
   614  		testObject: testObject{
   615  			t:  t,
   616  			v4: true,
   617  		},
   618  	}
   619  	ep := proto.NewEndpoint(&nic, nil)
   620  	defer ep.Close()
   621  
   622  	// Allocate and initialize the payload view.
   623  	payload := make([]byte, 100)
   624  	for i := 0; i < len(payload); i++ {
   625  		payload[i] = uint8(i)
   626  	}
   627  
   628  	// Setup the packet buffer.
   629  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   630  		ReserveHeaderBytes: int(ep.MaxHeaderLength()),
   631  		Payload:            buffer.MakeWithData(payload),
   632  	})
   633  	defer pkt.DecRef()
   634  
   635  	// Issue the write.
   636  	nic.testObject.protocol = 123
   637  	nic.testObject.srcAddr = localIPv4Addr
   638  	nic.testObject.dstAddr = remoteIPv4Addr
   639  	nic.testObject.contents = payload
   640  
   641  	r, err := buildIPv4Route(ctx, localIPv4Addr, remoteIPv4Addr)
   642  	if err != nil {
   643  		t.Fatalf("could not find route: %v", err)
   644  	}
   645  	defer r.Release()
   646  	if err := ep.WritePacket(r, stack.NetworkHeaderParams{
   647  		Protocol: 123,
   648  		TTL:      123,
   649  		TOS:      stack.DefaultTOS,
   650  	}, pkt); err != nil {
   651  		t.Fatalf("WritePacket failed: %v", err)
   652  	}
   653  }
   654  
   655  func TestReceive(t *testing.T) {
   656  	tests := []struct {
   657  		name         string
   658  		protoFactory stack.NetworkProtocolFactory
   659  		protoNum     tcpip.NetworkProtocolNumber
   660  		v4           bool
   661  		epAddr       tcpip.AddressWithPrefix
   662  		handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
   663  	}{
   664  		{
   665  			name:         "IPv4",
   666  			protoFactory: ipv4.NewProtocol,
   667  			protoNum:     ipv4.ProtocolNumber,
   668  			v4:           true,
   669  			epAddr:       localIPv4Addr.WithPrefix(),
   670  			handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
   671  				const totalLen = header.IPv4MinimumSize + 30 /* payload length */
   672  
   673  				view := make([]byte, totalLen)
   674  				ip := header.IPv4(view)
   675  				ip.Encode(&header.IPv4Fields{
   676  					TotalLength: totalLen,
   677  					TTL:         ipv4.DefaultTTL,
   678  					Protocol:    10,
   679  					SrcAddr:     remoteIPv4Addr,
   680  					DstAddr:     localIPv4Addr,
   681  				})
   682  				ip.SetChecksum(^ip.CalculateChecksum())
   683  
   684  				// Make payload be non-zero.
   685  				for i := header.IPv4MinimumSize; i < len(view); i++ {
   686  					view[i] = uint8(i)
   687  				}
   688  
   689  				// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
   690  				nic.testObject.protocol = 10
   691  				nic.testObject.srcAddr = remoteIPv4Addr
   692  				nic.testObject.dstAddr = localIPv4Addr
   693  				nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
   694  
   695  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   696  					Payload: buffer.MakeWithData(view),
   697  				})
   698  				ep.HandlePacket(pkt)
   699  				pkt.DecRef()
   700  			},
   701  		},
   702  		{
   703  			name:         "IPv6",
   704  			protoFactory: ipv6.NewProtocol,
   705  			protoNum:     ipv6.ProtocolNumber,
   706  			v4:           false,
   707  			epAddr:       localIPv6Addr.WithPrefix(),
   708  			handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
   709  				const payloadLen = 30
   710  				view := make([]byte, header.IPv6MinimumSize+payloadLen)
   711  				ip := header.IPv6(view)
   712  				ip.Encode(&header.IPv6Fields{
   713  					PayloadLength:     payloadLen,
   714  					TransportProtocol: 10,
   715  					HopLimit:          ipv6.DefaultTTL,
   716  					SrcAddr:           remoteIPv6Addr,
   717  					DstAddr:           localIPv6Addr,
   718  				})
   719  
   720  				// Make payload be non-zero.
   721  				for i := header.IPv6MinimumSize; i < len(view); i++ {
   722  					view[i] = uint8(i)
   723  				}
   724  
   725  				// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
   726  				nic.testObject.protocol = 10
   727  				nic.testObject.srcAddr = remoteIPv6Addr
   728  				nic.testObject.dstAddr = localIPv6Addr
   729  				nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
   730  
   731  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   732  					Payload: buffer.MakeWithData(view),
   733  				})
   734  				ep.HandlePacket(pkt)
   735  				pkt.DecRef()
   736  			},
   737  		},
   738  	}
   739  
   740  	for _, test := range tests {
   741  		t.Run(test.name, func(t *testing.T) {
   742  			s := stack.New(stack.Options{
   743  				NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
   744  			})
   745  			defer func() {
   746  				s.Close()
   747  				s.Wait()
   748  			}()
   749  
   750  			nic := testInterface{
   751  				testObject: testObject{
   752  					t:  t,
   753  					v4: test.v4,
   754  				},
   755  			}
   756  			ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject)
   757  			defer ep.Close()
   758  
   759  			if err := ep.Enable(); err != nil {
   760  				t.Fatalf("ep.Enable(): %s", err)
   761  			}
   762  
   763  			addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
   764  			if !ok {
   765  				t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
   766  			}
   767  			if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
   768  				t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
   769  			} else {
   770  				ep.DecRef()
   771  			}
   772  
   773  			stat := s.Stats().IP.PacketsReceived
   774  			if got := stat.Value(); got != 0 {
   775  				t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
   776  			}
   777  			test.handlePacket(t, ep, &nic)
   778  			if nic.testObject.dataCalls != 1 {
   779  				t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
   780  			}
   781  			if nic.testObject.rawCalls != 1 {
   782  				t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
   783  			}
   784  			if got := stat.Value(); got != 1 {
   785  				t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
   786  			}
   787  		})
   788  	}
   789  }
   790  
   791  func TestIPv4ReceiveControl(t *testing.T) {
   792  	const (
   793  		mtu     = 0xbeef - header.IPv4MinimumSize
   794  		dataLen = 8
   795  	)
   796  
   797  	cases := []struct {
   798  		name           string
   799  		expectedCount  int
   800  		fragmentOffset uint16
   801  		code           header.ICMPv4Code
   802  		transErr       transportError
   803  		trunc          int
   804  	}{
   805  		{
   806  			name:           "FragmentationNeeded",
   807  			expectedCount:  1,
   808  			fragmentOffset: 0,
   809  			code:           header.ICMPv4FragmentationNeeded,
   810  			transErr: transportError{
   811  				origin: tcpip.SockExtErrorOriginICMP,
   812  				typ:    uint8(header.ICMPv4DstUnreachable),
   813  				code:   uint8(header.ICMPv4FragmentationNeeded),
   814  				info:   mtu,
   815  				kind:   stack.PacketTooBigTransportError,
   816  			},
   817  			trunc: 0,
   818  		},
   819  		{
   820  			name:           "Truncated (missing IPv4 header)",
   821  			expectedCount:  0,
   822  			fragmentOffset: 0,
   823  			code:           header.ICMPv4FragmentationNeeded,
   824  			trunc:          header.IPv4MinimumSize + header.ICMPv4MinimumSize,
   825  		},
   826  		{
   827  			name:           "Truncated (partial offending packet's IP header)",
   828  			expectedCount:  0,
   829  			fragmentOffset: 0,
   830  			code:           header.ICMPv4FragmentationNeeded,
   831  			trunc:          header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1,
   832  		},
   833  		{
   834  			name:           "Truncated (partial offending packet's data)",
   835  			expectedCount:  0,
   836  			fragmentOffset: 0,
   837  			code:           header.ICMPv4FragmentationNeeded,
   838  			trunc:          header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1,
   839  		},
   840  		{
   841  			name:           "Port unreachable",
   842  			expectedCount:  1,
   843  			fragmentOffset: 0,
   844  			code:           header.ICMPv4PortUnreachable,
   845  			transErr: transportError{
   846  				origin: tcpip.SockExtErrorOriginICMP,
   847  				typ:    uint8(header.ICMPv4DstUnreachable),
   848  				code:   uint8(header.ICMPv4PortUnreachable),
   849  				kind:   stack.DestinationPortUnreachableTransportError,
   850  			},
   851  			trunc: 0,
   852  		},
   853  		{
   854  			name:           "Non-zero fragment offset",
   855  			expectedCount:  0,
   856  			fragmentOffset: 100,
   857  			code:           header.ICMPv4PortUnreachable,
   858  			trunc:          0,
   859  		},
   860  		{
   861  			name:           "Zero-length packet",
   862  			expectedCount:  0,
   863  			fragmentOffset: 100,
   864  			code:           header.ICMPv4PortUnreachable,
   865  			trunc:          2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen,
   866  		},
   867  	}
   868  	for _, c := range cases {
   869  		t.Run(c.name, func(t *testing.T) {
   870  			ctx := newTestContext()
   871  			defer ctx.cleanup()
   872  			s := ctx.s
   873  
   874  			proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
   875  			nic := testInterface{
   876  				testObject: testObject{
   877  					t: t,
   878  				},
   879  			}
   880  			ep := proto.NewEndpoint(&nic, &nic.testObject)
   881  			defer ep.Close()
   882  
   883  			if err := ep.Enable(); err != nil {
   884  				t.Fatalf("ep.Enable(): %s", err)
   885  			}
   886  
   887  			const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
   888  			view := make([]byte, dataOffset+dataLen)
   889  
   890  			// Create the outer IPv4 header.
   891  			ip := header.IPv4(view)
   892  			ip.Encode(&header.IPv4Fields{
   893  				TotalLength: uint16(len(view) - c.trunc),
   894  				TTL:         20,
   895  				Protocol:    uint8(header.ICMPv4ProtocolNumber),
   896  				SrcAddr:     tcpip.AddrFromSlice([]byte("\x0a\x00\x00\xbb")),
   897  				DstAddr:     localIPv4Addr,
   898  			})
   899  			ip.SetChecksum(^ip.CalculateChecksum())
   900  
   901  			// Create the ICMP header.
   902  			icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
   903  			icmp.SetType(header.ICMPv4DstUnreachable)
   904  			icmp.SetCode(c.code)
   905  			icmp.SetIdent(0xdead)
   906  			icmp.SetSequence(0xbeef)
   907  
   908  			// Create the inner IPv4 header.
   909  			ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
   910  			ip.Encode(&header.IPv4Fields{
   911  				TotalLength:    100,
   912  				TTL:            20,
   913  				Protocol:       10,
   914  				FragmentOffset: c.fragmentOffset,
   915  				SrcAddr:        localIPv4Addr,
   916  				DstAddr:        remoteIPv4Addr,
   917  			})
   918  			ip.SetChecksum(^ip.CalculateChecksum())
   919  
   920  			// Make payload be non-zero.
   921  			for i := dataOffset; i < len(view); i++ {
   922  				view[i] = uint8(i)
   923  			}
   924  
   925  			icmp.SetChecksum(0)
   926  			xsum := ^checksum.Checksum(icmp, 0 /* initial */)
   927  			icmp.SetChecksum(xsum)
   928  
   929  			// Give packet to IPv4 endpoint, dispatcher will validate that
   930  			// it's ok.
   931  			nic.testObject.protocol = 10
   932  			nic.testObject.srcAddr = remoteIPv4Addr
   933  			nic.testObject.dstAddr = localIPv4Addr
   934  			nic.testObject.contents = view[dataOffset:]
   935  			nic.testObject.transErr = c.transErr
   936  
   937  			addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
   938  			if !ok {
   939  				t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
   940  			}
   941  			addr := localIPv4Addr.WithPrefix()
   942  			if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
   943  				t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
   944  			} else {
   945  				ep.DecRef()
   946  			}
   947  
   948  			pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
   949  			ep.HandlePacket(pkt)
   950  			pkt.DecRef()
   951  			if want := c.expectedCount; nic.testObject.controlCalls != want {
   952  				t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
   953  			}
   954  		})
   955  	}
   956  }
   957  
   958  func TestIPv4FragmentationReceive(t *testing.T) {
   959  	ctx := newTestContext()
   960  	defer ctx.cleanup()
   961  	s := ctx.s
   962  
   963  	proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
   964  	nic := testInterface{
   965  		testObject: testObject{
   966  			t:  t,
   967  			v4: true,
   968  		},
   969  	}
   970  	ep := proto.NewEndpoint(&nic, &nic.testObject)
   971  	defer ep.Close()
   972  
   973  	if err := ep.Enable(); err != nil {
   974  		t.Fatalf("ep.Enable(): %s", err)
   975  	}
   976  
   977  	totalLen := header.IPv4MinimumSize + 24
   978  
   979  	frag1 := make([]byte, totalLen)
   980  	ip1 := header.IPv4(frag1)
   981  	ip1.Encode(&header.IPv4Fields{
   982  		TotalLength:    uint16(totalLen),
   983  		TTL:            20,
   984  		Protocol:       10,
   985  		FragmentOffset: 0,
   986  		Flags:          header.IPv4FlagMoreFragments,
   987  		SrcAddr:        remoteIPv4Addr,
   988  		DstAddr:        localIPv4Addr,
   989  	})
   990  	ip1.SetChecksum(^ip1.CalculateChecksum())
   991  
   992  	// Make payload be non-zero.
   993  	for i := header.IPv4MinimumSize; i < totalLen; i++ {
   994  		frag1[i] = uint8(i)
   995  	}
   996  
   997  	frag2 := make([]byte, totalLen)
   998  	ip2 := header.IPv4(frag2)
   999  	ip2.Encode(&header.IPv4Fields{
  1000  		TotalLength:    uint16(totalLen),
  1001  		TTL:            20,
  1002  		Protocol:       10,
  1003  		FragmentOffset: 24,
  1004  		SrcAddr:        remoteIPv4Addr,
  1005  		DstAddr:        localIPv4Addr,
  1006  	})
  1007  	ip2.SetChecksum(^ip2.CalculateChecksum())
  1008  
  1009  	// Make payload be non-zero.
  1010  	for i := header.IPv4MinimumSize; i < totalLen; i++ {
  1011  		frag2[i] = uint8(i)
  1012  	}
  1013  
  1014  	// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
  1015  	nic.testObject.protocol = 10
  1016  	nic.testObject.srcAddr = remoteIPv4Addr
  1017  	nic.testObject.dstAddr = localIPv4Addr
  1018  	nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
  1019  
  1020  	addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
  1021  	if !ok {
  1022  		t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
  1023  	}
  1024  	addr := localIPv4Addr.WithPrefix()
  1025  	if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
  1026  		t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
  1027  	} else {
  1028  		ep.DecRef()
  1029  	}
  1030  
  1031  	// Send first segment.
  1032  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1033  		Payload: buffer.MakeWithData(frag1),
  1034  	})
  1035  	ep.HandlePacket(pkt)
  1036  	pkt.DecRef()
  1037  
  1038  	if nic.testObject.dataCalls != 0 {
  1039  		t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls)
  1040  	}
  1041  	if nic.testObject.rawCalls != 0 {
  1042  		t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls)
  1043  	}
  1044  
  1045  	// Send second segment.
  1046  	pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
  1047  		Payload: buffer.MakeWithData(frag2),
  1048  	})
  1049  	ep.HandlePacket(pkt)
  1050  	pkt.DecRef()
  1051  
  1052  	if nic.testObject.dataCalls != 1 {
  1053  		t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
  1054  	}
  1055  	if nic.testObject.rawCalls != 1 {
  1056  		t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
  1057  	}
  1058  }
  1059  
  1060  func TestIPv6Send(t *testing.T) {
  1061  	ctx := newTestContext()
  1062  	defer ctx.cleanup()
  1063  	s := ctx.s
  1064  
  1065  	proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
  1066  	nic := testInterface{
  1067  		testObject: testObject{
  1068  			t: t,
  1069  		},
  1070  	}
  1071  	ep := proto.NewEndpoint(&nic, nil)
  1072  	defer ep.Close()
  1073  
  1074  	if err := ep.Enable(); err != nil {
  1075  		t.Fatalf("ep.Enable(): %s", err)
  1076  	}
  1077  
  1078  	// Allocate and initialize the payload view.
  1079  	payload := make([]byte, 100)
  1080  	for i := 0; i < len(payload); i++ {
  1081  		payload[i] = uint8(i)
  1082  	}
  1083  
  1084  	// Setup the packet buffer.
  1085  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1086  		ReserveHeaderBytes: int(ep.MaxHeaderLength()),
  1087  		Payload:            buffer.MakeWithData(payload),
  1088  	})
  1089  	defer pkt.DecRef()
  1090  	// Issue the write.
  1091  	nic.testObject.protocol = 123
  1092  	nic.testObject.srcAddr = localIPv6Addr
  1093  	nic.testObject.dstAddr = remoteIPv6Addr
  1094  	nic.testObject.contents = payload
  1095  
  1096  	r, err := buildIPv6Route(ctx, localIPv6Addr, remoteIPv6Addr)
  1097  	if err != nil {
  1098  		t.Fatalf("could not find route: %v", err)
  1099  	}
  1100  	defer r.Release()
  1101  	if err := ep.WritePacket(r, stack.NetworkHeaderParams{
  1102  		Protocol: 123,
  1103  		TTL:      123,
  1104  		TOS:      stack.DefaultTOS,
  1105  	}, pkt); err != nil {
  1106  		t.Fatalf("WritePacket failed: %v", err)
  1107  	}
  1108  }
  1109  
  1110  func TestIPv6ReceiveControl(t *testing.T) {
  1111  	const (
  1112  		mtu     = 0xffff
  1113  		dataLen = 8
  1114  	)
  1115  	outerSrcAddr := tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa\x00\x00\x00"))
  1116  
  1117  	newUint16 := func(v uint16) *uint16 { return &v }
  1118  
  1119  	portUnreachableTransErr := transportError{
  1120  		origin: tcpip.SockExtErrorOriginICMP6,
  1121  		typ:    uint8(header.ICMPv6DstUnreachable),
  1122  		code:   uint8(header.ICMPv6PortUnreachable),
  1123  		kind:   stack.DestinationPortUnreachableTransportError,
  1124  	}
  1125  
  1126  	cases := []struct {
  1127  		name           string
  1128  		expectedCount  int
  1129  		fragmentOffset *uint16
  1130  		typ            header.ICMPv6Type
  1131  		code           header.ICMPv6Code
  1132  		transErr       transportError
  1133  		trunc          int
  1134  	}{
  1135  		{
  1136  			name:           "PacketTooBig",
  1137  			expectedCount:  1,
  1138  			fragmentOffset: nil,
  1139  			typ:            header.ICMPv6PacketTooBig,
  1140  			code:           header.ICMPv6UnusedCode,
  1141  			transErr: transportError{
  1142  				origin: tcpip.SockExtErrorOriginICMP6,
  1143  				typ:    uint8(header.ICMPv6PacketTooBig),
  1144  				code:   uint8(header.ICMPv6UnusedCode),
  1145  				info:   mtu,
  1146  				kind:   stack.PacketTooBigTransportError,
  1147  			},
  1148  			trunc: 0,
  1149  		},
  1150  		{
  1151  			name:           "Truncated (missing offending packet's IPv6 header)",
  1152  			expectedCount:  0,
  1153  			fragmentOffset: nil,
  1154  			typ:            header.ICMPv6PacketTooBig,
  1155  			code:           header.ICMPv6UnusedCode,
  1156  			trunc:          header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize,
  1157  		},
  1158  		{
  1159  			name:           "Truncated PacketTooBig (partial offending packet's IPv6 header)",
  1160  			expectedCount:  0,
  1161  			fragmentOffset: nil,
  1162  			typ:            header.ICMPv6PacketTooBig,
  1163  			code:           header.ICMPv6UnusedCode,
  1164  			trunc:          header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1,
  1165  		},
  1166  		{
  1167  			name:           "Truncated (partial offending packet's data)",
  1168  			expectedCount:  0,
  1169  			fragmentOffset: nil,
  1170  			typ:            header.ICMPv6PacketTooBig,
  1171  			code:           header.ICMPv6UnusedCode,
  1172  			trunc:          header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1,
  1173  		},
  1174  		{
  1175  			name:           "Port unreachable",
  1176  			expectedCount:  1,
  1177  			fragmentOffset: nil,
  1178  			typ:            header.ICMPv6DstUnreachable,
  1179  			code:           header.ICMPv6PortUnreachable,
  1180  			transErr:       portUnreachableTransErr,
  1181  			trunc:          0,
  1182  		},
  1183  		{
  1184  			name:           "Truncated DstPortUnreachable (partial offending packet's IP header)",
  1185  			expectedCount:  0,
  1186  			fragmentOffset: nil,
  1187  			typ:            header.ICMPv6DstUnreachable,
  1188  			code:           header.ICMPv6PortUnreachable,
  1189  			trunc:          header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1,
  1190  		},
  1191  		{
  1192  			name:           "DstPortUnreachable for Fragmented, zero offset",
  1193  			expectedCount:  1,
  1194  			fragmentOffset: newUint16(0),
  1195  			typ:            header.ICMPv6DstUnreachable,
  1196  			code:           header.ICMPv6PortUnreachable,
  1197  			transErr:       portUnreachableTransErr,
  1198  			trunc:          0,
  1199  		},
  1200  		{
  1201  			name:           "DstPortUnreachable for Non-zero fragment offset",
  1202  			expectedCount:  0,
  1203  			fragmentOffset: newUint16(100),
  1204  			typ:            header.ICMPv6DstUnreachable,
  1205  			code:           header.ICMPv6PortUnreachable,
  1206  			transErr:       portUnreachableTransErr,
  1207  			trunc:          0,
  1208  		},
  1209  		{
  1210  			name:           "Zero-length packet",
  1211  			expectedCount:  0,
  1212  			fragmentOffset: nil,
  1213  			typ:            header.ICMPv6DstUnreachable,
  1214  			code:           header.ICMPv6PortUnreachable,
  1215  			trunc:          2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen,
  1216  		},
  1217  	}
  1218  	for _, c := range cases {
  1219  		t.Run(c.name, func(t *testing.T) {
  1220  			ctx := newTestContext()
  1221  			defer ctx.cleanup()
  1222  			s := ctx.s
  1223  
  1224  			proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
  1225  			nic := testInterface{
  1226  				testObject: testObject{
  1227  					t: t,
  1228  				},
  1229  			}
  1230  			ep := proto.NewEndpoint(&nic, &nic.testObject)
  1231  			defer ep.Close()
  1232  
  1233  			if err := ep.Enable(); err != nil {
  1234  				t.Fatalf("ep.Enable(): %s", err)
  1235  			}
  1236  
  1237  			dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
  1238  			if c.fragmentOffset != nil {
  1239  				dataOffset += header.IPv6FragmentHeaderSize
  1240  			}
  1241  			view := make([]byte, dataOffset+dataLen)
  1242  
  1243  			// Create the outer IPv6 header.
  1244  			ip := header.IPv6(view)
  1245  			ip.Encode(&header.IPv6Fields{
  1246  				PayloadLength:     uint16(len(view) - header.IPv6MinimumSize - c.trunc),
  1247  				TransportProtocol: header.ICMPv6ProtocolNumber,
  1248  				HopLimit:          20,
  1249  				SrcAddr:           outerSrcAddr,
  1250  				DstAddr:           localIPv6Addr,
  1251  			})
  1252  
  1253  			// Create the ICMP header.
  1254  			icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
  1255  			icmp.SetType(c.typ)
  1256  			icmp.SetCode(c.code)
  1257  			icmp.SetIdent(0xdead)
  1258  			icmp.SetSequence(0xbeef)
  1259  
  1260  			var extHdrs header.IPv6ExtHdrSerializer
  1261  			// Build the fragmentation header if needed.
  1262  			if c.fragmentOffset != nil {
  1263  				extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
  1264  					FragmentOffset: *c.fragmentOffset,
  1265  					M:              true,
  1266  					Identification: 0x12345678,
  1267  				})
  1268  			}
  1269  
  1270  			// Create the inner IPv6 header.
  1271  			ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
  1272  			ip.Encode(&header.IPv6Fields{
  1273  				PayloadLength:     100,
  1274  				TransportProtocol: 10,
  1275  				HopLimit:          20,
  1276  				SrcAddr:           localIPv6Addr,
  1277  				DstAddr:           remoteIPv6Addr,
  1278  				ExtensionHeaders:  extHdrs,
  1279  			})
  1280  
  1281  			// Make payload be non-zero.
  1282  			for i := dataOffset; i < len(view); i++ {
  1283  				view[i] = uint8(i)
  1284  			}
  1285  
  1286  			// Give packet to IPv6 endpoint, dispatcher will validate that
  1287  			// it's ok.
  1288  			nic.testObject.protocol = 10
  1289  			nic.testObject.srcAddr = remoteIPv6Addr
  1290  			nic.testObject.dstAddr = localIPv6Addr
  1291  			nic.testObject.contents = view[dataOffset:]
  1292  			nic.testObject.transErr = c.transErr
  1293  
  1294  			// Set ICMPv6 checksum.
  1295  			icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
  1296  				Header: icmp,
  1297  				Src:    outerSrcAddr,
  1298  				Dst:    localIPv6Addr,
  1299  			}))
  1300  
  1301  			addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
  1302  			if !ok {
  1303  				t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
  1304  			}
  1305  			addr := localIPv6Addr.WithPrefix()
  1306  			if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
  1307  				t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
  1308  			} else {
  1309  				ep.DecRef()
  1310  			}
  1311  			pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
  1312  			ep.HandlePacket(pkt)
  1313  			pkt.DecRef()
  1314  			if want := c.expectedCount; nic.testObject.controlCalls != want {
  1315  				t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
  1316  			}
  1317  		})
  1318  	}
  1319  }
  1320  
  1321  // truncatedPacket returns a PacketBuffer based on a truncated view. If view,
  1322  // after truncation, is large enough to hold a network header, it makes part of
  1323  // view the packet's NetworkHeader and the rest its Data. Otherwise all of view
  1324  // becomes Data.
  1325  func truncatedPacket(view []byte, trunc, netHdrLen int) *stack.PacketBuffer {
  1326  	v := view[:len(view)-trunc]
  1327  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1328  		Payload: buffer.MakeWithData(v),
  1329  	})
  1330  	return pkt
  1331  }
  1332  
  1333  func TestWriteHeaderIncludedPacket(t *testing.T) {
  1334  	const (
  1335  		nicID          = 1
  1336  		transportProto = 5
  1337  
  1338  		dataLen = 4
  1339  	)
  1340  
  1341  	dataBuf := [dataLen]byte{1, 2, 3, 4}
  1342  	data := dataBuf[:]
  1343  
  1344  	ipv4Options := header.IPv4OptionsSerializer{
  1345  		&header.IPv4SerializableListEndOption{},
  1346  		&header.IPv4SerializableNOPOption{},
  1347  		&header.IPv4SerializableListEndOption{},
  1348  		&header.IPv4SerializableNOPOption{},
  1349  	}
  1350  
  1351  	expectOptions := header.IPv4Options{
  1352  		byte(header.IPv4OptionListEndType),
  1353  		byte(header.IPv4OptionNOPType),
  1354  		byte(header.IPv4OptionListEndType),
  1355  		byte(header.IPv4OptionNOPType),
  1356  	}
  1357  
  1358  	ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
  1359  	ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
  1360  
  1361  	var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
  1362  	ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
  1363  	if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
  1364  		t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
  1365  	}
  1366  	if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
  1367  		t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1368  	}
  1369  
  1370  	tests := []struct {
  1371  		name         string
  1372  		protoFactory stack.NetworkProtocolFactory
  1373  		protoNum     tcpip.NetworkProtocolNumber
  1374  		nicAddr      tcpip.AddressWithPrefix
  1375  		remoteAddr   tcpip.Address
  1376  		pktGen       func(*testing.T, tcpip.Address) buffer.Buffer
  1377  		checker      func(*testing.T, *stack.PacketBuffer, tcpip.Address)
  1378  		expectedErr  tcpip.Error
  1379  	}{
  1380  		{
  1381  			name:         "IPv4",
  1382  			protoFactory: ipv4.NewProtocol,
  1383  			protoNum:     ipv4.ProtocolNumber,
  1384  			nicAddr:      localIPv4AddrWithPrefix,
  1385  			remoteAddr:   remoteIPv4Addr,
  1386  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1387  				totalLen := header.IPv4MinimumSize + len(data)
  1388  				hdr := prependable.New(totalLen)
  1389  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1390  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1391  				}
  1392  				ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  1393  				ip.Encode(&header.IPv4Fields{
  1394  					Protocol: transportProto,
  1395  					TTL:      ipv4.DefaultTTL,
  1396  					SrcAddr:  src,
  1397  					DstAddr:  remoteIPv4Addr,
  1398  				})
  1399  				return buffer.MakeWithData(hdr.View())
  1400  			},
  1401  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1402  				if src == header.IPv4Any {
  1403  					src = localIPv4Addr
  1404  				}
  1405  
  1406  				netHdr := pkt.NetworkHeader()
  1407  
  1408  				if len(netHdr.Slice()) != header.IPv4MinimumSize {
  1409  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), header.IPv4MinimumSize)
  1410  				}
  1411  
  1412  				payload := stack.PayloadSince(netHdr)
  1413  				defer payload.Release()
  1414  				checker.IPv4(t, payload,
  1415  					checker.SrcAddr(src),
  1416  					checker.DstAddr(remoteIPv4Addr),
  1417  					checker.IPv4HeaderLength(header.IPv4MinimumSize),
  1418  					checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
  1419  					checker.IPPayload(data),
  1420  				)
  1421  			},
  1422  		},
  1423  		{
  1424  			name:         "IPv4 with IHL too small",
  1425  			protoFactory: ipv4.NewProtocol,
  1426  			protoNum:     ipv4.ProtocolNumber,
  1427  			nicAddr:      localIPv4AddrWithPrefix,
  1428  			remoteAddr:   remoteIPv4Addr,
  1429  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1430  				totalLen := header.IPv4MinimumSize + len(data)
  1431  				hdr := prependable.New(totalLen)
  1432  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1433  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1434  				}
  1435  				ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  1436  				ip.Encode(&header.IPv4Fields{
  1437  					Protocol: transportProto,
  1438  					TTL:      ipv4.DefaultTTL,
  1439  					SrcAddr:  src,
  1440  					DstAddr:  remoteIPv4Addr,
  1441  				})
  1442  				ip.SetHeaderLength(header.IPv4MinimumSize - 1)
  1443  				return buffer.MakeWithData(hdr.View())
  1444  			},
  1445  			expectedErr: &tcpip.ErrMalformedHeader{},
  1446  		},
  1447  		{
  1448  			name:         "IPv4 too small",
  1449  			protoFactory: ipv4.NewProtocol,
  1450  			protoNum:     ipv4.ProtocolNumber,
  1451  			nicAddr:      localIPv4AddrWithPrefix,
  1452  			remoteAddr:   remoteIPv4Addr,
  1453  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1454  				ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
  1455  				ip.Encode(&header.IPv4Fields{
  1456  					Protocol: transportProto,
  1457  					TTL:      ipv4.DefaultTTL,
  1458  					SrcAddr:  src,
  1459  					DstAddr:  remoteIPv4Addr,
  1460  				})
  1461  				return buffer.MakeWithData(ip[:len(ip)-1])
  1462  			},
  1463  			expectedErr: &tcpip.ErrMalformedHeader{},
  1464  		},
  1465  		{
  1466  			name:         "IPv4 minimum size",
  1467  			protoFactory: ipv4.NewProtocol,
  1468  			protoNum:     ipv4.ProtocolNumber,
  1469  			nicAddr:      localIPv4AddrWithPrefix,
  1470  			remoteAddr:   remoteIPv4Addr,
  1471  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1472  				ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
  1473  				ip.Encode(&header.IPv4Fields{
  1474  					Protocol: transportProto,
  1475  					TTL:      ipv4.DefaultTTL,
  1476  					SrcAddr:  src,
  1477  					DstAddr:  remoteIPv4Addr,
  1478  				})
  1479  				return buffer.MakeWithData(ip)
  1480  			},
  1481  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1482  				if src == header.IPv4Any {
  1483  					src = localIPv4Addr
  1484  				}
  1485  
  1486  				netHdr := pkt.NetworkHeader()
  1487  
  1488  				if len(netHdr.Slice()) != header.IPv4MinimumSize {
  1489  					t.Errorf("got len(netHdr.Slice()) = %d, want = %d", len(netHdr.Slice()), header.IPv4MinimumSize)
  1490  				}
  1491  
  1492  				payload := stack.PayloadSince(netHdr)
  1493  				defer payload.Release()
  1494  				checker.IPv4(t, payload,
  1495  					checker.SrcAddr(src),
  1496  					checker.DstAddr(remoteIPv4Addr),
  1497  					checker.IPv4HeaderLength(header.IPv4MinimumSize),
  1498  					checker.IPFullLength(header.IPv4MinimumSize),
  1499  					checker.IPPayload(nil),
  1500  				)
  1501  			},
  1502  		},
  1503  		{
  1504  			name:         "IPv4 with options",
  1505  			protoFactory: ipv4.NewProtocol,
  1506  			protoNum:     ipv4.ProtocolNumber,
  1507  			nicAddr:      localIPv4AddrWithPrefix,
  1508  			remoteAddr:   remoteIPv4Addr,
  1509  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1510  				ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
  1511  				totalLen := ipHdrLen + len(data)
  1512  				hdr := prependable.New(totalLen)
  1513  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1514  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1515  				}
  1516  				ip := header.IPv4(hdr.Prepend(ipHdrLen))
  1517  				ip.Encode(&header.IPv4Fields{
  1518  					Protocol: transportProto,
  1519  					TTL:      ipv4.DefaultTTL,
  1520  					SrcAddr:  src,
  1521  					DstAddr:  remoteIPv4Addr,
  1522  					Options:  ipv4Options,
  1523  				})
  1524  				return buffer.MakeWithData(hdr.View())
  1525  			},
  1526  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1527  				if src == header.IPv4Any {
  1528  					src = localIPv4Addr
  1529  				}
  1530  
  1531  				netHdr := pkt.NetworkHeader()
  1532  
  1533  				hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
  1534  				if len(netHdr.Slice()) != hdrLen {
  1535  					t.Errorf("got len(netHdr.Slice()) = %d, want = %d", len(netHdr.Slice()), hdrLen)
  1536  				}
  1537  
  1538  				payload := stack.PayloadSince(netHdr)
  1539  				defer payload.Release()
  1540  				checker.IPv4(t, payload,
  1541  					checker.SrcAddr(src),
  1542  					checker.DstAddr(remoteIPv4Addr),
  1543  					checker.IPv4HeaderLength(hdrLen),
  1544  					checker.IPFullLength(uint16(hdrLen+len(data))),
  1545  					checker.IPv4Options(expectOptions),
  1546  					checker.IPPayload(data),
  1547  				)
  1548  			},
  1549  		},
  1550  		{
  1551  			name:         "IPv4 with options and data across views",
  1552  			protoFactory: ipv4.NewProtocol,
  1553  			protoNum:     ipv4.ProtocolNumber,
  1554  			nicAddr:      localIPv4AddrWithPrefix,
  1555  			remoteAddr:   remoteIPv4Addr,
  1556  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1557  				ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
  1558  				ip.Encode(&header.IPv4Fields{
  1559  					Protocol: transportProto,
  1560  					TTL:      ipv4.DefaultTTL,
  1561  					SrcAddr:  src,
  1562  					DstAddr:  remoteIPv4Addr,
  1563  					Options:  ipv4Options,
  1564  				})
  1565  				buf := buffer.MakeWithData(ip)
  1566  				buf.Append(buffer.NewViewWithData(data))
  1567  				return buf
  1568  			},
  1569  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1570  				if src == header.IPv4Any {
  1571  					src = localIPv4Addr
  1572  				}
  1573  
  1574  				netHdr := pkt.NetworkHeader()
  1575  
  1576  				hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
  1577  				if len(netHdr.Slice()) != hdrLen {
  1578  					t.Errorf("got len(netHdr.Slice()) = %d, want = %d", len(netHdr.Slice()), hdrLen)
  1579  				}
  1580  
  1581  				payload := stack.PayloadSince(netHdr)
  1582  				defer payload.Release()
  1583  				checker.IPv4(t, payload,
  1584  					checker.SrcAddr(src),
  1585  					checker.DstAddr(remoteIPv4Addr),
  1586  					checker.IPv4HeaderLength(hdrLen),
  1587  					checker.IPFullLength(uint16(hdrLen+len(data))),
  1588  					checker.IPv4Options(expectOptions),
  1589  					checker.IPPayload(data),
  1590  				)
  1591  			},
  1592  		},
  1593  		{
  1594  			name:         "IPv6",
  1595  			protoFactory: ipv6.NewProtocol,
  1596  			protoNum:     ipv6.ProtocolNumber,
  1597  			nicAddr:      localIPv6AddrWithPrefix,
  1598  			remoteAddr:   remoteIPv6Addr,
  1599  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1600  				totalLen := header.IPv6MinimumSize + len(data)
  1601  				hdr := prependable.New(totalLen)
  1602  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1603  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1604  				}
  1605  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  1606  				ip.Encode(&header.IPv6Fields{
  1607  					TransportProtocol: transportProto,
  1608  					HopLimit:          ipv6.DefaultTTL,
  1609  					SrcAddr:           src,
  1610  					DstAddr:           remoteIPv6Addr,
  1611  				})
  1612  				return buffer.MakeWithData(hdr.View())
  1613  			},
  1614  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1615  				if src == header.IPv6Any {
  1616  					src = localIPv6Addr
  1617  				}
  1618  
  1619  				netHdr := pkt.NetworkHeader()
  1620  
  1621  				if len(netHdr.Slice()) != header.IPv6MinimumSize {
  1622  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), header.IPv6MinimumSize)
  1623  				}
  1624  
  1625  				payload := stack.PayloadSince(netHdr)
  1626  				defer payload.Release()
  1627  				checker.IPv6(t, payload,
  1628  					checker.SrcAddr(src),
  1629  					checker.DstAddr(remoteIPv6Addr),
  1630  					checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
  1631  					checker.IPPayload(data),
  1632  				)
  1633  			},
  1634  		},
  1635  		{
  1636  			name:         "IPv6 with extension header",
  1637  			protoFactory: ipv6.NewProtocol,
  1638  			protoNum:     ipv6.ProtocolNumber,
  1639  			nicAddr:      localIPv6AddrWithPrefix,
  1640  			remoteAddr:   remoteIPv6Addr,
  1641  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1642  				totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
  1643  				hdr := prependable.New(totalLen)
  1644  				if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
  1645  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
  1646  				}
  1647  				if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
  1648  					t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
  1649  				}
  1650  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  1651  				ip.Encode(&header.IPv6Fields{
  1652  					// NB: we're lying about transport protocol here to verify the raw
  1653  					// fragment header bytes.
  1654  					TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
  1655  					HopLimit:          ipv6.DefaultTTL,
  1656  					SrcAddr:           src,
  1657  					DstAddr:           remoteIPv6Addr,
  1658  				})
  1659  				return buffer.MakeWithData(hdr.View())
  1660  			},
  1661  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1662  				if src == header.IPv6Any {
  1663  					src = localIPv6Addr
  1664  				}
  1665  
  1666  				netHdr := pkt.NetworkHeader()
  1667  
  1668  				if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.Slice()) != want {
  1669  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), want)
  1670  				}
  1671  
  1672  				payload := stack.PayloadSince(netHdr)
  1673  				defer payload.Release()
  1674  				checker.IPv6(t, payload,
  1675  					checker.SrcAddr(src),
  1676  					checker.DstAddr(remoteIPv6Addr),
  1677  					checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
  1678  					checker.IPPayload(ipv6PayloadWithExtHdr),
  1679  				)
  1680  			},
  1681  		},
  1682  		{
  1683  			name:         "IPv6 minimum size",
  1684  			protoFactory: ipv6.NewProtocol,
  1685  			protoNum:     ipv6.ProtocolNumber,
  1686  			nicAddr:      localIPv6AddrWithPrefix,
  1687  			remoteAddr:   remoteIPv6Addr,
  1688  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1689  				ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
  1690  				ip.Encode(&header.IPv6Fields{
  1691  					TransportProtocol: transportProto,
  1692  					HopLimit:          ipv6.DefaultTTL,
  1693  					SrcAddr:           src,
  1694  					DstAddr:           remoteIPv6Addr,
  1695  				})
  1696  				return buffer.MakeWithData(ip)
  1697  			},
  1698  			checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
  1699  				if src == header.IPv6Any {
  1700  					src = localIPv6Addr
  1701  				}
  1702  
  1703  				netHdr := pkt.NetworkHeader()
  1704  
  1705  				if len(netHdr.Slice()) != header.IPv6MinimumSize {
  1706  					t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.Slice()), header.IPv6MinimumSize)
  1707  				}
  1708  
  1709  				payload := stack.PayloadSince(netHdr)
  1710  				defer payload.Release()
  1711  				checker.IPv6(t, payload,
  1712  					checker.SrcAddr(src),
  1713  					checker.DstAddr(remoteIPv6Addr),
  1714  					checker.IPFullLength(header.IPv6MinimumSize),
  1715  					checker.IPPayload(nil),
  1716  				)
  1717  			},
  1718  		},
  1719  		{
  1720  			name:         "IPv6 too small",
  1721  			protoFactory: ipv6.NewProtocol,
  1722  			protoNum:     ipv6.ProtocolNumber,
  1723  			nicAddr:      localIPv6AddrWithPrefix,
  1724  			remoteAddr:   remoteIPv6Addr,
  1725  			pktGen: func(t *testing.T, src tcpip.Address) buffer.Buffer {
  1726  				ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
  1727  				ip.Encode(&header.IPv6Fields{
  1728  					TransportProtocol: transportProto,
  1729  					HopLimit:          ipv6.DefaultTTL,
  1730  					SrcAddr:           src,
  1731  					DstAddr:           remoteIPv4Addr,
  1732  				})
  1733  				return buffer.MakeWithData(ip[:len(ip)-1])
  1734  			},
  1735  			expectedErr: &tcpip.ErrMalformedHeader{},
  1736  		},
  1737  	}
  1738  
  1739  	for _, test := range tests {
  1740  		t.Run(test.name, func(t *testing.T) {
  1741  			subTests := []struct {
  1742  				name    string
  1743  				srcAddr tcpip.Address
  1744  			}{
  1745  				{
  1746  					name:    "unspecified source",
  1747  					srcAddr: tcpip.AddrFromSlice([]byte(strings.Repeat("\x00", test.nicAddr.Address.Len()))),
  1748  				},
  1749  				{
  1750  					name:    "random source",
  1751  					srcAddr: tcpip.AddrFromSlice([]byte(strings.Repeat("\xab", test.nicAddr.Address.Len()))),
  1752  				},
  1753  			}
  1754  
  1755  			for _, subTest := range subTests {
  1756  				t.Run(subTest.name, func(t *testing.T) {
  1757  					s := stack.New(stack.Options{
  1758  						NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
  1759  					})
  1760  					defer func() {
  1761  						s.Close()
  1762  						s.Wait()
  1763  					}()
  1764  
  1765  					e := channel.New(1, header.IPv6MinimumMTU, "")
  1766  					defer e.Close()
  1767  					if err := s.CreateNIC(nicID, e); err != nil {
  1768  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
  1769  					}
  1770  					protocolAddr := tcpip.ProtocolAddress{
  1771  						Protocol:          test.protoNum,
  1772  						AddressWithPrefix: test.nicAddr,
  1773  					}
  1774  					if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
  1775  						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
  1776  					}
  1777  
  1778  					s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
  1779  
  1780  					r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
  1781  					if err != nil {
  1782  						t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
  1783  					}
  1784  					defer r.Release()
  1785  
  1786  					{
  1787  						pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1788  							Payload: test.pktGen(t, subTest.srcAddr),
  1789  						})
  1790  						err := r.WriteHeaderIncludedPacket(pkt)
  1791  						pkt.DecRef()
  1792  						if diff := cmp.Diff(test.expectedErr, err); diff != "" {
  1793  							t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
  1794  						}
  1795  					}
  1796  
  1797  					if test.expectedErr != nil {
  1798  						return
  1799  					}
  1800  
  1801  					pkt := e.Read()
  1802  					if pkt == nil {
  1803  						t.Fatal("expected a packet to be written")
  1804  					}
  1805  					test.checker(t, pkt, subTest.srcAddr)
  1806  					pkt.DecRef()
  1807  				})
  1808  			}
  1809  		})
  1810  	}
  1811  }
  1812  
  1813  // Test that the included data in an ICMP error packet conforms to the
  1814  // requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3
  1815  func TestICMPInclusionSize(t *testing.T) {
  1816  	const (
  1817  		replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize
  1818  		replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize
  1819  		targetSize4        = header.IPv4MinimumProcessableDatagramSize
  1820  		targetSize6        = header.IPv6MinimumMTU
  1821  		// A protocol number that will cause an error response.
  1822  		reservedProtocol = 254
  1823  	)
  1824  
  1825  	// IPv4 function to create a IP packet and send it to the stack.
  1826  	// The packet should generate an error response. We can do that by using an
  1827  	// unknown transport protocol (254).
  1828  	rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) []byte {
  1829  		totalLen := header.IPv4MinimumSize + len(payload)
  1830  		hdr := prependable.New(header.IPv4MinimumSize)
  1831  		ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  1832  		ip.Encode(&header.IPv4Fields{
  1833  			TotalLength: uint16(totalLen),
  1834  			Protocol:    reservedProtocol,
  1835  			TTL:         ipv4.DefaultTTL,
  1836  			SrcAddr:     src,
  1837  			DstAddr:     localIPv4Addr,
  1838  		})
  1839  		ip.SetChecksum(^ip.CalculateChecksum())
  1840  		buf := buffer.MakeWithData(hdr.View())
  1841  		buf.Append(buffer.NewViewWithData(payload))
  1842  		// Take a copy before InjectInbound takes ownership of vv
  1843  		// as vv may be changed during the call.
  1844  		v := buf.Flatten()
  1845  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1846  			Payload: buf,
  1847  		})
  1848  		e.InjectInbound(header.IPv4ProtocolNumber, pkt)
  1849  		pkt.DecRef()
  1850  		return v
  1851  	}
  1852  
  1853  	// IPv6 function to create a packet and send it to the stack.
  1854  	// The packet should be errant in a way that causes the stack to send an
  1855  	// ICMP error response and have enough data to allow the testing of the
  1856  	// inclusion of the errant packet. Use `unknown next header' to generate
  1857  	// the error.
  1858  	rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) []byte {
  1859  		hdr := prependable.New(header.IPv6MinimumSize)
  1860  		ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  1861  		ip.Encode(&header.IPv6Fields{
  1862  			PayloadLength:     uint16(len(payload)),
  1863  			TransportProtocol: reservedProtocol,
  1864  			HopLimit:          ipv6.DefaultTTL,
  1865  			SrcAddr:           src,
  1866  			DstAddr:           localIPv6Addr,
  1867  		})
  1868  		buf := buffer.MakeWithData(hdr.View())
  1869  		buf.Append(buffer.NewViewWithData(payload))
  1870  		// Take a copy before InjectInbound takes ownership of vv
  1871  		// as vv may be changed during the call.
  1872  		v := buf.Flatten()
  1873  
  1874  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  1875  			Payload: buf,
  1876  		})
  1877  		e.InjectInbound(header.IPv6ProtocolNumber, pkt)
  1878  		pkt.DecRef()
  1879  		return v
  1880  	}
  1881  
  1882  	v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload []byte) {
  1883  		// We already know the entire packet is the right size so we can use its
  1884  		// length to calculate the right payload size to check.
  1885  		expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
  1886  		p := stack.PayloadSince(pkt.NetworkHeader())
  1887  		defer p.Release()
  1888  		checker.IPv4(t, p,
  1889  			checker.SrcAddr(localIPv4Addr),
  1890  			checker.DstAddr(remoteIPv4Addr),
  1891  			checker.IPv4HeaderLength(header.IPv4MinimumSize),
  1892  			checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)),
  1893  			checker.ICMPv4(
  1894  				checker.ICMPv4Checksum(),
  1895  				checker.ICMPv4Type(header.ICMPv4DstUnreachable),
  1896  				checker.ICMPv4Code(header.ICMPv4ProtoUnreachable),
  1897  				checker.ICMPv4Payload(payload[:expectedPayloadLength]),
  1898  			),
  1899  		)
  1900  	}
  1901  
  1902  	v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload []byte) {
  1903  		// We already know the entire packet is the right size so we can use its
  1904  		// length to calculate the right payload size to check.
  1905  		expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize
  1906  		p := stack.PayloadSince(pkt.NetworkHeader())
  1907  		defer p.Release()
  1908  		checker.IPv6(t, p,
  1909  			checker.SrcAddr(localIPv6Addr),
  1910  			checker.DstAddr(remoteIPv6Addr),
  1911  			checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)),
  1912  			checker.ICMPv6(
  1913  				checker.ICMPv6Type(header.ICMPv6ParamProblem),
  1914  				checker.ICMPv6Code(header.ICMPv6UnknownHeader),
  1915  				checker.ICMPv6Payload(payload[:expectedPayloadLength]),
  1916  			),
  1917  		)
  1918  	}
  1919  	tests := []struct {
  1920  		name          string
  1921  		srcAddress    tcpip.Address
  1922  		injector      func(*channel.Endpoint, tcpip.Address, []byte) []byte
  1923  		checker       func(*testing.T, *stack.PacketBuffer, []byte)
  1924  		payloadLength int    // Not including IP header.
  1925  		linkMTU       uint32 // Largest IP packet that the link can send as payload.
  1926  		replyLength   int    // Total size of IP/ICMP packet expected back.
  1927  	}{
  1928  		{
  1929  			name:          "IPv4 exact match",
  1930  			srcAddress:    remoteIPv4Addr,
  1931  			injector:      rxIPv4Bad,
  1932  			checker:       v4Checker,
  1933  			payloadLength: targetSize4 - replyHeaderLength4,
  1934  			linkMTU:       targetSize4,
  1935  			replyLength:   targetSize4,
  1936  		},
  1937  		{
  1938  			name:          "IPv4 larger MTU",
  1939  			srcAddress:    remoteIPv4Addr,
  1940  			injector:      rxIPv4Bad,
  1941  			checker:       v4Checker,
  1942  			payloadLength: targetSize4,
  1943  			linkMTU:       targetSize4 + 1000,
  1944  			replyLength:   targetSize4,
  1945  		},
  1946  		{
  1947  			name:          "IPv4 smaller MTU",
  1948  			srcAddress:    remoteIPv4Addr,
  1949  			injector:      rxIPv4Bad,
  1950  			checker:       v4Checker,
  1951  			payloadLength: targetSize4,
  1952  			linkMTU:       targetSize4 - 50,
  1953  			replyLength:   targetSize4 - 50,
  1954  		},
  1955  		{
  1956  			name:          "IPv4 payload exceeds",
  1957  			srcAddress:    remoteIPv4Addr,
  1958  			injector:      rxIPv4Bad,
  1959  			checker:       v4Checker,
  1960  			payloadLength: targetSize4 + 10,
  1961  			linkMTU:       targetSize4,
  1962  			replyLength:   targetSize4,
  1963  		},
  1964  		{
  1965  			name:          "IPv4 1 byte less",
  1966  			srcAddress:    remoteIPv4Addr,
  1967  			injector:      rxIPv4Bad,
  1968  			checker:       v4Checker,
  1969  			payloadLength: targetSize4 - replyHeaderLength4 - 1,
  1970  			linkMTU:       targetSize4,
  1971  			replyLength:   targetSize4 - 1,
  1972  		},
  1973  		{
  1974  			name:          "IPv4 No payload",
  1975  			srcAddress:    remoteIPv4Addr,
  1976  			injector:      rxIPv4Bad,
  1977  			checker:       v4Checker,
  1978  			payloadLength: 0,
  1979  			linkMTU:       targetSize4,
  1980  			replyLength:   replyHeaderLength4,
  1981  		},
  1982  		{
  1983  			name:          "IPv6 exact match",
  1984  			srcAddress:    remoteIPv6Addr,
  1985  			injector:      rxIPv6Bad,
  1986  			checker:       v6Checker,
  1987  			payloadLength: targetSize6 - replyHeaderLength6,
  1988  			linkMTU:       targetSize6,
  1989  			replyLength:   targetSize6,
  1990  		},
  1991  		{
  1992  			name:          "IPv6 larger MTU",
  1993  			srcAddress:    remoteIPv6Addr,
  1994  			injector:      rxIPv6Bad,
  1995  			checker:       v6Checker,
  1996  			payloadLength: targetSize6,
  1997  			linkMTU:       targetSize6 + 400,
  1998  			replyLength:   targetSize6,
  1999  		},
  2000  		// NB. No "smaller MTU" test here as less than 1280 is not permitted
  2001  		// in IPv6.
  2002  		{
  2003  			name:          "IPv6 payload exceeds",
  2004  			srcAddress:    remoteIPv6Addr,
  2005  			injector:      rxIPv6Bad,
  2006  			checker:       v6Checker,
  2007  			payloadLength: targetSize6,
  2008  			linkMTU:       targetSize6,
  2009  			replyLength:   targetSize6,
  2010  		},
  2011  		{
  2012  			name:          "IPv6 1 byte less",
  2013  			srcAddress:    remoteIPv6Addr,
  2014  			injector:      rxIPv6Bad,
  2015  			checker:       v6Checker,
  2016  			payloadLength: targetSize6 - replyHeaderLength6 - 1,
  2017  			linkMTU:       targetSize6,
  2018  			replyLength:   targetSize6 - 1,
  2019  		},
  2020  		{
  2021  			name:          "IPv6 no payload",
  2022  			srcAddress:    remoteIPv6Addr,
  2023  			injector:      rxIPv6Bad,
  2024  			checker:       v6Checker,
  2025  			payloadLength: 0,
  2026  			linkMTU:       targetSize6,
  2027  			replyLength:   replyHeaderLength6,
  2028  		},
  2029  	}
  2030  
  2031  	for _, test := range tests {
  2032  		t.Run(test.name, func(t *testing.T) {
  2033  			ctx := newTestContext()
  2034  			defer ctx.cleanup()
  2035  			s := ctx.s
  2036  
  2037  			e := addLinkEndpointToStackWithMTU(t, s, test.linkMTU)
  2038  			defer e.Close()
  2039  			// Allocate and initialize the payload view.
  2040  			payload := make([]byte, test.payloadLength)
  2041  			for i := 0; i < len(payload); i++ {
  2042  				payload[i] = uint8(i)
  2043  			}
  2044  			// Default routes for IPv4&6 so ICMP can find a route to the remote
  2045  			// node when attempting to send the ICMP error Reply.
  2046  			s.SetRouteTable([]tcpip.Route{
  2047  				{
  2048  					Destination: header.IPv4EmptySubnet,
  2049  					NIC:         nicID,
  2050  				},
  2051  				{
  2052  					Destination: header.IPv6EmptySubnet,
  2053  					NIC:         nicID,
  2054  				},
  2055  			})
  2056  			v := test.injector(e, test.srcAddress, payload)
  2057  			pkt := e.Read()
  2058  			if pkt == nil {
  2059  				t.Fatal("expected a packet to be written")
  2060  			}
  2061  			if got, want := pkt.Size(), test.replyLength; got != want {
  2062  				t.Fatalf("got %d bytes of icmp error packet, want %d", got, want)
  2063  			}
  2064  			test.checker(t, pkt, v)
  2065  			pkt.DecRef()
  2066  		})
  2067  	}
  2068  }
  2069  
  2070  func TestJoinLeaveAllRoutersGroup(t *testing.T) {
  2071  	const nicID = 1
  2072  
  2073  	tests := []struct {
  2074  		name           string
  2075  		netProto       tcpip.NetworkProtocolNumber
  2076  		protoFactory   stack.NetworkProtocolFactory
  2077  		allRoutersAddr tcpip.Address
  2078  	}{
  2079  		{
  2080  			name:           "IPv4",
  2081  			netProto:       ipv4.ProtocolNumber,
  2082  			protoFactory:   ipv4.NewProtocol,
  2083  			allRoutersAddr: header.IPv4AllRoutersGroup,
  2084  		},
  2085  		{
  2086  			name:           "IPv6 Interface Local",
  2087  			netProto:       ipv6.ProtocolNumber,
  2088  			protoFactory:   ipv6.NewProtocol,
  2089  			allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
  2090  		},
  2091  		{
  2092  			name:           "IPv6 Link Local",
  2093  			netProto:       ipv6.ProtocolNumber,
  2094  			protoFactory:   ipv6.NewProtocol,
  2095  			allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
  2096  		},
  2097  		{
  2098  			name:           "IPv6 Site Local",
  2099  			netProto:       ipv6.ProtocolNumber,
  2100  			protoFactory:   ipv6.NewProtocol,
  2101  			allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
  2102  		},
  2103  	}
  2104  
  2105  	for _, test := range tests {
  2106  		t.Run(test.name, func(t *testing.T) {
  2107  			for _, nicDisabled := range [...]bool{true, false} {
  2108  				t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
  2109  					ctx := newTestContext()
  2110  					defer ctx.cleanup()
  2111  					s := ctx.s
  2112  
  2113  					opts := stack.NICOptions{Disabled: nicDisabled}
  2114  					if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
  2115  						t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
  2116  					}
  2117  
  2118  					if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
  2119  						t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
  2120  					} else if got {
  2121  						t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
  2122  					}
  2123  
  2124  					if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil {
  2125  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err)
  2126  					}
  2127  					if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
  2128  						t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
  2129  					} else if !got {
  2130  						t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
  2131  					}
  2132  
  2133  					if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil {
  2134  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err)
  2135  					}
  2136  					if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
  2137  						t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
  2138  					} else if got {
  2139  						t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
  2140  					}
  2141  				})
  2142  			}
  2143  		})
  2144  	}
  2145  }
  2146  
  2147  func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
  2148  	const nicID = 1
  2149  
  2150  	tests := []struct {
  2151  		name          string
  2152  		proto         tcpip.NetworkProtocolNumber
  2153  		addr          tcpip.AddressWithPrefix
  2154  		payloadOffset int
  2155  	}{
  2156  		{
  2157  			name:          "IPv4",
  2158  			proto:         header.IPv4ProtocolNumber,
  2159  			addr:          localIPv4AddrWithPrefix,
  2160  			payloadOffset: header.IPv4MinimumSize,
  2161  		},
  2162  		{
  2163  			name:          "IPv6",
  2164  			proto:         header.IPv6ProtocolNumber,
  2165  			addr:          localIPv6AddrWithPrefix,
  2166  			payloadOffset: 0,
  2167  		},
  2168  	}
  2169  
  2170  	for _, test := range tests {
  2171  		t.Run(test.name, func(t *testing.T) {
  2172  			ctx := newTestContext()
  2173  			defer ctx.cleanup()
  2174  			s := ctx.s
  2175  
  2176  			if err := s.CreateNIC(nicID, loopback.New()); err != nil {
  2177  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
  2178  			}
  2179  			protocolAddr := tcpip.ProtocolAddress{
  2180  				Protocol:          test.proto,
  2181  				AddressWithPrefix: test.addr,
  2182  			}
  2183  			if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
  2184  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
  2185  			}
  2186  
  2187  			s.SetRouteTable([]tcpip.Route{
  2188  				{
  2189  					Destination: test.addr.Subnet(),
  2190  					NIC:         nicID,
  2191  				},
  2192  			})
  2193  
  2194  			var wq waiter.Queue
  2195  			we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
  2196  			wq.EventRegister(&we)
  2197  			ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */)
  2198  			if err != nil {
  2199  				t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
  2200  			}
  2201  			defer ep.Close()
  2202  
  2203  			writeOpts := tcpip.WriteOptions{
  2204  				To: &tcpip.FullAddress{
  2205  					Addr: test.addr.Address,
  2206  				},
  2207  			}
  2208  			data := []byte{1, 2, 3, 4}
  2209  			var r bytes.Reader
  2210  			r.Reset(data)
  2211  			if n, err := ep.Write(&r, writeOpts); err != nil {
  2212  				t.Fatalf("ep.Write(_, _): %s", err)
  2213  			} else if want := int64(len(data)); n != want {
  2214  				t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
  2215  			}
  2216  
  2217  			// Wait for the endpoint to become readable.
  2218  			<-ch
  2219  
  2220  			var w bytes.Buffer
  2221  			rr, err := ep.Read(&w, tcpip.ReadOptions{
  2222  				NeedRemoteAddr: true,
  2223  			})
  2224  			if err != nil {
  2225  				t.Fatalf("ep.Read(...): %s", err)
  2226  			}
  2227  			if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" {
  2228  				t.Errorf("payload mismatch (-want +got):\n%s", diff)
  2229  			}
  2230  			if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" {
  2231  				t.Errorf("remote addr mismatch (-want +got):\n%s", diff)
  2232  			}
  2233  		})
  2234  	}
  2235  }