github.com/polevpn/netstack@v1.10.9/tcpip/transport/udp/udp_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"math/rand"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/polevpn/netstack/tcpip"
    25  	"github.com/polevpn/netstack/tcpip/buffer"
    26  	"github.com/polevpn/netstack/tcpip/checker"
    27  	"github.com/polevpn/netstack/tcpip/header"
    28  	"github.com/polevpn/netstack/tcpip/link/channel"
    29  	"github.com/polevpn/netstack/tcpip/link/loopback"
    30  	"github.com/polevpn/netstack/tcpip/link/sniffer"
    31  	"github.com/polevpn/netstack/tcpip/network/ipv4"
    32  	"github.com/polevpn/netstack/tcpip/network/ipv6"
    33  	"github.com/polevpn/netstack/tcpip/stack"
    34  	"github.com/polevpn/netstack/tcpip/transport/udp"
    35  	"github.com/polevpn/netstack/waiter"
    36  )
    37  
    38  // Addresses and ports used for testing. It is recommended that tests stick to
    39  // using these addresses as it allows using the testFlow helper.
    40  // Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
    41  // represents the remote endpoint.
    42  const (
    43  	v4MappedAddrPrefix    = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
    44  	stackV6Addr           = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
    45  	testV6Addr            = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
    46  	stackV4MappedAddr     = v4MappedAddrPrefix + stackAddr
    47  	testV4MappedAddr      = v4MappedAddrPrefix + testAddr
    48  	multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
    49  	broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
    50  	v4MappedWildcardAddr  = v4MappedAddrPrefix + "\x00\x00\x00\x00"
    51  
    52  	stackAddr       = "\x0a\x00\x00\x01"
    53  	stackPort       = 1234
    54  	testAddr        = "\x0a\x00\x00\x02"
    55  	testPort        = 4096
    56  	multicastAddr   = "\xe8\x2b\xd3\xea"
    57  	multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
    58  	broadcastAddr   = header.IPv4Broadcast
    59  
    60  	// defaultMTU is the MTU, in bytes, used throughout the tests, except
    61  	// where another value is explicitly used. It is chosen to match the MTU
    62  	// of loopback interfaces on linux systems.
    63  	defaultMTU = 65536
    64  )
    65  
    66  // header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
    67  // a packet header. These values are used to populate a header or verify one.
    68  // Note that because they are used in packet headers, the addresses are never in
    69  // a V4-mapped format.
    70  type header4Tuple struct {
    71  	srcAddr tcpip.FullAddress
    72  	dstAddr tcpip.FullAddress
    73  }
    74  
    75  // testFlow implements a helper type used for sending and receiving test
    76  // packets. A given test flow value defines 1) the socket endpoint used for the
    77  // test and 2) the type of packet send or received on the endpoint. E.g., a
    78  // multicastV6Only flow is a V6 multicast packet passing through a V6-only
    79  // endpoint. The type provides helper methods to characterize the flow (e.g.,
    80  // isV4) as well as return a proper header4Tuple for it.
    81  type testFlow int
    82  
    83  const (
    84  	unicastV4       testFlow = iota // V4 unicast on a V4 socket
    85  	unicastV4in6                    // V4-mapped unicast on a V6-dual socket
    86  	unicastV6                       // V6 unicast on a V6 socket
    87  	unicastV6Only                   // V6 unicast on a V6-only socket
    88  	multicastV4                     // V4 multicast on a V4 socket
    89  	multicastV4in6                  // V4-mapped multicast on a V6-dual socket
    90  	multicastV6                     // V6 multicast on a V6 socket
    91  	multicastV6Only                 // V6 multicast on a V6-only socket
    92  	broadcast                       // V4 broadcast on a V4 socket
    93  	broadcastIn6                    // V4-mapped broadcast on a V6-dual socket
    94  )
    95  
    96  func (flow testFlow) String() string {
    97  	switch flow {
    98  	case unicastV4:
    99  		return "unicastV4"
   100  	case unicastV6:
   101  		return "unicastV6"
   102  	case unicastV6Only:
   103  		return "unicastV6Only"
   104  	case unicastV4in6:
   105  		return "unicastV4in6"
   106  	case multicastV4:
   107  		return "multicastV4"
   108  	case multicastV6:
   109  		return "multicastV6"
   110  	case multicastV6Only:
   111  		return "multicastV6Only"
   112  	case multicastV4in6:
   113  		return "multicastV4in6"
   114  	case broadcast:
   115  		return "broadcast"
   116  	case broadcastIn6:
   117  		return "broadcastIn6"
   118  	default:
   119  		return "unknown"
   120  	}
   121  }
   122  
   123  // packetDirection explains if a flow is incoming (read) or outgoing (write).
   124  type packetDirection int
   125  
   126  const (
   127  	incoming packetDirection = iota
   128  	outgoing
   129  )
   130  
   131  // header4Tuple returns the header4Tuple for the given flow and direction. Note
   132  // that the tuple contains no mapped addresses as those only exist at the socket
   133  // level but not at the packet header level.
   134  func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
   135  	var h header4Tuple
   136  	if flow.isV4() {
   137  		if d == outgoing {
   138  			h = header4Tuple{
   139  				srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
   140  				dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
   141  			}
   142  		} else {
   143  			h = header4Tuple{
   144  				srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
   145  				dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
   146  			}
   147  		}
   148  		if flow.isMulticast() {
   149  			h.dstAddr.Addr = multicastAddr
   150  		} else if flow.isBroadcast() {
   151  			h.dstAddr.Addr = broadcastAddr
   152  		}
   153  	} else { // IPv6
   154  		if d == outgoing {
   155  			h = header4Tuple{
   156  				srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
   157  				dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
   158  			}
   159  		} else {
   160  			h = header4Tuple{
   161  				srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
   162  				dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
   163  			}
   164  		}
   165  		if flow.isMulticast() {
   166  			h.dstAddr.Addr = multicastV6Addr
   167  		}
   168  	}
   169  	return h
   170  }
   171  
   172  func (flow testFlow) getMcastAddr() tcpip.Address {
   173  	if flow.isV4() {
   174  		return multicastAddr
   175  	}
   176  	return multicastV6Addr
   177  }
   178  
   179  // mapAddrIfApplicable converts the given V4 address into its V4-mapped version
   180  // if it is applicable to the flow.
   181  func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
   182  	if flow.isMapped() {
   183  		return v4MappedAddrPrefix + v4Addr
   184  	}
   185  	return v4Addr
   186  }
   187  
   188  // netProto returns the protocol number used for the network packet.
   189  func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
   190  	if flow.isV4() {
   191  		return ipv4.ProtocolNumber
   192  	}
   193  	return ipv6.ProtocolNumber
   194  }
   195  
   196  // sockProto returns the protocol number used when creating the socket
   197  // endpoint for this flow.
   198  func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
   199  	switch flow {
   200  	case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
   201  		return ipv6.ProtocolNumber
   202  	case unicastV4, multicastV4, broadcast:
   203  		return ipv4.ProtocolNumber
   204  	default:
   205  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   206  	}
   207  }
   208  
   209  func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
   210  	if flow.isV4() {
   211  		return checker.IPv4
   212  	}
   213  	return checker.IPv6
   214  }
   215  
   216  func (flow testFlow) isV6() bool { return !flow.isV4() }
   217  func (flow testFlow) isV4() bool {
   218  	return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
   219  }
   220  
   221  func (flow testFlow) isV6Only() bool {
   222  	switch flow {
   223  	case unicastV6Only, multicastV6Only:
   224  		return true
   225  	case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
   226  		return false
   227  	default:
   228  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   229  	}
   230  }
   231  
   232  func (flow testFlow) isMulticast() bool {
   233  	switch flow {
   234  	case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
   235  		return true
   236  	case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
   237  		return false
   238  	default:
   239  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   240  	}
   241  }
   242  
   243  func (flow testFlow) isBroadcast() bool {
   244  	switch flow {
   245  	case broadcast, broadcastIn6:
   246  		return true
   247  	case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
   248  		return false
   249  	default:
   250  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   251  	}
   252  }
   253  
   254  func (flow testFlow) isMapped() bool {
   255  	switch flow {
   256  	case unicastV4in6, multicastV4in6, broadcastIn6:
   257  		return true
   258  	case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
   259  		return false
   260  	default:
   261  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   262  	}
   263  }
   264  
   265  type testContext struct {
   266  	t      *testing.T
   267  	linkEP *channel.Endpoint
   268  	s      *stack.Stack
   269  
   270  	ep tcpip.Endpoint
   271  	wq waiter.Queue
   272  }
   273  
   274  func newDualTestContext(t *testing.T, mtu uint32) *testContext {
   275  	t.Helper()
   276  
   277  	s := stack.New(stack.Options{
   278  		NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
   279  		TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
   280  	})
   281  	ep := channel.New(256, mtu, "")
   282  	wep := stack.LinkEndpoint(ep)
   283  
   284  	if testing.Verbose() {
   285  		wep = sniffer.New(ep)
   286  	}
   287  	if err := s.CreateNIC(1, wep); err != nil {
   288  		t.Fatalf("CreateNIC failed: %v", err)
   289  	}
   290  
   291  	if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
   292  		t.Fatalf("AddAddress failed: %v", err)
   293  	}
   294  
   295  	if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
   296  		t.Fatalf("AddAddress failed: %v", err)
   297  	}
   298  
   299  	s.SetRouteTable([]tcpip.Route{
   300  		{
   301  			Destination: header.IPv4EmptySubnet,
   302  			NIC:         1,
   303  		},
   304  		{
   305  			Destination: header.IPv6EmptySubnet,
   306  			NIC:         1,
   307  		},
   308  	})
   309  
   310  	return &testContext{
   311  		t:      t,
   312  		s:      s,
   313  		linkEP: ep,
   314  	}
   315  }
   316  
   317  func (c *testContext) cleanup() {
   318  	if c.ep != nil {
   319  		c.ep.Close()
   320  	}
   321  }
   322  
   323  func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
   324  	c.t.Helper()
   325  
   326  	var err *tcpip.Error
   327  	c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
   328  	if err != nil {
   329  		c.t.Fatal("NewEndpoint failed: ", err)
   330  	}
   331  }
   332  
   333  func (c *testContext) createEndpointForFlow(flow testFlow) {
   334  	c.t.Helper()
   335  
   336  	c.createEndpoint(flow.sockProto())
   337  	if flow.isV6Only() {
   338  		if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
   339  			c.t.Fatalf("SetSockOpt failed: %v", err)
   340  		}
   341  	} else if flow.isBroadcast() {
   342  		if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
   343  			c.t.Fatal("SetSockOpt failed:", err)
   344  		}
   345  	}
   346  }
   347  
   348  // getPacketAndVerify reads a packet from the link endpoint and verifies the
   349  // header against expected values from the given test flow. In addition, it
   350  // calls any extra checker functions provided.
   351  func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
   352  	c.t.Helper()
   353  
   354  	select {
   355  	case p := <-c.linkEP.C:
   356  		if p.Proto != flow.netProto() {
   357  			c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
   358  		}
   359  
   360  		hdr := p.Pkt.Header.View()
   361  		b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
   362  
   363  		h := flow.header4Tuple(outgoing)
   364  		checkers := append(
   365  			checkers,
   366  			checker.SrcAddr(h.srcAddr.Addr),
   367  			checker.DstAddr(h.dstAddr.Addr),
   368  			checker.UDP(checker.DstPort(h.dstAddr.Port)),
   369  		)
   370  		flow.checkerFn()(c.t, b, checkers...)
   371  		return b
   372  
   373  	case <-time.After(2 * time.Second):
   374  		c.t.Fatalf("Packet wasn't written out")
   375  	}
   376  
   377  	return nil
   378  }
   379  
   380  // injectPacket creates a packet of the given flow and with the given payload,
   381  // and injects it into the link endpoint.
   382  func (c *testContext) injectPacket(flow testFlow, payload []byte) {
   383  	c.t.Helper()
   384  
   385  	h := flow.header4Tuple(incoming)
   386  	if flow.isV4() {
   387  		c.injectV4Packet(payload, &h, true /* valid */)
   388  	} else {
   389  		c.injectV6Packet(payload, &h, true /* valid */)
   390  	}
   391  }
   392  
   393  // injectV6Packet creates a V6 test packet with the given payload and header
   394  // values, and injects it into the link endpoint. valid indicates if the
   395  // caller intends to inject a packet with a valid or an invalid UDP header.
   396  // We can invalidate the header by corrupting the UDP payload length.
   397  func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
   398  	// Allocate a buffer for data and headers.
   399  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
   400  	payloadStart := len(buf) - len(payload)
   401  	copy(buf[payloadStart:], payload)
   402  
   403  	// Initialize the IP header.
   404  	ip := header.IPv6(buf)
   405  	ip.Encode(&header.IPv6Fields{
   406  		PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
   407  		NextHeader:    uint8(udp.ProtocolNumber),
   408  		HopLimit:      65,
   409  		SrcAddr:       h.srcAddr.Addr,
   410  		DstAddr:       h.dstAddr.Addr,
   411  	})
   412  
   413  	// Initialize the UDP header.
   414  	u := header.UDP(buf[header.IPv6MinimumSize:])
   415  	l := uint16(header.UDPMinimumSize + len(payload))
   416  	if !valid {
   417  		// Change the UDP payload length to corrupt the header
   418  		// as requested by the caller.
   419  		l++
   420  	}
   421  	u.Encode(&header.UDPFields{
   422  		SrcPort: h.srcAddr.Port,
   423  		DstPort: h.dstAddr.Port,
   424  		Length:  l,
   425  	})
   426  
   427  	// Calculate the UDP pseudo-header checksum.
   428  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
   429  
   430  	// Calculate the UDP checksum and set it.
   431  	xsum = header.Checksum(payload, xsum)
   432  	u.SetChecksum(^u.CalculateChecksum(xsum))
   433  
   434  	// Inject packet.
   435  	c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
   436  		Data:            buf.ToVectorisedView(),
   437  		NetworkHeader:   buffer.View(ip),
   438  		TransportHeader: buffer.View(u),
   439  	})
   440  }
   441  
   442  // injectV4Packet creates a V4 test packet with the given payload and header
   443  // values, and injects it into the link endpoint. valid indicates if the
   444  // caller intends to inject a packet with a valid or an invalid UDP header.
   445  // We can invalidate the header by corrupting the UDP payload length.
   446  func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
   447  	// Allocate a buffer for data and headers.
   448  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
   449  	payloadStart := len(buf) - len(payload)
   450  	copy(buf[payloadStart:], payload)
   451  
   452  	// Initialize the IP header.
   453  	ip := header.IPv4(buf)
   454  	ip.Encode(&header.IPv4Fields{
   455  		IHL:         header.IPv4MinimumSize,
   456  		TotalLength: uint16(len(buf)),
   457  		TTL:         65,
   458  		Protocol:    uint8(udp.ProtocolNumber),
   459  		SrcAddr:     h.srcAddr.Addr,
   460  		DstAddr:     h.dstAddr.Addr,
   461  	})
   462  	ip.SetChecksum(^ip.CalculateChecksum())
   463  
   464  	// Initialize the UDP header.
   465  	u := header.UDP(buf[header.IPv4MinimumSize:])
   466  	u.Encode(&header.UDPFields{
   467  		SrcPort: h.srcAddr.Port,
   468  		DstPort: h.dstAddr.Port,
   469  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   470  	})
   471  
   472  	// Calculate the UDP pseudo-header checksum.
   473  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
   474  
   475  	// Calculate the UDP checksum and set it.
   476  	xsum = header.Checksum(payload, xsum)
   477  	u.SetChecksum(^u.CalculateChecksum(xsum))
   478  
   479  	// Inject packet.
   480  
   481  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
   482  		Data:            buf.ToVectorisedView(),
   483  		NetworkHeader:   buffer.View(ip),
   484  		TransportHeader: buffer.View(u),
   485  	})
   486  }
   487  
   488  func newPayload() []byte {
   489  	return newMinPayload(30)
   490  }
   491  
   492  func newMinPayload(minSize int) []byte {
   493  	b := make([]byte, minSize+rand.Intn(100))
   494  	for i := range b {
   495  		b[i] = byte(rand.Intn(256))
   496  	}
   497  	return b
   498  }
   499  
   500  func TestBindToDeviceOption(t *testing.T) {
   501  	s := stack.New(stack.Options{
   502  		NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol()},
   503  		TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
   504  
   505  	ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
   506  	if err != nil {
   507  		t.Fatalf("NewEndpoint failed; %v", err)
   508  	}
   509  	defer ep.Close()
   510  
   511  	if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
   512  		t.Errorf("CreateNamedNIC failed: %v", err)
   513  	}
   514  
   515  	// Make an nameless NIC.
   516  	if err := s.CreateNIC(54321, loopback.New()); err != nil {
   517  		t.Errorf("CreateNIC failed: %v", err)
   518  	}
   519  
   520  	// strPtr is used instead of taking the address of string literals, which is
   521  	// a compiler error.
   522  	strPtr := func(s string) *string {
   523  		return &s
   524  	}
   525  
   526  	testActions := []struct {
   527  		name                 string
   528  		setBindToDevice      *string
   529  		setBindToDeviceError *tcpip.Error
   530  		getBindToDevice      tcpip.BindToDeviceOption
   531  	}{
   532  		{"GetDefaultValue", nil, nil, ""},
   533  		{"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
   534  		{"BindToExistent", strPtr("my_device"), nil, "my_device"},
   535  		{"UnbindToDevice", strPtr(""), nil, ""},
   536  	}
   537  	for _, testAction := range testActions {
   538  		t.Run(testAction.name, func(t *testing.T) {
   539  			if testAction.setBindToDevice != nil {
   540  				bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
   541  				if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
   542  					t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
   543  				}
   544  			}
   545  			bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
   546  			if ep.GetSockOpt(&bindToDevice) != nil {
   547  				t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
   548  			}
   549  			if got, want := bindToDevice, testAction.getBindToDevice; got != want {
   550  				t.Errorf("bindToDevice got %q, want %q", got, want)
   551  			}
   552  		})
   553  	}
   554  }
   555  
   556  // testReadInternal sends a packet of the given test flow into the stack by
   557  // injecting it into the link endpoint. It then attempts to read it from the
   558  // UDP endpoint and depending on if this was expected to succeed verifies its
   559  // correctness.
   560  func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
   561  	c.t.Helper()
   562  
   563  	payload := newPayload()
   564  	c.injectPacket(flow, payload)
   565  
   566  	// Try to receive the data.
   567  	we, ch := waiter.NewChannelEntry(nil)
   568  	c.wq.EventRegister(&we, waiter.EventIn)
   569  	defer c.wq.EventUnregister(&we)
   570  
   571  	// Take a snapshot of the stats to validate them at the end of the test.
   572  	epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
   573  
   574  	var addr tcpip.FullAddress
   575  	v, _, err := c.ep.Read(&addr)
   576  	if err == tcpip.ErrWouldBlock {
   577  		// Wait for data to become available.
   578  		select {
   579  		case <-ch:
   580  			v, _, err = c.ep.Read(&addr)
   581  
   582  		case <-time.After(300 * time.Millisecond):
   583  			if packetShouldBeDropped {
   584  				return // expected to time out
   585  			}
   586  			c.t.Fatal("timed out waiting for data")
   587  		}
   588  	}
   589  
   590  	if expectReadError && err != nil {
   591  		c.checkEndpointReadStats(1, epstats, err)
   592  		return
   593  	}
   594  
   595  	if err != nil {
   596  		c.t.Fatal("Read failed:", err)
   597  	}
   598  
   599  	if packetShouldBeDropped {
   600  		c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
   601  	}
   602  
   603  	// Check the peer address.
   604  	h := flow.header4Tuple(incoming)
   605  	if addr.Addr != h.srcAddr.Addr {
   606  		c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
   607  	}
   608  
   609  	// Check the payload.
   610  	if !bytes.Equal(payload, v) {
   611  		c.t.Fatalf("bad payload: got %x, want %x", v, payload)
   612  	}
   613  	c.checkEndpointReadStats(1, epstats, err)
   614  }
   615  
   616  // testRead sends a packet of the given test flow into the stack by injecting it
   617  // into the link endpoint. It then reads it from the UDP endpoint and verifies
   618  // its correctness.
   619  func testRead(c *testContext, flow testFlow) {
   620  	c.t.Helper()
   621  	testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
   622  }
   623  
   624  // testFailingRead sends a packet of the given test flow into the stack by
   625  // injecting it into the link endpoint. It then tries to read it from the UDP
   626  // endpoint and expects this to fail.
   627  func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
   628  	c.t.Helper()
   629  	testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
   630  }
   631  
   632  func TestBindEphemeralPort(t *testing.T) {
   633  	c := newDualTestContext(t, defaultMTU)
   634  	defer c.cleanup()
   635  
   636  	c.createEndpoint(ipv6.ProtocolNumber)
   637  
   638  	if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
   639  		t.Fatalf("ep.Bind(...) failed: %v", err)
   640  	}
   641  }
   642  
   643  func TestBindReservedPort(t *testing.T) {
   644  	c := newDualTestContext(t, defaultMTU)
   645  	defer c.cleanup()
   646  
   647  	c.createEndpoint(ipv6.ProtocolNumber)
   648  
   649  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
   650  		c.t.Fatalf("Connect failed: %v", err)
   651  	}
   652  
   653  	addr, err := c.ep.GetLocalAddress()
   654  	if err != nil {
   655  		t.Fatalf("GetLocalAddress failed: %v", err)
   656  	}
   657  
   658  	// We can't bind the address reserved by the connected endpoint above.
   659  	{
   660  		ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
   661  		if err != nil {
   662  			t.Fatalf("NewEndpoint failed: %v", err)
   663  		}
   664  		defer ep.Close()
   665  		if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
   666  			t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
   667  		}
   668  	}
   669  
   670  	func() {
   671  		ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
   672  		if err != nil {
   673  			t.Fatalf("NewEndpoint failed: %v", err)
   674  		}
   675  		defer ep.Close()
   676  		// We can't bind ipv4-any on the port reserved by the connected endpoint
   677  		// above, since the endpoint is dual-stack.
   678  		if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
   679  			t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
   680  		}
   681  		// We can bind an ipv4 address on this port, though.
   682  		if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
   683  			t.Fatalf("ep.Bind(...) failed: %v", err)
   684  		}
   685  	}()
   686  
   687  	// Once the connected endpoint releases its port reservation, we are able to
   688  	// bind ipv4-any once again.
   689  	c.ep.Close()
   690  	func() {
   691  		ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
   692  		if err != nil {
   693  			t.Fatalf("NewEndpoint failed: %v", err)
   694  		}
   695  		defer ep.Close()
   696  		if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
   697  			t.Fatalf("ep.Bind(...) failed: %v", err)
   698  		}
   699  	}()
   700  }
   701  
   702  func TestV4ReadOnV6(t *testing.T) {
   703  	c := newDualTestContext(t, defaultMTU)
   704  	defer c.cleanup()
   705  
   706  	c.createEndpointForFlow(unicastV4in6)
   707  
   708  	// Bind to wildcard.
   709  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   710  		c.t.Fatalf("Bind failed: %v", err)
   711  	}
   712  
   713  	// Test acceptance.
   714  	testRead(c, unicastV4in6)
   715  }
   716  
   717  func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
   718  	c := newDualTestContext(t, defaultMTU)
   719  	defer c.cleanup()
   720  
   721  	c.createEndpointForFlow(unicastV4in6)
   722  
   723  	// Bind to v4 mapped wildcard.
   724  	if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
   725  		c.t.Fatalf("Bind failed: %v", err)
   726  	}
   727  
   728  	// Test acceptance.
   729  	testRead(c, unicastV4in6)
   730  }
   731  
   732  func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
   733  	c := newDualTestContext(t, defaultMTU)
   734  	defer c.cleanup()
   735  
   736  	c.createEndpointForFlow(unicastV4in6)
   737  
   738  	// Bind to local address.
   739  	if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
   740  		c.t.Fatalf("Bind failed: %v", err)
   741  	}
   742  
   743  	// Test acceptance.
   744  	testRead(c, unicastV4in6)
   745  }
   746  
   747  func TestV6ReadOnV6(t *testing.T) {
   748  	c := newDualTestContext(t, defaultMTU)
   749  	defer c.cleanup()
   750  
   751  	c.createEndpointForFlow(unicastV6)
   752  
   753  	// Bind to wildcard.
   754  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   755  		c.t.Fatalf("Bind failed: %v", err)
   756  	}
   757  
   758  	// Test acceptance.
   759  	testRead(c, unicastV6)
   760  }
   761  
   762  func TestV4ReadOnV4(t *testing.T) {
   763  	c := newDualTestContext(t, defaultMTU)
   764  	defer c.cleanup()
   765  
   766  	c.createEndpointForFlow(unicastV4)
   767  
   768  	// Bind to wildcard.
   769  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   770  		c.t.Fatalf("Bind failed: %v", err)
   771  	}
   772  
   773  	// Test acceptance.
   774  	testRead(c, unicastV4)
   775  }
   776  
   777  // TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
   778  // address and receive data sent to that address.
   779  func TestReadOnBoundToMulticast(t *testing.T) {
   780  	// FIXME(b/128189410): multicastV4in6 currently doesn't work as
   781  	// AddMembershipOption doesn't handle V4in6 addresses.
   782  	for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
   783  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   784  			c := newDualTestContext(t, defaultMTU)
   785  			defer c.cleanup()
   786  
   787  			c.createEndpointForFlow(flow)
   788  
   789  			// Bind to multicast address.
   790  			mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
   791  			if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
   792  				c.t.Fatal("Bind failed:", err)
   793  			}
   794  
   795  			// Join multicast group.
   796  			ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
   797  			if err := c.ep.SetSockOpt(ifoptSet); err != nil {
   798  				c.t.Fatal("SetSockOpt failed:", err)
   799  			}
   800  
   801  			// Check that we receive multicast packets but not unicast or broadcast
   802  			// ones.
   803  			testRead(c, flow)
   804  			testFailingRead(c, broadcast, false /* expectReadError */)
   805  			testFailingRead(c, unicastV4, false /* expectReadError */)
   806  		})
   807  	}
   808  }
   809  
   810  // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
   811  // address and can receive only broadcast data.
   812  func TestV4ReadOnBoundToBroadcast(t *testing.T) {
   813  	for _, flow := range []testFlow{broadcast, broadcastIn6} {
   814  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   815  			c := newDualTestContext(t, defaultMTU)
   816  			defer c.cleanup()
   817  
   818  			c.createEndpointForFlow(flow)
   819  
   820  			// Bind to broadcast address.
   821  			bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
   822  			if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
   823  				c.t.Fatalf("Bind failed: %s", err)
   824  			}
   825  
   826  			// Check that we receive broadcast packets but not unicast ones.
   827  			testRead(c, flow)
   828  			testFailingRead(c, unicastV4, false /* expectReadError */)
   829  		})
   830  	}
   831  }
   832  
   833  // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
   834  // and receive broadcast and unicast data.
   835  func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
   836  	for _, flow := range []testFlow{broadcast, broadcastIn6} {
   837  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   838  			c := newDualTestContext(t, defaultMTU)
   839  			defer c.cleanup()
   840  
   841  			c.createEndpointForFlow(flow)
   842  
   843  			// Bind to wildcard.
   844  			if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   845  				c.t.Fatalf("Bind failed: %s (", err)
   846  			}
   847  
   848  			// Check that we receive both broadcast and unicast packets.
   849  			testRead(c, flow)
   850  			testRead(c, unicastV4)
   851  		})
   852  	}
   853  }
   854  
   855  // testFailingWrite sends a packet of the given test flow into the UDP endpoint
   856  // and verifies it fails with the provided error code.
   857  func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
   858  	c.t.Helper()
   859  	// Take a snapshot of the stats to validate them at the end of the test.
   860  	epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
   861  	h := flow.header4Tuple(outgoing)
   862  	writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
   863  
   864  	payload := buffer.View(newPayload())
   865  	_, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
   866  		To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
   867  	})
   868  	c.checkEndpointWriteStats(1, epstats, gotErr)
   869  	if gotErr != wantErr {
   870  		c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
   871  	}
   872  }
   873  
   874  // testWrite sends a packet of the given test flow from the UDP endpoint to the
   875  // flow's destination address:port. It then receives it from the link endpoint
   876  // and verifies its correctness including any additional checker functions
   877  // provided.
   878  func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
   879  	c.t.Helper()
   880  	return testWriteInternal(c, flow, true, checkers...)
   881  }
   882  
   883  // testWriteWithoutDestination sends a packet of the given test flow from the
   884  // UDP endpoint without giving a destination address:port. It then receives it
   885  // from the link endpoint and verifies its correctness including any additional
   886  // checker functions provided.
   887  func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
   888  	c.t.Helper()
   889  	return testWriteInternal(c, flow, false, checkers...)
   890  }
   891  
   892  func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
   893  	c.t.Helper()
   894  	// Take a snapshot of the stats to validate them at the end of the test.
   895  	epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
   896  
   897  	writeOpts := tcpip.WriteOptions{}
   898  	if setDest {
   899  		h := flow.header4Tuple(outgoing)
   900  		writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
   901  		writeOpts = tcpip.WriteOptions{
   902  			To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
   903  		}
   904  	}
   905  	payload := buffer.View(newPayload())
   906  	n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
   907  	if err != nil {
   908  		c.t.Fatalf("Write failed: %v", err)
   909  	}
   910  	if n != int64(len(payload)) {
   911  		c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
   912  	}
   913  	c.checkEndpointWriteStats(1, epstats, err)
   914  	// Received the packet and check the payload.
   915  	b := c.getPacketAndVerify(flow, checkers...)
   916  	var udp header.UDP
   917  	if flow.isV4() {
   918  		udp = header.UDP(header.IPv4(b).Payload())
   919  	} else {
   920  		udp = header.UDP(header.IPv6(b).Payload())
   921  	}
   922  	if !bytes.Equal(payload, udp.Payload()) {
   923  		c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
   924  	}
   925  
   926  	return udp.SourcePort()
   927  }
   928  
   929  func testDualWrite(c *testContext) uint16 {
   930  	c.t.Helper()
   931  
   932  	v4Port := testWrite(c, unicastV4in6)
   933  	v6Port := testWrite(c, unicastV6)
   934  	if v4Port != v6Port {
   935  		c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
   936  	}
   937  
   938  	return v4Port
   939  }
   940  
   941  func TestDualWriteUnbound(t *testing.T) {
   942  	c := newDualTestContext(t, defaultMTU)
   943  	defer c.cleanup()
   944  
   945  	c.createEndpoint(ipv6.ProtocolNumber)
   946  
   947  	testDualWrite(c)
   948  }
   949  
   950  func TestDualWriteBoundToWildcard(t *testing.T) {
   951  	c := newDualTestContext(t, defaultMTU)
   952  	defer c.cleanup()
   953  
   954  	c.createEndpoint(ipv6.ProtocolNumber)
   955  
   956  	// Bind to wildcard.
   957  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   958  		c.t.Fatalf("Bind failed: %v", err)
   959  	}
   960  
   961  	p := testDualWrite(c)
   962  	if p != stackPort {
   963  		c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
   964  	}
   965  }
   966  
   967  func TestDualWriteConnectedToV6(t *testing.T) {
   968  	c := newDualTestContext(t, defaultMTU)
   969  	defer c.cleanup()
   970  
   971  	c.createEndpoint(ipv6.ProtocolNumber)
   972  
   973  	// Connect to v6 address.
   974  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
   975  		c.t.Fatalf("Bind failed: %v", err)
   976  	}
   977  
   978  	testWrite(c, unicastV6)
   979  
   980  	// Write to V4 mapped address.
   981  	testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
   982  	const want = 1
   983  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
   984  		c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
   985  	}
   986  }
   987  
   988  func TestDualWriteConnectedToV4Mapped(t *testing.T) {
   989  	c := newDualTestContext(t, defaultMTU)
   990  	defer c.cleanup()
   991  
   992  	c.createEndpoint(ipv6.ProtocolNumber)
   993  
   994  	// Connect to v4 mapped address.
   995  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
   996  		c.t.Fatalf("Bind failed: %v", err)
   997  	}
   998  
   999  	testWrite(c, unicastV4in6)
  1000  
  1001  	// Write to v6 address.
  1002  	testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
  1003  }
  1004  
  1005  func TestV4WriteOnV6Only(t *testing.T) {
  1006  	c := newDualTestContext(t, defaultMTU)
  1007  	defer c.cleanup()
  1008  
  1009  	c.createEndpointForFlow(unicastV6Only)
  1010  
  1011  	// Write to V4 mapped address.
  1012  	testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
  1013  }
  1014  
  1015  func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
  1016  	c := newDualTestContext(t, defaultMTU)
  1017  	defer c.cleanup()
  1018  
  1019  	c.createEndpoint(ipv6.ProtocolNumber)
  1020  
  1021  	// Bind to v4 mapped address.
  1022  	if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
  1023  		c.t.Fatalf("Bind failed: %v", err)
  1024  	}
  1025  
  1026  	// Write to v6 address.
  1027  	testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
  1028  }
  1029  
  1030  func TestV6WriteOnConnected(t *testing.T) {
  1031  	c := newDualTestContext(t, defaultMTU)
  1032  	defer c.cleanup()
  1033  
  1034  	c.createEndpoint(ipv6.ProtocolNumber)
  1035  
  1036  	// Connect to v6 address.
  1037  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  1038  		c.t.Fatalf("Connect failed: %v", err)
  1039  	}
  1040  
  1041  	testWriteWithoutDestination(c, unicastV6)
  1042  }
  1043  
  1044  func TestV4WriteOnConnected(t *testing.T) {
  1045  	c := newDualTestContext(t, defaultMTU)
  1046  	defer c.cleanup()
  1047  
  1048  	c.createEndpoint(ipv6.ProtocolNumber)
  1049  
  1050  	// Connect to v4 mapped address.
  1051  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
  1052  		c.t.Fatalf("Connect failed: %v", err)
  1053  	}
  1054  
  1055  	testWriteWithoutDestination(c, unicastV4)
  1056  }
  1057  
  1058  // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
  1059  // that is bound to a V4 multicast address.
  1060  func TestWriteOnBoundToV4Multicast(t *testing.T) {
  1061  	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
  1062  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1063  			c := newDualTestContext(t, defaultMTU)
  1064  			defer c.cleanup()
  1065  
  1066  			c.createEndpointForFlow(flow)
  1067  
  1068  			// Bind to V4 mcast address.
  1069  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
  1070  				c.t.Fatal("Bind failed:", err)
  1071  			}
  1072  
  1073  			testWrite(c, flow)
  1074  		})
  1075  	}
  1076  }
  1077  
  1078  // TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
  1079  // socket that is bound to a V4-mapped multicast address.
  1080  func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
  1081  	for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
  1082  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1083  			c := newDualTestContext(t, defaultMTU)
  1084  			defer c.cleanup()
  1085  
  1086  			c.createEndpointForFlow(flow)
  1087  
  1088  			// Bind to V4Mapped mcast address.
  1089  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
  1090  				c.t.Fatalf("Bind failed: %s", err)
  1091  			}
  1092  
  1093  			testWrite(c, flow)
  1094  		})
  1095  	}
  1096  }
  1097  
  1098  // TestWriteOnBoundToV6Multicast checks that we can send packets out of a
  1099  // socket that is bound to a V6 multicast address.
  1100  func TestWriteOnBoundToV6Multicast(t *testing.T) {
  1101  	for _, flow := range []testFlow{unicastV6, multicastV6} {
  1102  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1103  			c := newDualTestContext(t, defaultMTU)
  1104  			defer c.cleanup()
  1105  
  1106  			c.createEndpointForFlow(flow)
  1107  
  1108  			// Bind to V6 mcast address.
  1109  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
  1110  				c.t.Fatalf("Bind failed: %s", err)
  1111  			}
  1112  
  1113  			testWrite(c, flow)
  1114  		})
  1115  	}
  1116  }
  1117  
  1118  // TestWriteOnBoundToV6Multicast checks that we can send packets out of a
  1119  // V6-only socket that is bound to a V6 multicast address.
  1120  func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
  1121  	for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
  1122  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1123  			c := newDualTestContext(t, defaultMTU)
  1124  			defer c.cleanup()
  1125  
  1126  			c.createEndpointForFlow(flow)
  1127  
  1128  			// Bind to V6 mcast address.
  1129  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
  1130  				c.t.Fatalf("Bind failed: %s", err)
  1131  			}
  1132  
  1133  			testWrite(c, flow)
  1134  		})
  1135  	}
  1136  }
  1137  
  1138  // TestWriteOnBoundToBroadcast checks that we can send packets out of a
  1139  // socket that is bound to the broadcast address.
  1140  func TestWriteOnBoundToBroadcast(t *testing.T) {
  1141  	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
  1142  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1143  			c := newDualTestContext(t, defaultMTU)
  1144  			defer c.cleanup()
  1145  
  1146  			c.createEndpointForFlow(flow)
  1147  
  1148  			// Bind to V4 broadcast address.
  1149  			if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
  1150  				c.t.Fatal("Bind failed:", err)
  1151  			}
  1152  
  1153  			testWrite(c, flow)
  1154  		})
  1155  	}
  1156  }
  1157  
  1158  // TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
  1159  // socket that is bound to the V4-mapped broadcast address.
  1160  func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
  1161  	for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
  1162  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1163  			c := newDualTestContext(t, defaultMTU)
  1164  			defer c.cleanup()
  1165  
  1166  			c.createEndpointForFlow(flow)
  1167  
  1168  			// Bind to V4Mapped mcast address.
  1169  			if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
  1170  				c.t.Fatalf("Bind failed: %s", err)
  1171  			}
  1172  
  1173  			testWrite(c, flow)
  1174  		})
  1175  	}
  1176  }
  1177  
  1178  func TestReadIncrementsPacketsReceived(t *testing.T) {
  1179  	c := newDualTestContext(t, defaultMTU)
  1180  	defer c.cleanup()
  1181  
  1182  	// Create IPv4 UDP endpoint
  1183  	c.createEndpoint(ipv6.ProtocolNumber)
  1184  
  1185  	// Bind to wildcard.
  1186  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  1187  		c.t.Fatalf("Bind failed: %v", err)
  1188  	}
  1189  
  1190  	testRead(c, unicastV4)
  1191  
  1192  	var want uint64 = 1
  1193  	if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
  1194  		c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
  1195  	}
  1196  }
  1197  
  1198  func TestWriteIncrementsPacketsSent(t *testing.T) {
  1199  	c := newDualTestContext(t, defaultMTU)
  1200  	defer c.cleanup()
  1201  
  1202  	c.createEndpoint(ipv6.ProtocolNumber)
  1203  
  1204  	testDualWrite(c)
  1205  
  1206  	var want uint64 = 2
  1207  	if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
  1208  		c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
  1209  	}
  1210  }
  1211  
  1212  func TestTTL(t *testing.T) {
  1213  	for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
  1214  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1215  			c := newDualTestContext(t, defaultMTU)
  1216  			defer c.cleanup()
  1217  
  1218  			c.createEndpointForFlow(flow)
  1219  
  1220  			const multicastTTL = 42
  1221  			if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
  1222  				c.t.Fatalf("SetSockOpt failed: %v", err)
  1223  			}
  1224  
  1225  			var wantTTL uint8
  1226  			if flow.isMulticast() {
  1227  				wantTTL = multicastTTL
  1228  			} else {
  1229  				var p stack.NetworkProtocol
  1230  				if flow.isV4() {
  1231  					p = ipv4.NewProtocol()
  1232  				} else {
  1233  					p = ipv6.NewProtocol()
  1234  				}
  1235  				ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
  1236  				if err != nil {
  1237  					t.Fatal(err)
  1238  				}
  1239  				wantTTL = ep.DefaultTTL()
  1240  				ep.Close()
  1241  			}
  1242  
  1243  			testWrite(c, flow, checker.TTL(wantTTL))
  1244  		})
  1245  	}
  1246  }
  1247  
  1248  func TestSetTTL(t *testing.T) {
  1249  	for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
  1250  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1251  			for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
  1252  				t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
  1253  					c := newDualTestContext(t, defaultMTU)
  1254  					defer c.cleanup()
  1255  
  1256  					c.createEndpointForFlow(flow)
  1257  
  1258  					if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
  1259  						c.t.Fatalf("SetSockOpt failed: %v", err)
  1260  					}
  1261  
  1262  					var p stack.NetworkProtocol
  1263  					if flow.isV4() {
  1264  						p = ipv4.NewProtocol()
  1265  					} else {
  1266  						p = ipv6.NewProtocol()
  1267  					}
  1268  					ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
  1269  					if err != nil {
  1270  						t.Fatal(err)
  1271  					}
  1272  					ep.Close()
  1273  
  1274  					testWrite(c, flow, checker.TTL(wantTTL))
  1275  				})
  1276  			}
  1277  		})
  1278  	}
  1279  }
  1280  
  1281  func TestTOSV4(t *testing.T) {
  1282  	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
  1283  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1284  			c := newDualTestContext(t, defaultMTU)
  1285  			defer c.cleanup()
  1286  
  1287  			c.createEndpointForFlow(flow)
  1288  
  1289  			const tos = 0xC0
  1290  			var v tcpip.IPv4TOSOption
  1291  			if err := c.ep.GetSockOpt(&v); err != nil {
  1292  				c.t.Errorf("GetSockopt failed: %s", err)
  1293  			}
  1294  			// Test for expected default value.
  1295  			if v != 0 {
  1296  				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
  1297  			}
  1298  
  1299  			if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
  1300  				c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
  1301  			}
  1302  
  1303  			if err := c.ep.GetSockOpt(&v); err != nil {
  1304  				c.t.Errorf("GetSockopt failed: %s", err)
  1305  			}
  1306  
  1307  			if want := tcpip.IPv4TOSOption(tos); v != want {
  1308  				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
  1309  			}
  1310  
  1311  			testWrite(c, flow, checker.TOS(tos, 0))
  1312  		})
  1313  	}
  1314  }
  1315  
  1316  func TestTOSV6(t *testing.T) {
  1317  	for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
  1318  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1319  			c := newDualTestContext(t, defaultMTU)
  1320  			defer c.cleanup()
  1321  
  1322  			c.createEndpointForFlow(flow)
  1323  
  1324  			const tos = 0xC0
  1325  			var v tcpip.IPv6TrafficClassOption
  1326  			if err := c.ep.GetSockOpt(&v); err != nil {
  1327  				c.t.Errorf("GetSockopt failed: %s", err)
  1328  			}
  1329  			// Test for expected default value.
  1330  			if v != 0 {
  1331  				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
  1332  			}
  1333  
  1334  			if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
  1335  				c.t.Errorf("SetSockOpt failed: %s", err)
  1336  			}
  1337  
  1338  			if err := c.ep.GetSockOpt(&v); err != nil {
  1339  				c.t.Errorf("GetSockopt failed: %s", err)
  1340  			}
  1341  
  1342  			if want := tcpip.IPv6TrafficClassOption(tos); v != want {
  1343  				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
  1344  			}
  1345  
  1346  			testWrite(c, flow, checker.TOS(tos, 0))
  1347  		})
  1348  	}
  1349  }
  1350  
  1351  func TestMulticastInterfaceOption(t *testing.T) {
  1352  	for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
  1353  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1354  			for _, bindTyp := range []string{"bound", "unbound"} {
  1355  				t.Run(bindTyp, func(t *testing.T) {
  1356  					for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
  1357  						t.Run(optTyp, func(t *testing.T) {
  1358  							h := flow.header4Tuple(outgoing)
  1359  							mcastAddr := h.dstAddr.Addr
  1360  							localIfAddr := h.srcAddr.Addr
  1361  
  1362  							var ifoptSet tcpip.MulticastInterfaceOption
  1363  							switch optTyp {
  1364  							case "use local-addr":
  1365  								ifoptSet.InterfaceAddr = localIfAddr
  1366  							case "use NICID":
  1367  								ifoptSet.NIC = 1
  1368  							case "use local-addr and NIC":
  1369  								ifoptSet.InterfaceAddr = localIfAddr
  1370  								ifoptSet.NIC = 1
  1371  							default:
  1372  								t.Fatal("unknown test variant")
  1373  							}
  1374  
  1375  							c := newDualTestContext(t, defaultMTU)
  1376  							defer c.cleanup()
  1377  
  1378  							c.createEndpoint(flow.sockProto())
  1379  
  1380  							if bindTyp == "bound" {
  1381  								// Bind the socket by connecting to the multicast address.
  1382  								// This may have an influence on how the multicast interface
  1383  								// is set.
  1384  								addr := tcpip.FullAddress{
  1385  									Addr: flow.mapAddrIfApplicable(mcastAddr),
  1386  									Port: stackPort,
  1387  								}
  1388  								if err := c.ep.Connect(addr); err != nil {
  1389  									c.t.Fatalf("Connect failed: %v", err)
  1390  								}
  1391  							}
  1392  
  1393  							if err := c.ep.SetSockOpt(ifoptSet); err != nil {
  1394  								c.t.Fatalf("SetSockOpt failed: %v", err)
  1395  							}
  1396  
  1397  							// Verify multicast interface addr and NIC were set correctly.
  1398  							// Note that NIC must be 1 since this is our outgoing interface.
  1399  							ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
  1400  							var ifoptGot tcpip.MulticastInterfaceOption
  1401  							if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
  1402  								c.t.Fatalf("GetSockOpt failed: %v", err)
  1403  							}
  1404  							if ifoptGot != ifoptWant {
  1405  								c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
  1406  							}
  1407  						})
  1408  					}
  1409  				})
  1410  			}
  1411  		})
  1412  	}
  1413  }
  1414  
  1415  // TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
  1416  // Unreachable message when a udp datagram is received on ports for which there
  1417  // is no bound udp socket.
  1418  func TestV4UnknownDestination(t *testing.T) {
  1419  	c := newDualTestContext(t, defaultMTU)
  1420  	defer c.cleanup()
  1421  
  1422  	testCases := []struct {
  1423  		flow         testFlow
  1424  		icmpRequired bool
  1425  		// largePayload if true, will result in a payload large enough
  1426  		// so that the final generated IPv4 packet is larger than
  1427  		// header.IPv4MinimumProcessableDatagramSize.
  1428  		largePayload bool
  1429  	}{
  1430  		{unicastV4, true, false},
  1431  		{unicastV4, true, true},
  1432  		{multicastV4, false, false},
  1433  		{multicastV4, false, true},
  1434  		{broadcast, false, false},
  1435  		{broadcast, false, true},
  1436  	}
  1437  	for _, tc := range testCases {
  1438  		t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
  1439  			payload := newPayload()
  1440  			if tc.largePayload {
  1441  				payload = newMinPayload(576)
  1442  			}
  1443  			c.injectPacket(tc.flow, payload)
  1444  			if !tc.icmpRequired {
  1445  				select {
  1446  				case p := <-c.linkEP.C:
  1447  					t.Fatalf("unexpected packet received: %+v", p)
  1448  				case <-time.After(1 * time.Second):
  1449  					return
  1450  				}
  1451  			}
  1452  
  1453  			select {
  1454  			case p := <-c.linkEP.C:
  1455  				var pkt []byte
  1456  				pkt = append(pkt, p.Pkt.Header.View()...)
  1457  				pkt = append(pkt, p.Pkt.Data.ToView()...)
  1458  				if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
  1459  					t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
  1460  				}
  1461  
  1462  				hdr := header.IPv4(pkt)
  1463  				checker.IPv4(t, hdr, checker.ICMPv4(
  1464  					checker.ICMPv4Type(header.ICMPv4DstUnreachable),
  1465  					checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
  1466  
  1467  				icmpPkt := header.ICMPv4(hdr.Payload())
  1468  				payloadIPHeader := header.IPv4(icmpPkt.Payload())
  1469  				wantLen := len(payload)
  1470  				if tc.largePayload {
  1471  					wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
  1472  				}
  1473  
  1474  				// In case of large payloads the IP packet may be truncated. Update
  1475  				// the length field before retrieving the udp datagram payload.
  1476  				payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
  1477  
  1478  				origDgram := header.UDP(payloadIPHeader.Payload())
  1479  				if got, want := len(origDgram.Payload()), wantLen; got != want {
  1480  					t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
  1481  				}
  1482  				if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
  1483  					t.Fatalf("unexpected payload got: %d, want: %d", got, want)
  1484  				}
  1485  			case <-time.After(1 * time.Second):
  1486  				t.Fatalf("packet wasn't written out")
  1487  			}
  1488  		})
  1489  	}
  1490  }
  1491  
  1492  // TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
  1493  // Unreachable message when a udp datagram is received on ports for which there
  1494  // is no bound udp socket.
  1495  func TestV6UnknownDestination(t *testing.T) {
  1496  	c := newDualTestContext(t, defaultMTU)
  1497  	defer c.cleanup()
  1498  
  1499  	testCases := []struct {
  1500  		flow         testFlow
  1501  		icmpRequired bool
  1502  		// largePayload if true will result in a payload large enough to
  1503  		// create an IPv6 packet > header.IPv6MinimumMTU bytes.
  1504  		largePayload bool
  1505  	}{
  1506  		{unicastV6, true, false},
  1507  		{unicastV6, true, true},
  1508  		{multicastV6, false, false},
  1509  		{multicastV6, false, true},
  1510  	}
  1511  	for _, tc := range testCases {
  1512  		t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
  1513  			payload := newPayload()
  1514  			if tc.largePayload {
  1515  				payload = newMinPayload(1280)
  1516  			}
  1517  			c.injectPacket(tc.flow, payload)
  1518  			if !tc.icmpRequired {
  1519  				select {
  1520  				case p := <-c.linkEP.C:
  1521  					t.Fatalf("unexpected packet received: %+v", p)
  1522  				case <-time.After(1 * time.Second):
  1523  					return
  1524  				}
  1525  			}
  1526  
  1527  			select {
  1528  			case p := <-c.linkEP.C:
  1529  				var pkt []byte
  1530  				pkt = append(pkt, p.Pkt.Header.View()...)
  1531  				pkt = append(pkt, p.Pkt.Data.ToView()...)
  1532  				if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
  1533  					t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
  1534  				}
  1535  
  1536  				hdr := header.IPv6(pkt)
  1537  				checker.IPv6(t, hdr, checker.ICMPv6(
  1538  					checker.ICMPv6Type(header.ICMPv6DstUnreachable),
  1539  					checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
  1540  
  1541  				icmpPkt := header.ICMPv6(hdr.Payload())
  1542  				payloadIPHeader := header.IPv6(icmpPkt.Payload())
  1543  				wantLen := len(payload)
  1544  				if tc.largePayload {
  1545  					wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
  1546  				}
  1547  				// In case of large payloads the IP packet may be truncated. Update
  1548  				// the length field before retrieving the udp datagram payload.
  1549  				payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
  1550  
  1551  				origDgram := header.UDP(payloadIPHeader.Payload())
  1552  				if got, want := len(origDgram.Payload()), wantLen; got != want {
  1553  					t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
  1554  				}
  1555  				if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
  1556  					t.Fatalf("unexpected payload got: %v, want: %v", got, want)
  1557  				}
  1558  			case <-time.After(1 * time.Second):
  1559  				t.Fatalf("packet wasn't written out")
  1560  			}
  1561  		})
  1562  	}
  1563  }
  1564  
  1565  // TestIncrementMalformedPacketsReceived verifies if the malformed received
  1566  // global and endpoint stats get incremented.
  1567  func TestIncrementMalformedPacketsReceived(t *testing.T) {
  1568  	c := newDualTestContext(t, defaultMTU)
  1569  	defer c.cleanup()
  1570  
  1571  	c.createEndpoint(ipv6.ProtocolNumber)
  1572  	// Bind to wildcard.
  1573  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  1574  		c.t.Fatalf("Bind failed: %v", err)
  1575  	}
  1576  
  1577  	payload := newPayload()
  1578  	c.t.Helper()
  1579  	h := unicastV6.header4Tuple(incoming)
  1580  	c.injectV6Packet(payload, &h, false /* !valid */)
  1581  
  1582  	var want uint64 = 1
  1583  	if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
  1584  		t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
  1585  	}
  1586  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
  1587  		t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
  1588  	}
  1589  }
  1590  
  1591  // TestShutdownRead verifies endpoint read shutdown and error
  1592  // stats increment on packet receive.
  1593  func TestShutdownRead(t *testing.T) {
  1594  	c := newDualTestContext(t, defaultMTU)
  1595  	defer c.cleanup()
  1596  
  1597  	c.createEndpoint(ipv6.ProtocolNumber)
  1598  
  1599  	// Bind to wildcard.
  1600  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  1601  		c.t.Fatalf("Bind failed: %v", err)
  1602  	}
  1603  
  1604  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  1605  		c.t.Fatalf("Connect failed: %v", err)
  1606  	}
  1607  
  1608  	if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
  1609  		t.Fatalf("Shutdown failed: %v", err)
  1610  	}
  1611  
  1612  	testFailingRead(c, unicastV6, true /* expectReadError */)
  1613  
  1614  	var want uint64 = 1
  1615  	if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
  1616  		t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
  1617  	}
  1618  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
  1619  		t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
  1620  	}
  1621  }
  1622  
  1623  // TestShutdownWrite verifies endpoint write shutdown and error
  1624  // stats increment on packet write.
  1625  func TestShutdownWrite(t *testing.T) {
  1626  	c := newDualTestContext(t, defaultMTU)
  1627  	defer c.cleanup()
  1628  
  1629  	c.createEndpoint(ipv6.ProtocolNumber)
  1630  
  1631  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  1632  		c.t.Fatalf("Connect failed: %v", err)
  1633  	}
  1634  
  1635  	if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
  1636  		t.Fatalf("Shutdown failed: %v", err)
  1637  	}
  1638  
  1639  	testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
  1640  }
  1641  
  1642  func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
  1643  	got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
  1644  	switch err {
  1645  	case nil:
  1646  		want.PacketsSent.IncrementBy(incr)
  1647  	case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
  1648  		want.WriteErrors.InvalidArgs.IncrementBy(incr)
  1649  	case tcpip.ErrClosedForSend:
  1650  		want.WriteErrors.WriteClosed.IncrementBy(incr)
  1651  	case tcpip.ErrInvalidEndpointState:
  1652  		want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
  1653  	case tcpip.ErrNoLinkAddress:
  1654  		want.SendErrors.NoLinkAddr.IncrementBy(incr)
  1655  	case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
  1656  		want.SendErrors.NoRoute.IncrementBy(incr)
  1657  	default:
  1658  		want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
  1659  	}
  1660  	if got != want {
  1661  		c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
  1662  	}
  1663  }
  1664  
  1665  func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
  1666  	got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
  1667  	switch err {
  1668  	case nil, tcpip.ErrWouldBlock:
  1669  	case tcpip.ErrClosedForReceive:
  1670  		want.ReadErrors.ReadClosed.IncrementBy(incr)
  1671  	default:
  1672  		c.t.Errorf("Endpoint error missing stats update err %v", err)
  1673  	}
  1674  	if got != want {
  1675  		c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
  1676  	}
  1677  }