github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/tests/integration/link_resolution_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package link_resolution_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"net"
    21  	"runtime"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/google/go-cmp/cmp/cmpopts"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/faketime"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/link/pipe"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/network/arp"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    36  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    37  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    38  	"github.com/SagerNet/gvisor/pkg/tcpip/tests/utils"
    39  	tcptestutil "github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    40  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp"
    41  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp"
    42  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    43  	"github.com/SagerNet/gvisor/pkg/waiter"
    44  )
    45  
    46  func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) {
    47  	host1Stack := stack.New(stackOpts)
    48  	host2Stack := stack.New(stackOpts)
    49  
    50  	host1NIC, host2NIC := pipe.New(utils.LinkAddr1, utils.LinkAddr2)
    51  
    52  	if err := host1Stack.CreateNIC(host1NICID, utils.NewEthernetEndpoint(host1NIC)); err != nil {
    53  		t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
    54  	}
    55  	if err := host2Stack.CreateNIC(host2NICID, utils.NewEthernetEndpoint(host2NIC)); err != nil {
    56  		t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
    57  	}
    58  
    59  	if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
    60  		t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
    61  	}
    62  	if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
    63  		t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
    64  	}
    65  	if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
    66  		t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
    67  	}
    68  	if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
    69  		t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
    70  	}
    71  
    72  	host1Stack.SetRouteTable([]tcpip.Route{
    73  		{
    74  			Destination: utils.Ipv4Addr1.AddressWithPrefix.Subnet(),
    75  			NIC:         host1NICID,
    76  		},
    77  		{
    78  			Destination: utils.Ipv6Addr1.AddressWithPrefix.Subnet(),
    79  			NIC:         host1NICID,
    80  		},
    81  	})
    82  	host2Stack.SetRouteTable([]tcpip.Route{
    83  		{
    84  			Destination: utils.Ipv4Addr2.AddressWithPrefix.Subnet(),
    85  			NIC:         host2NICID,
    86  		},
    87  		{
    88  			Destination: utils.Ipv6Addr2.AddressWithPrefix.Subnet(),
    89  			NIC:         host2NICID,
    90  		},
    91  	})
    92  
    93  	return host1Stack, host2Stack
    94  }
    95  
    96  // TestPing tests that two hosts can ping eachother when link resolution is
    97  // enabled.
    98  func TestPing(t *testing.T) {
    99  	const (
   100  		host1NICID = 1
   101  		host2NICID = 4
   102  
   103  		// icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
   104  		// request/reply packets.
   105  		icmpDataOffset = 8
   106  	)
   107  
   108  	tests := []struct {
   109  		name       string
   110  		transProto tcpip.TransportProtocolNumber
   111  		netProto   tcpip.NetworkProtocolNumber
   112  		remoteAddr tcpip.Address
   113  		icmpBuf    func(*testing.T) []byte
   114  	}{
   115  		{
   116  			name:       "IPv4 Ping",
   117  			transProto: icmp.ProtocolNumber4,
   118  			netProto:   ipv4.ProtocolNumber,
   119  			remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
   120  			icmpBuf: func(t *testing.T) []byte {
   121  				data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
   122  				hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
   123  				hdr.SetType(header.ICMPv4Echo)
   124  				if n := copy(hdr.Payload(), data[:]); n != len(data) {
   125  					t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
   126  				}
   127  				return hdr
   128  			},
   129  		},
   130  		{
   131  			name:       "IPv6 Ping",
   132  			transProto: icmp.ProtocolNumber6,
   133  			netProto:   ipv6.ProtocolNumber,
   134  			remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
   135  			icmpBuf: func(t *testing.T) []byte {
   136  				data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
   137  				hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
   138  				hdr.SetType(header.ICMPv6EchoRequest)
   139  				if n := copy(hdr.Payload(), data[:]); n != len(data) {
   140  					t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
   141  				}
   142  				return hdr
   143  			},
   144  		},
   145  	}
   146  
   147  	for _, test := range tests {
   148  		t.Run(test.name, func(t *testing.T) {
   149  			stackOpts := stack.Options{
   150  				NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   151  				TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
   152  			}
   153  
   154  			host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
   155  
   156  			var wq waiter.Queue
   157  			we, waiterCH := waiter.NewChannelEntry(nil)
   158  			wq.EventRegister(&we, waiter.ReadableEvents)
   159  			ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
   160  			if err != nil {
   161  				t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
   162  			}
   163  			defer ep.Close()
   164  
   165  			icmpBuf := test.icmpBuf(t)
   166  			var r bytes.Reader
   167  			r.Reset(icmpBuf)
   168  			wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
   169  			if n, err := ep.Write(&r, wOpts); err != nil {
   170  				t.Fatalf("ep.Write(_, _): %s", err)
   171  			} else if want := int64(len(icmpBuf)); n != want {
   172  				t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
   173  			}
   174  
   175  			// Wait for the endpoint to be readable.
   176  			<-waiterCH
   177  
   178  			var buf bytes.Buffer
   179  			opts := tcpip.ReadOptions{NeedRemoteAddr: true}
   180  			res, err := ep.Read(&buf, opts)
   181  			if err != nil {
   182  				t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
   183  			}
   184  			if diff := cmp.Diff(tcpip.ReadResult{
   185  				Count:      buf.Len(),
   186  				Total:      buf.Len(),
   187  				RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr},
   188  			}, res, checker.IgnoreCmpPath(
   189  				"ControlMessages",
   190  				"RemoteAddr.NIC",
   191  				"RemoteAddr.Port",
   192  			)); diff != "" {
   193  				t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   194  			}
   195  			if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
   196  				t.Errorf("received data mismatch (-want +got):\n%s", diff)
   197  			}
   198  		})
   199  	}
   200  }
   201  
   202  type transportError struct {
   203  	origin tcpip.SockErrOrigin
   204  	typ    uint8
   205  	code   uint8
   206  	info   uint32
   207  	kind   stack.TransportErrorKind
   208  }
   209  
   210  func TestTCPLinkResolutionFailure(t *testing.T) {
   211  	const (
   212  		host1NICID = 1
   213  		host2NICID = 4
   214  	)
   215  
   216  	tests := []struct {
   217  		name             string
   218  		netProto         tcpip.NetworkProtocolNumber
   219  		remoteAddr       tcpip.Address
   220  		expectedWriteErr tcpip.Error
   221  		sockError        tcpip.SockError
   222  		transErr         transportError
   223  	}{
   224  		{
   225  			name:             "IPv4 with resolvable remote",
   226  			netProto:         ipv4.ProtocolNumber,
   227  			remoteAddr:       utils.Ipv4Addr2.AddressWithPrefix.Address,
   228  			expectedWriteErr: nil,
   229  		},
   230  		{
   231  			name:             "IPv6 with resolvable remote",
   232  			netProto:         ipv6.ProtocolNumber,
   233  			remoteAddr:       utils.Ipv6Addr2.AddressWithPrefix.Address,
   234  			expectedWriteErr: nil,
   235  		},
   236  		{
   237  			name:             "IPv4 without resolvable remote",
   238  			netProto:         ipv4.ProtocolNumber,
   239  			remoteAddr:       utils.Ipv4Addr3.AddressWithPrefix.Address,
   240  			expectedWriteErr: &tcpip.ErrNoRoute{},
   241  			sockError: tcpip.SockError{
   242  				Err: &tcpip.ErrNoRoute{},
   243  				Dst: tcpip.FullAddress{
   244  					NIC:  host1NICID,
   245  					Addr: utils.Ipv4Addr3.AddressWithPrefix.Address,
   246  					Port: 1234,
   247  				},
   248  				Offender: tcpip.FullAddress{
   249  					NIC:  host1NICID,
   250  					Addr: utils.Ipv4Addr1.AddressWithPrefix.Address,
   251  				},
   252  				NetProto: ipv4.ProtocolNumber,
   253  			},
   254  			transErr: transportError{
   255  				origin: tcpip.SockExtErrorOriginICMP,
   256  				typ:    uint8(header.ICMPv4DstUnreachable),
   257  				code:   uint8(header.ICMPv4HostUnreachable),
   258  				kind:   stack.DestinationHostUnreachableTransportError,
   259  			},
   260  		},
   261  		{
   262  			name:             "IPv6 without resolvable remote",
   263  			netProto:         ipv6.ProtocolNumber,
   264  			remoteAddr:       utils.Ipv6Addr3.AddressWithPrefix.Address,
   265  			expectedWriteErr: &tcpip.ErrNoRoute{},
   266  			sockError: tcpip.SockError{
   267  				Err: &tcpip.ErrNoRoute{},
   268  				Dst: tcpip.FullAddress{
   269  					NIC:  host1NICID,
   270  					Addr: utils.Ipv6Addr3.AddressWithPrefix.Address,
   271  					Port: 1234,
   272  				},
   273  				Offender: tcpip.FullAddress{
   274  					NIC:  host1NICID,
   275  					Addr: utils.Ipv6Addr1.AddressWithPrefix.Address,
   276  				},
   277  				NetProto: ipv6.ProtocolNumber,
   278  			},
   279  			transErr: transportError{
   280  				origin: tcpip.SockExtErrorOriginICMP6,
   281  				typ:    uint8(header.ICMPv6DstUnreachable),
   282  				code:   uint8(header.ICMPv6AddressUnreachable),
   283  				kind:   stack.DestinationHostUnreachableTransportError,
   284  			},
   285  		},
   286  	}
   287  
   288  	for _, test := range tests {
   289  		t.Run(test.name, func(t *testing.T) {
   290  			clock := faketime.NewManualClock()
   291  			stackOpts := stack.Options{
   292  				NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   293  				TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
   294  				Clock:              clock,
   295  			}
   296  
   297  			host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
   298  
   299  			var listenerWQ waiter.Queue
   300  			listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ)
   301  			if err != nil {
   302  				t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
   303  			}
   304  			defer listenerEP.Close()
   305  
   306  			listenerAddr := tcpip.FullAddress{Port: 1234}
   307  			if err := listenerEP.Bind(listenerAddr); err != nil {
   308  				t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
   309  			}
   310  
   311  			if err := listenerEP.Listen(1); err != nil {
   312  				t.Fatalf("listenerEP.Listen(1): %s", err)
   313  			}
   314  
   315  			var clientWQ waiter.Queue
   316  			we, ch := waiter.NewChannelEntry(nil)
   317  			clientWQ.EventRegister(&we, waiter.WritableEvents|waiter.EventErr)
   318  			clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ)
   319  			if err != nil {
   320  				t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
   321  			}
   322  			defer clientEP.Close()
   323  
   324  			sockOpts := clientEP.SocketOptions()
   325  			sockOpts.SetRecvError(true)
   326  
   327  			remoteAddr := listenerAddr
   328  			remoteAddr.Addr = test.remoteAddr
   329  			{
   330  				err := clientEP.Connect(remoteAddr)
   331  				if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   332  					t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{})
   333  				}
   334  			}
   335  
   336  			// Wait for an error due to link resolution failing, or the endpoint to be
   337  			// writable.
   338  			if test.expectedWriteErr != nil {
   339  				nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
   340  				if err != nil {
   341  					t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
   342  				}
   343  				clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
   344  			} else {
   345  				clock.RunImmediatelyScheduledJobs()
   346  			}
   347  			<-ch
   348  
   349  			{
   350  				var r bytes.Reader
   351  				r.Reset([]byte{0})
   352  				var wOpts tcpip.WriteOptions
   353  				_, err := clientEP.Write(&r, wOpts)
   354  				if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" {
   355  					t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff)
   356  				}
   357  			}
   358  
   359  			if test.expectedWriteErr == nil {
   360  				return
   361  			}
   362  
   363  			sockErr := sockOpts.DequeueErr()
   364  			if sockErr == nil {
   365  				t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil")
   366  			}
   367  
   368  			sockErrCmpOpts := []cmp.Option{
   369  				cmpopts.IgnoreUnexported(tcpip.SockError{}),
   370  				cmp.Comparer(func(a, b tcpip.Error) bool {
   371  					// tcpip.Error holds an unexported field but the errors netstack uses
   372  					// are pre defined so we can simply compare pointers.
   373  					return a == b
   374  				}),
   375  				checker.IgnoreCmpPath(
   376  					// Ignore the payload since we do not know the TCP seq/ack numbers.
   377  					"Payload",
   378  					// Ignore the cause since we will compare its properties separately
   379  					// since the concrete type of the cause is unknown.
   380  					"Cause",
   381  				),
   382  			}
   383  
   384  			if addr, err := clientEP.GetLocalAddress(); err != nil {
   385  				t.Fatalf("clientEP.GetLocalAddress(): %s", err)
   386  			} else {
   387  				test.sockError.Offender.Port = addr.Port
   388  			}
   389  			if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" {
   390  				t.Errorf("socket error mismatch (-want +got):\n%s", diff)
   391  			}
   392  
   393  			transErr, ok := sockErr.Cause.(stack.TransportError)
   394  			if !ok {
   395  				t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause)
   396  			}
   397  			if diff := cmp.Diff(
   398  				test.transErr,
   399  				transportError{
   400  					origin: transErr.Origin(),
   401  					typ:    transErr.Type(),
   402  					code:   transErr.Code(),
   403  					info:   transErr.Info(),
   404  					kind:   transErr.Kind(),
   405  				},
   406  				cmp.AllowUnexported(transportError{}),
   407  			); diff != "" {
   408  				t.Errorf("socket error mismatch (-want +got):\n%s", diff)
   409  			}
   410  		})
   411  	}
   412  }
   413  
   414  func TestForwardingWithLinkResolutionFailure(t *testing.T) {
   415  	const (
   416  		incomingNICID                     = 1
   417  		outgoingNICID                     = 2
   418  		ttl                               = 2
   419  		expectedHostUnreachableErrorCount = 1
   420  	)
   421  	outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06")
   422  
   423  	rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
   424  		utils.RxICMPv4EchoRequest(e, src, dst, ttl)
   425  	}
   426  
   427  	rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
   428  		utils.RxICMPv6EchoRequest(e, src, dst, ttl)
   429  	}
   430  
   431  	arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
   432  		if request.Proto != arp.ProtocolNumber {
   433  			t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber)
   434  		}
   435  		if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
   436  			t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
   437  		}
   438  		rep := header.ARP(request.Pkt.NetworkHeader().View())
   439  		if got := rep.Op(); got != header.ARPRequest {
   440  			t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
   441  		}
   442  		if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr {
   443  			t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr)
   444  		}
   445  		if got := tcpip.Address(rep.ProtocolAddressSender()); got != src {
   446  			t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src)
   447  		}
   448  		if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst {
   449  			t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst)
   450  		}
   451  	}
   452  
   453  	ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
   454  		if request.Proto != header.IPv6ProtocolNumber {
   455  			t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber)
   456  		}
   457  
   458  		snmc := header.SolicitedNodeAddr(dst)
   459  		if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want {
   460  			t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want)
   461  		}
   462  
   463  		checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()),
   464  			checker.SrcAddr(src),
   465  			checker.DstAddr(snmc),
   466  			checker.TTL(header.NDPHopLimit),
   467  			checker.NDPNS(
   468  				checker.NDPNSTargetAddress(dst),
   469  			))
   470  	}
   471  
   472  	icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
   473  		checker.IPv4(t, b,
   474  			checker.SrcAddr(src),
   475  			checker.DstAddr(dst),
   476  			checker.TTL(ipv4.DefaultTTL),
   477  			checker.ICMPv4(
   478  				checker.ICMPv4Checksum(),
   479  				checker.ICMPv4Type(header.ICMPv4DstUnreachable),
   480  				checker.ICMPv4Code(header.ICMPv4HostUnreachable),
   481  			),
   482  		)
   483  	}
   484  
   485  	icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
   486  		checker.IPv6(t, b,
   487  			checker.SrcAddr(src),
   488  			checker.DstAddr(dst),
   489  			checker.TTL(ipv6.DefaultTTL),
   490  			checker.ICMPv6(
   491  				checker.ICMPv6Type(header.ICMPv6DstUnreachable),
   492  				checker.ICMPv6Code(header.ICMPv6AddressUnreachable),
   493  			),
   494  		)
   495  	}
   496  
   497  	tests := []struct {
   498  		name                         string
   499  		networkProtocolFactory       []stack.NetworkProtocolFactory
   500  		networkProtocolNumber        tcpip.NetworkProtocolNumber
   501  		sourceAddr                   tcpip.Address
   502  		destAddr                     tcpip.Address
   503  		incomingAddr                 tcpip.AddressWithPrefix
   504  		outgoingAddr                 tcpip.AddressWithPrefix
   505  		transportProtocol            func(*stack.Stack) stack.TransportProtocol
   506  		rx                           func(*channel.Endpoint, tcpip.Address, tcpip.Address)
   507  		linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address)
   508  		icmpReplyChecker             func(*testing.T, []byte, tcpip.Address, tcpip.Address)
   509  		mtu                          uint32
   510  	}{
   511  		{
   512  			name:                   "IPv4 Host unreachable",
   513  			networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
   514  			networkProtocolNumber:  header.IPv4ProtocolNumber,
   515  			sourceAddr:             tcptestutil.MustParse4("10.0.0.2"),
   516  			destAddr:               tcptestutil.MustParse4("11.0.0.2"),
   517  			incomingAddr: tcpip.AddressWithPrefix{
   518  				Address:   tcpip.Address(net.ParseIP("10.0.0.1").To4()),
   519  				PrefixLen: 8,
   520  			},
   521  			outgoingAddr: tcpip.AddressWithPrefix{
   522  				Address:   tcpip.Address(net.ParseIP("11.0.0.1").To4()),
   523  				PrefixLen: 8,
   524  			},
   525  			transportProtocol:            icmp.NewProtocol4,
   526  			linkResolutionRequestChecker: arpChecker,
   527  			icmpReplyChecker:             icmpv4Checker,
   528  			rx:                           rxICMPv4EchoRequest,
   529  			mtu:                          ipv4.MaxTotalSize,
   530  		},
   531  		{
   532  			name:                   "IPv6 Host unreachable",
   533  			networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
   534  			networkProtocolNumber:  header.IPv6ProtocolNumber,
   535  			sourceAddr:             tcptestutil.MustParse6("10::2"),
   536  			destAddr:               tcptestutil.MustParse6("11::2"),
   537  			incomingAddr: tcpip.AddressWithPrefix{
   538  				Address:   tcpip.Address(net.ParseIP("10::1").To16()),
   539  				PrefixLen: 64,
   540  			},
   541  			outgoingAddr: tcpip.AddressWithPrefix{
   542  				Address:   tcpip.Address(net.ParseIP("11::1").To16()),
   543  				PrefixLen: 64,
   544  			},
   545  			transportProtocol:            icmp.NewProtocol6,
   546  			linkResolutionRequestChecker: ndpChecker,
   547  			icmpReplyChecker:             icmpv6Checker,
   548  			rx:                           rxICMPv6EchoRequest,
   549  			mtu:                          header.IPv6MinimumMTU,
   550  		},
   551  	}
   552  	for _, test := range tests {
   553  		t.Run(test.name, func(t *testing.T) {
   554  			clock := faketime.NewManualClock()
   555  
   556  			s := stack.New(stack.Options{
   557  				NetworkProtocols:   test.networkProtocolFactory,
   558  				TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol},
   559  				Clock:              clock,
   560  			})
   561  
   562  			// Set up endpoint through which we will receive packets.
   563  			incomingEndpoint := channel.New(1, test.mtu, "")
   564  			if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
   565  				t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
   566  			}
   567  			incomingProtoAddr := tcpip.ProtocolAddress{
   568  				Protocol:          test.networkProtocolNumber,
   569  				AddressWithPrefix: test.incomingAddr,
   570  			}
   571  			if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
   572  				t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
   573  			}
   574  
   575  			// Set up endpoint through which we will attempt to forward packets.
   576  			outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr)
   577  			outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
   578  			if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
   579  				t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
   580  			}
   581  			outgoingProtoAddr := tcpip.ProtocolAddress{
   582  				Protocol:          test.networkProtocolNumber,
   583  				AddressWithPrefix: test.outgoingAddr,
   584  			}
   585  			if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
   586  				t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
   587  			}
   588  
   589  			s.SetRouteTable([]tcpip.Route{
   590  				{
   591  					Destination: test.incomingAddr.Subnet(),
   592  					NIC:         incomingNICID,
   593  				},
   594  				{
   595  					Destination: test.outgoingAddr.Subnet(),
   596  					NIC:         outgoingNICID,
   597  				},
   598  			})
   599  
   600  			if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil {
   601  				t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err)
   602  			}
   603  
   604  			test.rx(incomingEndpoint, test.sourceAddr, test.destAddr)
   605  
   606  			nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber)
   607  			if err != nil {
   608  				t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err)
   609  			}
   610  			// Trigger the first packet on the endpoint.
   611  			clock.RunImmediatelyScheduledJobs()
   612  
   613  			for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ {
   614  				request, ok := outgoingEndpoint.Read()
   615  				if !ok {
   616  					t.Fatal("expected ARP packet through outgoing NIC")
   617  				}
   618  
   619  				test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr)
   620  
   621  				// Advance the clock the span of one request timeout.
   622  				clock.Advance(nudConfigs.RetransmitTimer)
   623  			}
   624  
   625  			// Next, we make a blocking read to retrieve the error packet. This is
   626  			// necessary because outgoing packets are dequeued asynchronously when
   627  			// link resolution fails, and this dequeue is what triggers the ICMP
   628  			// error.
   629  			reply, ok := incomingEndpoint.Read()
   630  			if !ok {
   631  				t.Fatal("expected ICMP packet through incoming NIC")
   632  			}
   633  
   634  			test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr)
   635  
   636  			// Since link resolution failed, we don't expect the packet to be
   637  			// forwarded.
   638  			forwardedPacket, ok := outgoingEndpoint.Read()
   639  			if ok {
   640  				t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket)
   641  			}
   642  
   643  			if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want {
   644  				t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want)
   645  			}
   646  		})
   647  	}
   648  }
   649  
   650  func TestGetLinkAddress(t *testing.T) {
   651  	const (
   652  		host1NICID = 1
   653  		host2NICID = 4
   654  	)
   655  
   656  	tests := []struct {
   657  		name                  string
   658  		netProto              tcpip.NetworkProtocolNumber
   659  		remoteAddr, localAddr tcpip.Address
   660  		expectedErr           tcpip.Error
   661  	}{
   662  		{
   663  			name:        "IPv4 resolvable",
   664  			netProto:    ipv4.ProtocolNumber,
   665  			remoteAddr:  utils.Ipv4Addr2.AddressWithPrefix.Address,
   666  			expectedErr: nil,
   667  		},
   668  		{
   669  			name:        "IPv6 resolvable",
   670  			netProto:    ipv6.ProtocolNumber,
   671  			remoteAddr:  utils.Ipv6Addr2.AddressWithPrefix.Address,
   672  			expectedErr: nil,
   673  		},
   674  		{
   675  			name:        "IPv4 not resolvable",
   676  			netProto:    ipv4.ProtocolNumber,
   677  			remoteAddr:  utils.Ipv4Addr3.AddressWithPrefix.Address,
   678  			expectedErr: &tcpip.ErrTimeout{},
   679  		},
   680  		{
   681  			name:        "IPv6 not resolvable",
   682  			netProto:    ipv6.ProtocolNumber,
   683  			remoteAddr:  utils.Ipv6Addr3.AddressWithPrefix.Address,
   684  			expectedErr: &tcpip.ErrTimeout{},
   685  		},
   686  		{
   687  			name:        "IPv4 bad local address",
   688  			netProto:    ipv4.ProtocolNumber,
   689  			remoteAddr:  utils.Ipv4Addr2.AddressWithPrefix.Address,
   690  			localAddr:   utils.Ipv4Addr2.AddressWithPrefix.Address,
   691  			expectedErr: &tcpip.ErrBadLocalAddress{},
   692  		},
   693  		{
   694  			name:        "IPv6 bad local address",
   695  			netProto:    ipv6.ProtocolNumber,
   696  			remoteAddr:  utils.Ipv6Addr2.AddressWithPrefix.Address,
   697  			localAddr:   utils.Ipv6Addr2.AddressWithPrefix.Address,
   698  			expectedErr: &tcpip.ErrBadLocalAddress{},
   699  		},
   700  	}
   701  
   702  	for _, test := range tests {
   703  		t.Run(test.name, func(t *testing.T) {
   704  			clock := faketime.NewManualClock()
   705  			stackOpts := stack.Options{
   706  				NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   707  				Clock:            clock,
   708  			}
   709  
   710  			host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
   711  
   712  			ch := make(chan stack.LinkResolutionResult, 1)
   713  			err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, test.localAddr, test.netProto, func(r stack.LinkResolutionResult) {
   714  				ch <- r
   715  			})
   716  			if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
   717  				t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
   718  			}
   719  			wantRes := stack.LinkResolutionResult{Err: test.expectedErr}
   720  			if test.expectedErr == nil {
   721  				wantRes.LinkAddress = utils.LinkAddr2
   722  			}
   723  
   724  			nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
   725  			if err != nil {
   726  				t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
   727  			}
   728  
   729  			clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
   730  			select {
   731  			case got := <-ch:
   732  				if diff := cmp.Diff(wantRes, got); diff != "" {
   733  					t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
   734  				}
   735  			default:
   736  				t.Fatal("event didn't arrive")
   737  			}
   738  		})
   739  	}
   740  }
   741  
   742  func TestRouteResolvedFields(t *testing.T) {
   743  	const (
   744  		host1NICID = 1
   745  		host2NICID = 4
   746  	)
   747  
   748  	tests := []struct {
   749  		name                  string
   750  		netProto              tcpip.NetworkProtocolNumber
   751  		localAddr             tcpip.Address
   752  		remoteAddr            tcpip.Address
   753  		immediatelyResolvable bool
   754  		expectedErr           tcpip.Error
   755  		expectedLinkAddr      tcpip.LinkAddress
   756  	}{
   757  		{
   758  			name:                  "IPv4 immediately resolvable",
   759  			netProto:              ipv4.ProtocolNumber,
   760  			localAddr:             utils.Ipv4Addr1.AddressWithPrefix.Address,
   761  			remoteAddr:            header.IPv4AllSystems,
   762  			immediatelyResolvable: true,
   763  			expectedErr:           nil,
   764  			expectedLinkAddr:      header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems),
   765  		},
   766  		{
   767  			name:                  "IPv6 immediately resolvable",
   768  			netProto:              ipv6.ProtocolNumber,
   769  			localAddr:             utils.Ipv6Addr1.AddressWithPrefix.Address,
   770  			remoteAddr:            header.IPv6AllNodesMulticastAddress,
   771  			immediatelyResolvable: true,
   772  			expectedErr:           nil,
   773  			expectedLinkAddr:      header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
   774  		},
   775  		{
   776  			name:                  "IPv4 resolvable",
   777  			netProto:              ipv4.ProtocolNumber,
   778  			localAddr:             utils.Ipv4Addr1.AddressWithPrefix.Address,
   779  			remoteAddr:            utils.Ipv4Addr2.AddressWithPrefix.Address,
   780  			immediatelyResolvable: false,
   781  			expectedErr:           nil,
   782  			expectedLinkAddr:      utils.LinkAddr2,
   783  		},
   784  		{
   785  			name:                  "IPv6 resolvable",
   786  			netProto:              ipv6.ProtocolNumber,
   787  			localAddr:             utils.Ipv6Addr1.AddressWithPrefix.Address,
   788  			remoteAddr:            utils.Ipv6Addr2.AddressWithPrefix.Address,
   789  			immediatelyResolvable: false,
   790  			expectedErr:           nil,
   791  			expectedLinkAddr:      utils.LinkAddr2,
   792  		},
   793  		{
   794  			name:                  "IPv4 not resolvable",
   795  			netProto:              ipv4.ProtocolNumber,
   796  			localAddr:             utils.Ipv4Addr1.AddressWithPrefix.Address,
   797  			remoteAddr:            utils.Ipv4Addr3.AddressWithPrefix.Address,
   798  			immediatelyResolvable: false,
   799  			expectedErr:           &tcpip.ErrTimeout{},
   800  		},
   801  		{
   802  			name:                  "IPv6 not resolvable",
   803  			netProto:              ipv6.ProtocolNumber,
   804  			localAddr:             utils.Ipv6Addr1.AddressWithPrefix.Address,
   805  			remoteAddr:            utils.Ipv6Addr3.AddressWithPrefix.Address,
   806  			immediatelyResolvable: false,
   807  			expectedErr:           &tcpip.ErrTimeout{},
   808  		},
   809  	}
   810  
   811  	for _, test := range tests {
   812  		t.Run(test.name, func(t *testing.T) {
   813  			clock := faketime.NewManualClock()
   814  			stackOpts := stack.Options{
   815  				NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   816  				Clock:            clock,
   817  			}
   818  
   819  			host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
   820  			r, err := host1Stack.FindRoute(host1NICID, test.localAddr, test.remoteAddr, test.netProto, false /* multicastLoop */)
   821  			if err != nil {
   822  				t.Fatalf("host1Stack.FindRoute(%d, %s, %s, %d, false): %s", host1NICID, test.localAddr, test.remoteAddr, test.netProto, err)
   823  			}
   824  			defer r.Release()
   825  
   826  			var wantRouteInfo stack.RouteInfo
   827  			wantRouteInfo.LocalLinkAddress = utils.LinkAddr1
   828  			wantRouteInfo.LocalAddress = test.localAddr
   829  			wantRouteInfo.RemoteAddress = test.remoteAddr
   830  			wantRouteInfo.NetProto = test.netProto
   831  			wantRouteInfo.Loop = stack.PacketOut
   832  			wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
   833  
   834  			ch := make(chan stack.ResolvedFieldsResult, 1)
   835  
   836  			if !test.immediatelyResolvable {
   837  				wantUnresolvedRouteInfo := wantRouteInfo
   838  				wantUnresolvedRouteInfo.RemoteLinkAddress = ""
   839  
   840  				err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
   841  					ch <- r
   842  				})
   843  				if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
   844  					t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
   845  				}
   846  
   847  				nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
   848  				if err != nil {
   849  					t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
   850  				}
   851  				clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
   852  
   853  				select {
   854  				case got := <-ch:
   855  					if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
   856  						t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
   857  					}
   858  				default:
   859  					t.Fatalf("event didn't arrive")
   860  				}
   861  
   862  				if test.expectedErr != nil {
   863  					return
   864  				}
   865  
   866  				// At this point the neighbor table should be populated so the route
   867  				// should be immediately resolvable.
   868  			}
   869  
   870  			if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
   871  				ch <- r
   872  			}); err != nil {
   873  				t.Errorf("r.ResolvedFields(_): %s", err)
   874  			}
   875  			select {
   876  			case routeResolveRes := <-ch:
   877  				if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: nil}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
   878  					t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
   879  				}
   880  			default:
   881  				t.Fatal("expected route to be immediately resolvable")
   882  			}
   883  		})
   884  	}
   885  }
   886  
   887  func TestWritePacketsLinkResolution(t *testing.T) {
   888  	const (
   889  		host1NICID = 1
   890  		host2NICID = 4
   891  	)
   892  
   893  	tests := []struct {
   894  		name             string
   895  		netProto         tcpip.NetworkProtocolNumber
   896  		remoteAddr       tcpip.Address
   897  		expectedWriteErr tcpip.Error
   898  	}{
   899  		{
   900  			name:             "IPv4",
   901  			netProto:         ipv4.ProtocolNumber,
   902  			remoteAddr:       utils.Ipv4Addr2.AddressWithPrefix.Address,
   903  			expectedWriteErr: nil,
   904  		},
   905  		{
   906  			name:             "IPv6",
   907  			netProto:         ipv6.ProtocolNumber,
   908  			remoteAddr:       utils.Ipv6Addr2.AddressWithPrefix.Address,
   909  			expectedWriteErr: nil,
   910  		},
   911  	}
   912  
   913  	for _, test := range tests {
   914  		t.Run(test.name, func(t *testing.T) {
   915  			stackOpts := stack.Options{
   916  				NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   917  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   918  			}
   919  
   920  			host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
   921  
   922  			var serverWQ waiter.Queue
   923  			serverWE, serverCH := waiter.NewChannelEntry(nil)
   924  			serverWQ.EventRegister(&serverWE, waiter.ReadableEvents)
   925  			serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
   926  			if err != nil {
   927  				t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
   928  			}
   929  			defer serverEP.Close()
   930  
   931  			serverAddr := tcpip.FullAddress{Port: 1234}
   932  			if err := serverEP.Bind(serverAddr); err != nil {
   933  				t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
   934  			}
   935  
   936  			r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
   937  			if err != nil {
   938  				t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
   939  			}
   940  			defer r.Release()
   941  
   942  			data := []byte{1, 2}
   943  			var pkts stack.PacketBufferList
   944  			for _, d := range data {
   945  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   946  					ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
   947  					Data:               buffer.View([]byte{d}).ToVectorisedView(),
   948  				})
   949  				pkt.TransportProtocolNumber = udp.ProtocolNumber
   950  				length := uint16(pkt.Size())
   951  				udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
   952  				udpHdr.Encode(&header.UDPFields{
   953  					SrcPort: 5555,
   954  					DstPort: serverAddr.Port,
   955  					Length:  length,
   956  				})
   957  				xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
   958  				xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
   959  				udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
   960  
   961  				pkts.PushBack(pkt)
   962  			}
   963  
   964  			params := stack.NetworkHeaderParams{
   965  				Protocol: udp.ProtocolNumber,
   966  				TTL:      64,
   967  				TOS:      stack.DefaultTOS,
   968  			}
   969  
   970  			if n, err := r.WritePackets(pkts, params); err != nil {
   971  				t.Fatalf("r.WritePackets(_, %#v): %s", params, err)
   972  			} else if want := pkts.Len(); want != n {
   973  				t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want)
   974  			}
   975  
   976  			var writer bytes.Buffer
   977  			count := 0
   978  			for {
   979  				var rOpts tcpip.ReadOptions
   980  				res, err := serverEP.Read(&writer, rOpts)
   981  				if err != nil {
   982  					if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   983  						// Should not have anymore bytes to read after we read the sent
   984  						// number of bytes.
   985  						if count == len(data) {
   986  							break
   987  						}
   988  
   989  						<-serverCH
   990  						continue
   991  					}
   992  
   993  					t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
   994  				}
   995  				count += res.Count
   996  			}
   997  
   998  			if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
   999  				t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
  1000  			}
  1001  			if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
  1002  				t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
  1003  			}
  1004  		})
  1005  	}
  1006  }
  1007  
  1008  type eventType int
  1009  
  1010  const (
  1011  	entryAdded eventType = iota
  1012  	entryChanged
  1013  	entryRemoved
  1014  )
  1015  
  1016  func (t eventType) String() string {
  1017  	switch t {
  1018  	case entryAdded:
  1019  		return "add"
  1020  	case entryChanged:
  1021  		return "change"
  1022  	case entryRemoved:
  1023  		return "remove"
  1024  	default:
  1025  		return fmt.Sprintf("unknown (%d)", t)
  1026  	}
  1027  }
  1028  
  1029  type eventInfo struct {
  1030  	eventType eventType
  1031  	nicID     tcpip.NICID
  1032  	entry     stack.NeighborEntry
  1033  }
  1034  
  1035  func (e eventInfo) String() string {
  1036  	return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
  1037  }
  1038  
  1039  var _ stack.NUDDispatcher = (*nudDispatcher)(nil)
  1040  
  1041  type nudDispatcher struct {
  1042  	c chan eventInfo
  1043  }
  1044  
  1045  func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
  1046  	e := eventInfo{
  1047  		eventType: entryAdded,
  1048  		nicID:     nicID,
  1049  		entry:     entry,
  1050  	}
  1051  	d.c <- e
  1052  }
  1053  
  1054  func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
  1055  	e := eventInfo{
  1056  		eventType: entryChanged,
  1057  		nicID:     nicID,
  1058  		entry:     entry,
  1059  	}
  1060  	d.c <- e
  1061  }
  1062  
  1063  func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
  1064  	e := eventInfo{
  1065  		eventType: entryRemoved,
  1066  		nicID:     nicID,
  1067  		entry:     entry,
  1068  	}
  1069  	d.c <- e
  1070  }
  1071  
  1072  func (d *nudDispatcher) expectEvent(want eventInfo) error {
  1073  	select {
  1074  	case got := <-d.c:
  1075  		if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
  1076  			return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
  1077  		}
  1078  		return nil
  1079  	default:
  1080  		return fmt.Errorf("event didn't arrive")
  1081  	}
  1082  }
  1083  
  1084  // TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
  1085  // that the neighbor used for a route is reachable.
  1086  func TestTCPConfirmNeighborReachability(t *testing.T) {
  1087  	tests := []struct {
  1088  		name            string
  1089  		netProto        tcpip.NetworkProtocolNumber
  1090  		remoteAddr      tcpip.Address
  1091  		neighborAddr    tcpip.Address
  1092  		getEndpoints    func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{})
  1093  		isHost1Listener bool
  1094  	}{
  1095  		{
  1096  			name:         "IPv4 active connection through neighbor",
  1097  			netProto:     ipv4.ProtocolNumber,
  1098  			remoteAddr:   utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1099  			neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1100  			getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1101  				var listenerWQ waiter.Queue
  1102  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1103  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1104  				listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
  1105  				if err != nil {
  1106  					t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1107  				}
  1108  				t.Cleanup(listenerEP.Close)
  1109  
  1110  				var clientWQ waiter.Queue
  1111  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1112  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1113  				clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
  1114  				if err != nil {
  1115  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1116  				}
  1117  
  1118  				return listenerEP, listenerCH, clientEP, clientCH
  1119  			},
  1120  		},
  1121  		{
  1122  			name:         "IPv6 active connection through neighbor",
  1123  			netProto:     ipv6.ProtocolNumber,
  1124  			remoteAddr:   utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1125  			neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1126  			getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1127  				var listenerWQ waiter.Queue
  1128  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1129  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1130  				listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
  1131  				if err != nil {
  1132  					t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1133  				}
  1134  				t.Cleanup(listenerEP.Close)
  1135  
  1136  				var clientWQ waiter.Queue
  1137  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1138  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1139  				clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
  1140  				if err != nil {
  1141  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1142  				}
  1143  
  1144  				return listenerEP, listenerCH, clientEP, clientCH
  1145  			},
  1146  		},
  1147  		{
  1148  			name:         "IPv4 active connection to neighbor",
  1149  			netProto:     ipv4.ProtocolNumber,
  1150  			remoteAddr:   utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1151  			neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1152  			getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1153  				var listenerWQ waiter.Queue
  1154  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1155  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1156  				listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
  1157  				if err != nil {
  1158  					t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1159  				}
  1160  				t.Cleanup(listenerEP.Close)
  1161  
  1162  				var clientWQ waiter.Queue
  1163  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1164  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1165  				clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
  1166  				if err != nil {
  1167  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1168  				}
  1169  
  1170  				return listenerEP, listenerCH, clientEP, clientCH
  1171  			},
  1172  		},
  1173  		{
  1174  			name:         "IPv6 active connection to neighbor",
  1175  			netProto:     ipv6.ProtocolNumber,
  1176  			remoteAddr:   utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1177  			neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1178  			getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1179  				var listenerWQ waiter.Queue
  1180  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1181  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1182  				listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
  1183  				if err != nil {
  1184  					t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1185  				}
  1186  				t.Cleanup(listenerEP.Close)
  1187  
  1188  				var clientWQ waiter.Queue
  1189  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1190  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1191  				clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
  1192  				if err != nil {
  1193  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1194  				}
  1195  
  1196  				return listenerEP, listenerCH, clientEP, clientCH
  1197  			},
  1198  		},
  1199  		{
  1200  			name:         "IPv4 passive connection to neighbor",
  1201  			netProto:     ipv4.ProtocolNumber,
  1202  			remoteAddr:   utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1203  			neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1204  			getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1205  				var listenerWQ waiter.Queue
  1206  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1207  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1208  				listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
  1209  				if err != nil {
  1210  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1211  				}
  1212  				t.Cleanup(listenerEP.Close)
  1213  
  1214  				var clientWQ waiter.Queue
  1215  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1216  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1217  				clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
  1218  				if err != nil {
  1219  					t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1220  				}
  1221  
  1222  				return listenerEP, listenerCH, clientEP, clientCH
  1223  			},
  1224  			isHost1Listener: true,
  1225  		},
  1226  		{
  1227  			name:         "IPv6 passive connection to neighbor",
  1228  			netProto:     ipv6.ProtocolNumber,
  1229  			remoteAddr:   utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1230  			neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1231  			getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1232  				var listenerWQ waiter.Queue
  1233  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1234  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1235  				listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
  1236  				if err != nil {
  1237  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1238  				}
  1239  				t.Cleanup(listenerEP.Close)
  1240  
  1241  				var clientWQ waiter.Queue
  1242  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1243  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1244  				clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
  1245  				if err != nil {
  1246  					t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1247  				}
  1248  
  1249  				return listenerEP, listenerCH, clientEP, clientCH
  1250  			},
  1251  			isHost1Listener: true,
  1252  		},
  1253  		{
  1254  			name:         "IPv4 passive connection through neighbor",
  1255  			netProto:     ipv4.ProtocolNumber,
  1256  			remoteAddr:   utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1257  			neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1258  			getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1259  				var listenerWQ waiter.Queue
  1260  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1261  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1262  				listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
  1263  				if err != nil {
  1264  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1265  				}
  1266  				t.Cleanup(listenerEP.Close)
  1267  
  1268  				var clientWQ waiter.Queue
  1269  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1270  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1271  				clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
  1272  				if err != nil {
  1273  					t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
  1274  				}
  1275  
  1276  				return listenerEP, listenerCH, clientEP, clientCH
  1277  			},
  1278  			isHost1Listener: true,
  1279  		},
  1280  		{
  1281  			name:         "IPv6 passive connection through neighbor",
  1282  			netProto:     ipv6.ProtocolNumber,
  1283  			remoteAddr:   utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1284  			neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1285  			getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
  1286  				var listenerWQ waiter.Queue
  1287  				listenerWE, listenerCH := waiter.NewChannelEntry(nil)
  1288  				listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
  1289  				listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
  1290  				if err != nil {
  1291  					t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1292  				}
  1293  				t.Cleanup(listenerEP.Close)
  1294  
  1295  				var clientWQ waiter.Queue
  1296  				clientWE, clientCH := waiter.NewChannelEntry(nil)
  1297  				clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
  1298  				clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
  1299  				if err != nil {
  1300  					t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
  1301  				}
  1302  
  1303  				return listenerEP, listenerCH, clientEP, clientCH
  1304  			},
  1305  			isHost1Listener: true,
  1306  		},
  1307  	}
  1308  
  1309  	for _, test := range tests {
  1310  		t.Run(test.name, func(t *testing.T) {
  1311  			clock := faketime.NewManualClock()
  1312  			nudDisp := nudDispatcher{
  1313  				c: make(chan eventInfo, 3),
  1314  			}
  1315  			stackOpts := stack.Options{
  1316  				NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
  1317  				TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
  1318  				Clock:              clock,
  1319  			}
  1320  			host1StackOpts := stackOpts
  1321  			host1StackOpts.NUDDisp = &nudDisp
  1322  
  1323  			host1Stack := stack.New(host1StackOpts)
  1324  			routerStack := stack.New(stackOpts)
  1325  			host2Stack := stack.New(stackOpts)
  1326  			utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
  1327  
  1328  			// Add a reachable dynamic entry to our neighbor table for the remote.
  1329  			{
  1330  				ch := make(chan stack.LinkResolutionResult, 1)
  1331  				err := host1Stack.GetLinkAddress(utils.Host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
  1332  					ch <- r
  1333  				})
  1334  				if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
  1335  					t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", utils.Host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
  1336  				}
  1337  				if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: utils.LinkAddr2, Err: nil}, <-ch); diff != "" {
  1338  					t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
  1339  				}
  1340  			}
  1341  			if err := nudDisp.expectEvent(eventInfo{
  1342  				eventType: entryAdded,
  1343  				nicID:     utils.Host1NICID,
  1344  				entry:     stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
  1345  			}); err != nil {
  1346  				t.Fatalf("error waiting for initial NUD event: %s", err)
  1347  			}
  1348  			if err := nudDisp.expectEvent(eventInfo{
  1349  				eventType: entryChanged,
  1350  				nicID:     utils.Host1NICID,
  1351  				entry:     stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1352  			}); err != nil {
  1353  				t.Fatalf("error waiting for reachable NUD event: %s", err)
  1354  			}
  1355  
  1356  			// Wait for the remote's neighbor entry to be stale before creating a
  1357  			// TCP connection from host1 to some remote.
  1358  			nudConfigs, err := host1Stack.NUDConfigurations(utils.Host1NICID, test.netProto)
  1359  			if err != nil {
  1360  				t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", utils.Host1NICID, test.netProto, err)
  1361  			}
  1362  			// The maximum reachable time for a neighbor is some maximum random factor
  1363  			// applied to the base reachable time.
  1364  			//
  1365  			// See NUDConfigurations.BaseReachableTime for more information.
  1366  			maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
  1367  			clock.Advance(maxReachableTime)
  1368  			if err := nudDisp.expectEvent(eventInfo{
  1369  				eventType: entryChanged,
  1370  				nicID:     utils.Host1NICID,
  1371  				entry:     stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1372  			}); err != nil {
  1373  				t.Fatalf("error waiting for stale NUD event: %s", err)
  1374  			}
  1375  
  1376  			listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
  1377  			defer clientEP.Close()
  1378  			listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
  1379  			if err := listenerEP.Bind(listenerAddr); err != nil {
  1380  				t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
  1381  			}
  1382  			if err := listenerEP.Listen(1); err != nil {
  1383  				t.Fatalf("listenerEP.Listen(1): %s", err)
  1384  			}
  1385  			{
  1386  				err := clientEP.Connect(listenerAddr)
  1387  				if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
  1388  					t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{})
  1389  				}
  1390  			}
  1391  
  1392  			// Wait for the TCP handshake to complete then make sure the neighbor is
  1393  			// reachable without entering the probe state as TCP should provide NUD
  1394  			// with confirmation that the neighbor is reachable (indicated by a
  1395  			// successful 3-way handshake).
  1396  			<-clientCH
  1397  			if err := nudDisp.expectEvent(eventInfo{
  1398  				eventType: entryChanged,
  1399  				nicID:     utils.Host1NICID,
  1400  				entry:     stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1401  			}); err != nil {
  1402  				t.Fatalf("error waiting for delay NUD event: %s", err)
  1403  			}
  1404  			<-listenerCH
  1405  			if err := nudDisp.expectEvent(eventInfo{
  1406  				eventType: entryChanged,
  1407  				nicID:     utils.Host1NICID,
  1408  				entry:     stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1409  			}); err != nil {
  1410  				t.Fatalf("error waiting for reachable NUD event: %s", err)
  1411  			}
  1412  
  1413  			peerEP, peerWQ, err := listenerEP.Accept(nil)
  1414  			if err != nil {
  1415  				t.Fatalf("listenerEP.Accept(): %s", err)
  1416  			}
  1417  			defer peerEP.Close()
  1418  			peerWE, peerCH := waiter.NewChannelEntry(nil)
  1419  			peerWQ.EventRegister(&peerWE, waiter.ReadableEvents)
  1420  
  1421  			// Wait for the neighbor to be stale again then send data to the remote.
  1422  			//
  1423  			// On successful transmission, the neighbor should become reachable
  1424  			// without probing the neighbor as a TCP ACK would be received which is an
  1425  			// indication of the neighbor being reachable.
  1426  			clock.Advance(maxReachableTime)
  1427  			if err := nudDisp.expectEvent(eventInfo{
  1428  				eventType: entryChanged,
  1429  				nicID:     utils.Host1NICID,
  1430  				entry:     stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1431  			}); err != nil {
  1432  				t.Fatalf("error waiting for stale NUD event: %s", err)
  1433  			}
  1434  			{
  1435  				var r bytes.Reader
  1436  				r.Reset([]byte{0})
  1437  				var wOpts tcpip.WriteOptions
  1438  				if _, err := clientEP.Write(&r, wOpts); err != nil {
  1439  					t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
  1440  				}
  1441  			}
  1442  			// Heads up, there is a race here.
  1443  			//
  1444  			// Incoming TCP segments are handled in
  1445  			// tcp.(*endpoint).handleSegmentLocked:
  1446  			//
  1447  			// - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the
  1448  			// segment queue and notifies waiting readers (such as this channel)
  1449  			//
  1450  			// - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment
  1451  			// and notifies the NUD machinery that the peer is reachable
  1452  			//
  1453  			// Thus we must permit a delay between the readable signal and the
  1454  			// expected NUD event.
  1455  			//
  1456  			// At the time of writing, this race is reliably hit with gotsan.
  1457  			<-peerCH
  1458  			for len(nudDisp.c) == 0 {
  1459  				runtime.Gosched()
  1460  			}
  1461  			if err := nudDisp.expectEvent(eventInfo{
  1462  				eventType: entryChanged,
  1463  				nicID:     utils.Host1NICID,
  1464  				entry:     stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1465  			}); err != nil {
  1466  				t.Fatalf("error waiting for delay NUD event: %s", err)
  1467  			}
  1468  			if test.isHost1Listener {
  1469  				// If host1 is not the client, host1 does not send any data so TCP
  1470  				// has no way to know it is making forward progress. Because of this,
  1471  				// TCP should not mark the route reachable and NUD should go through the
  1472  				// probe state.
  1473  				clock.Advance(nudConfigs.DelayFirstProbeTime)
  1474  				if err := nudDisp.expectEvent(eventInfo{
  1475  					eventType: entryChanged,
  1476  					nicID:     utils.Host1NICID,
  1477  					entry:     stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1478  				}); err != nil {
  1479  					t.Fatalf("error waiting for probe NUD event: %s", err)
  1480  				}
  1481  			}
  1482  			{
  1483  				var r bytes.Reader
  1484  				r.Reset([]byte{0})
  1485  				var wOpts tcpip.WriteOptions
  1486  				if _, err := peerEP.Write(&r, wOpts); err != nil {
  1487  					t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err)
  1488  				}
  1489  			}
  1490  			<-clientCH
  1491  			if err := nudDisp.expectEvent(eventInfo{
  1492  				eventType: entryChanged,
  1493  				nicID:     utils.Host1NICID,
  1494  				entry:     stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
  1495  			}); err != nil {
  1496  				t.Fatalf("error waiting for reachable NUD event: %s", err)
  1497  			}
  1498  		})
  1499  	}
  1500  }
  1501  
  1502  func TestDAD(t *testing.T) {
  1503  	dadConfigs := stack.DADConfigurations{
  1504  		DupAddrDetectTransmits: 1,
  1505  		RetransmitTimer:        time.Second,
  1506  	}
  1507  
  1508  	tests := []struct {
  1509  		name           string
  1510  		netProto       tcpip.NetworkProtocolNumber
  1511  		dadNetProto    tcpip.NetworkProtocolNumber
  1512  		remoteAddr     tcpip.Address
  1513  		expectedResult stack.DADResult
  1514  	}{
  1515  		{
  1516  			name:           "IPv4 own address",
  1517  			netProto:       ipv4.ProtocolNumber,
  1518  			dadNetProto:    arp.ProtocolNumber,
  1519  			remoteAddr:     utils.Ipv4Addr1.AddressWithPrefix.Address,
  1520  			expectedResult: &stack.DADSucceeded{},
  1521  		},
  1522  		{
  1523  			name:           "IPv6 own address",
  1524  			netProto:       ipv6.ProtocolNumber,
  1525  			dadNetProto:    ipv6.ProtocolNumber,
  1526  			remoteAddr:     utils.Ipv6Addr1.AddressWithPrefix.Address,
  1527  			expectedResult: &stack.DADSucceeded{},
  1528  		},
  1529  		{
  1530  			name:           "IPv4 duplicate address",
  1531  			netProto:       ipv4.ProtocolNumber,
  1532  			dadNetProto:    arp.ProtocolNumber,
  1533  			remoteAddr:     utils.Ipv4Addr2.AddressWithPrefix.Address,
  1534  			expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
  1535  		},
  1536  		{
  1537  			name:           "IPv6 duplicate address",
  1538  			netProto:       ipv6.ProtocolNumber,
  1539  			dadNetProto:    ipv6.ProtocolNumber,
  1540  			remoteAddr:     utils.Ipv6Addr2.AddressWithPrefix.Address,
  1541  			expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
  1542  		},
  1543  		{
  1544  			name:           "IPv4 no duplicate address",
  1545  			netProto:       ipv4.ProtocolNumber,
  1546  			dadNetProto:    arp.ProtocolNumber,
  1547  			remoteAddr:     utils.Ipv4Addr3.AddressWithPrefix.Address,
  1548  			expectedResult: &stack.DADSucceeded{},
  1549  		},
  1550  		{
  1551  			name:           "IPv6 no duplicate address",
  1552  			netProto:       ipv6.ProtocolNumber,
  1553  			dadNetProto:    ipv6.ProtocolNumber,
  1554  			remoteAddr:     utils.Ipv6Addr3.AddressWithPrefix.Address,
  1555  			expectedResult: &stack.DADSucceeded{},
  1556  		},
  1557  	}
  1558  
  1559  	for _, test := range tests {
  1560  		t.Run(test.name, func(t *testing.T) {
  1561  			clock := faketime.NewManualClock()
  1562  			stackOpts := stack.Options{
  1563  				Clock: clock,
  1564  				NetworkProtocols: []stack.NetworkProtocolFactory{
  1565  					arp.NewProtocol,
  1566  					ipv4.NewProtocol,
  1567  					ipv6.NewProtocol,
  1568  				},
  1569  			}
  1570  
  1571  			host1Stack, _ := setupStack(t, stackOpts, utils.Host1NICID, utils.Host2NICID)
  1572  
  1573  			// DAD should be disabled by default.
  1574  			if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
  1575  				t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled")
  1576  			}); err != nil {
  1577  				t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
  1578  			} else if res != stack.DADDisabled {
  1579  				t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled)
  1580  			}
  1581  
  1582  			// Enable DAD then attempt to check if an address is duplicated.
  1583  			netEP, err := host1Stack.GetNetworkEndpoint(utils.Host1NICID, test.dadNetProto)
  1584  			if err != nil {
  1585  				t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", utils.Host1NICID, test.dadNetProto, err)
  1586  			}
  1587  			dad, ok := netEP.(stack.DuplicateAddressDetector)
  1588  			if !ok {
  1589  				t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP)
  1590  			}
  1591  			dad.SetDADConfigurations(dadConfigs)
  1592  			ch := make(chan stack.DADResult, 3)
  1593  			if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
  1594  				ch <- r
  1595  			}); err != nil {
  1596  				t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
  1597  			} else if res != stack.DADStarting {
  1598  				t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting)
  1599  			}
  1600  
  1601  			expectResults := 1
  1602  			if _, ok := test.expectedResult.(*stack.DADSucceeded); ok {
  1603  				const delta = time.Nanosecond
  1604  				clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
  1605  				select {
  1606  				case r := <-ch:
  1607  					t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r)
  1608  				default:
  1609  				}
  1610  
  1611  				// If we expect the resolve to succeed try requesting DAD again on the
  1612  				// same address. The handler for the new request should be called once
  1613  				// the original DAD request completes.
  1614  				expectResults = 2
  1615  				if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
  1616  					ch <- r
  1617  				}); err != nil {
  1618  					t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
  1619  				} else if res != stack.DADAlreadyRunning {
  1620  					t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning)
  1621  				}
  1622  
  1623  				clock.Advance(delta)
  1624  			}
  1625  
  1626  			for i := 0; i < expectResults; i++ {
  1627  				if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" {
  1628  					t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
  1629  				}
  1630  			}
  1631  
  1632  			// Should have no more results.
  1633  			select {
  1634  			case r := <-ch:
  1635  				t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
  1636  			default:
  1637  			}
  1638  		})
  1639  	}
  1640  }