github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/udp/udp_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"math/rand"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/faketime"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/link/loopback"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/link/sniffer"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    36  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    37  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp"
    38  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    39  	"github.com/SagerNet/gvisor/pkg/waiter"
    40  )
    41  
    42  // Addresses and ports used for testing. It is recommended that tests stick to
    43  // using these addresses as it allows using the testFlow helper.
    44  // Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
    45  // represents the remote endpoint.
    46  const (
    47  	v4MappedAddrPrefix    = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
    48  	stackV6Addr           = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
    49  	testV6Addr            = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
    50  	stackV4MappedAddr     = v4MappedAddrPrefix + stackAddr
    51  	testV4MappedAddr      = v4MappedAddrPrefix + testAddr
    52  	multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
    53  	broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
    54  	v4MappedWildcardAddr  = v4MappedAddrPrefix + "\x00\x00\x00\x00"
    55  
    56  	stackAddr       = "\x0a\x00\x00\x01"
    57  	stackPort       = 1234
    58  	testAddr        = "\x0a\x00\x00\x02"
    59  	testPort        = 4096
    60  	invalidPort     = 8192
    61  	multicastAddr   = "\xe8\x2b\xd3\xea"
    62  	multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
    63  	broadcastAddr   = header.IPv4Broadcast
    64  	testTOS         = 0x80
    65  
    66  	// defaultMTU is the MTU, in bytes, used throughout the tests, except
    67  	// where another value is explicitly used. It is chosen to match the MTU
    68  	// of loopback interfaces on linux systems.
    69  	defaultMTU = 65536
    70  )
    71  
    72  // header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
    73  // a packet header. These values are used to populate a header or verify one.
    74  // Note that because they are used in packet headers, the addresses are never in
    75  // a V4-mapped format.
    76  type header4Tuple struct {
    77  	srcAddr tcpip.FullAddress
    78  	dstAddr tcpip.FullAddress
    79  }
    80  
    81  // testFlow implements a helper type used for sending and receiving test
    82  // packets. A given test flow value defines 1) the socket endpoint used for the
    83  // test and 2) the type of packet send or received on the endpoint. E.g., a
    84  // multicastV6Only flow is a V6 multicast packet passing through a V6-only
    85  // endpoint. The type provides helper methods to characterize the flow (e.g.,
    86  // isV4) as well as return a proper header4Tuple for it.
    87  type testFlow int
    88  
    89  const (
    90  	unicastV4         testFlow = iota // V4 unicast on a V4 socket
    91  	unicastV4in6                      // V4-mapped unicast on a V6-dual socket
    92  	unicastV6                         // V6 unicast on a V6 socket
    93  	unicastV6Only                     // V6 unicast on a V6-only socket
    94  	multicastV4                       // V4 multicast on a V4 socket
    95  	multicastV4in6                    // V4-mapped multicast on a V6-dual socket
    96  	multicastV6                       // V6 multicast on a V6 socket
    97  	multicastV6Only                   // V6 multicast on a V6-only socket
    98  	broadcast                         // V4 broadcast on a V4 socket
    99  	broadcastIn6                      // V4-mapped broadcast on a V6-dual socket
   100  	reverseMulticast4                 // V4 multicast src. Must fail.
   101  	reverseMulticast6                 // V6 multicast src. Must fail.
   102  )
   103  
   104  func (flow testFlow) String() string {
   105  	switch flow {
   106  	case unicastV4:
   107  		return "unicastV4"
   108  	case unicastV6:
   109  		return "unicastV6"
   110  	case unicastV6Only:
   111  		return "unicastV6Only"
   112  	case unicastV4in6:
   113  		return "unicastV4in6"
   114  	case multicastV4:
   115  		return "multicastV4"
   116  	case multicastV6:
   117  		return "multicastV6"
   118  	case multicastV6Only:
   119  		return "multicastV6Only"
   120  	case multicastV4in6:
   121  		return "multicastV4in6"
   122  	case broadcast:
   123  		return "broadcast"
   124  	case broadcastIn6:
   125  		return "broadcastIn6"
   126  	case reverseMulticast4:
   127  		return "reverseMulticast4"
   128  	case reverseMulticast6:
   129  		return "reverseMulticast6"
   130  	default:
   131  		return "unknown"
   132  	}
   133  }
   134  
   135  // packetDirection explains if a flow is incoming (read) or outgoing (write).
   136  type packetDirection int
   137  
   138  const (
   139  	incoming packetDirection = iota
   140  	outgoing
   141  )
   142  
   143  // header4Tuple returns the header4Tuple for the given flow and direction. Note
   144  // that the tuple contains no mapped addresses as those only exist at the socket
   145  // level but not at the packet header level.
   146  func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
   147  	var h header4Tuple
   148  	if flow.isV4() {
   149  		if d == outgoing {
   150  			h = header4Tuple{
   151  				srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
   152  				dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
   153  			}
   154  		} else {
   155  			h = header4Tuple{
   156  				srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
   157  				dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
   158  			}
   159  		}
   160  		if flow.isMulticast() {
   161  			h.dstAddr.Addr = multicastAddr
   162  		} else if flow.isBroadcast() {
   163  			h.dstAddr.Addr = broadcastAddr
   164  		}
   165  	} else { // IPv6
   166  		if d == outgoing {
   167  			h = header4Tuple{
   168  				srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
   169  				dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
   170  			}
   171  		} else {
   172  			h = header4Tuple{
   173  				srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
   174  				dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
   175  			}
   176  		}
   177  		if flow.isMulticast() {
   178  			h.dstAddr.Addr = multicastV6Addr
   179  		}
   180  	}
   181  	if flow.isReverseMulticast() {
   182  		h.srcAddr.Addr = flow.getMcastAddr()
   183  	}
   184  	return h
   185  }
   186  
   187  func (flow testFlow) getMcastAddr() tcpip.Address {
   188  	if flow.isV4() {
   189  		return multicastAddr
   190  	}
   191  	return multicastV6Addr
   192  }
   193  
   194  // mapAddrIfApplicable converts the given V4 address into its V4-mapped version
   195  // if it is applicable to the flow.
   196  func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
   197  	if flow.isMapped() {
   198  		return v4MappedAddrPrefix + v4Addr
   199  	}
   200  	return v4Addr
   201  }
   202  
   203  // netProto returns the protocol number used for the network packet.
   204  func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
   205  	if flow.isV4() {
   206  		return ipv4.ProtocolNumber
   207  	}
   208  	return ipv6.ProtocolNumber
   209  }
   210  
   211  // sockProto returns the protocol number used when creating the socket
   212  // endpoint for this flow.
   213  func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
   214  	switch flow {
   215  	case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
   216  		return ipv6.ProtocolNumber
   217  	case unicastV4, multicastV4, broadcast, reverseMulticast4:
   218  		return ipv4.ProtocolNumber
   219  	default:
   220  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   221  	}
   222  }
   223  
   224  func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
   225  	if flow.isV4() {
   226  		return checker.IPv4
   227  	}
   228  	return checker.IPv6
   229  }
   230  
   231  func (flow testFlow) isV6() bool { return !flow.isV4() }
   232  func (flow testFlow) isV4() bool {
   233  	return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
   234  }
   235  
   236  func (flow testFlow) isV6Only() bool {
   237  	switch flow {
   238  	case unicastV6Only, multicastV6Only:
   239  		return true
   240  	case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
   241  		return false
   242  	default:
   243  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   244  	}
   245  }
   246  
   247  func (flow testFlow) isMulticast() bool {
   248  	switch flow {
   249  	case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
   250  		return true
   251  	case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
   252  		return false
   253  	default:
   254  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   255  	}
   256  }
   257  
   258  func (flow testFlow) isBroadcast() bool {
   259  	switch flow {
   260  	case broadcast, broadcastIn6:
   261  		return true
   262  	case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
   263  		return false
   264  	default:
   265  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   266  	}
   267  }
   268  
   269  func (flow testFlow) isMapped() bool {
   270  	switch flow {
   271  	case unicastV4in6, multicastV4in6, broadcastIn6:
   272  		return true
   273  	case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
   274  		return false
   275  	default:
   276  		panic(fmt.Sprintf("invalid testFlow given: %d", flow))
   277  	}
   278  }
   279  
   280  func (flow testFlow) isReverseMulticast() bool {
   281  	switch flow {
   282  	case reverseMulticast4, reverseMulticast6:
   283  		return true
   284  	default:
   285  		return false
   286  	}
   287  }
   288  
   289  type testContext struct {
   290  	t      *testing.T
   291  	linkEP *channel.Endpoint
   292  	s      *stack.Stack
   293  
   294  	ep tcpip.Endpoint
   295  	wq waiter.Queue
   296  }
   297  
   298  func newDualTestContext(t *testing.T, mtu uint32) *testContext {
   299  	t.Helper()
   300  	return newDualTestContextWithHandleLocal(t, mtu, true)
   301  }
   302  
   303  func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
   304  	t.Helper()
   305  
   306  	options := stack.Options{
   307  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   308  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
   309  		HandleLocal:        handleLocal,
   310  		Clock:              &faketime.NullClock{},
   311  	}
   312  	s := stack.New(options)
   313  	ep := channel.New(256, mtu, "")
   314  	wep := stack.LinkEndpoint(ep)
   315  
   316  	if testing.Verbose() {
   317  		wep = sniffer.New(ep)
   318  	}
   319  	if err := s.CreateNIC(1, wep); err != nil {
   320  		t.Fatalf("CreateNIC failed: %s", err)
   321  	}
   322  
   323  	if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
   324  		t.Fatalf("AddAddress failed: %s", err)
   325  	}
   326  
   327  	if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
   328  		t.Fatalf("AddAddress failed: %s", err)
   329  	}
   330  
   331  	s.SetRouteTable([]tcpip.Route{
   332  		{
   333  			Destination: header.IPv4EmptySubnet,
   334  			NIC:         1,
   335  		},
   336  		{
   337  			Destination: header.IPv6EmptySubnet,
   338  			NIC:         1,
   339  		},
   340  	})
   341  
   342  	return &testContext{
   343  		t:      t,
   344  		s:      s,
   345  		linkEP: ep,
   346  	}
   347  }
   348  
   349  func (c *testContext) cleanup() {
   350  	if c.ep != nil {
   351  		c.ep.Close()
   352  	}
   353  }
   354  
   355  func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
   356  	c.t.Helper()
   357  
   358  	var err tcpip.Error
   359  	c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
   360  	if err != nil {
   361  		c.t.Fatal("NewEndpoint failed: ", err)
   362  	}
   363  }
   364  
   365  func (c *testContext) createEndpointForFlow(flow testFlow) {
   366  	c.t.Helper()
   367  
   368  	c.createEndpoint(flow.sockProto())
   369  	if flow.isV6Only() {
   370  		c.ep.SocketOptions().SetV6Only(true)
   371  	} else if flow.isBroadcast() {
   372  		c.ep.SocketOptions().SetBroadcast(true)
   373  	}
   374  }
   375  
   376  // getPacketAndVerify reads a packet from the link endpoint and verifies the
   377  // header against expected values from the given test flow. In addition, it
   378  // calls any extra checker functions provided.
   379  func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
   380  	c.t.Helper()
   381  
   382  	p, ok := c.linkEP.Read()
   383  	if !ok {
   384  		c.t.Fatalf("Packet wasn't written out")
   385  		return nil
   386  	}
   387  
   388  	if p.Proto != flow.netProto() {
   389  		c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
   390  	}
   391  
   392  	if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
   393  		c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
   394  	}
   395  
   396  	vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
   397  	b := vv.ToView()
   398  
   399  	h := flow.header4Tuple(outgoing)
   400  	checkers = append(
   401  		checkers,
   402  		checker.SrcAddr(h.srcAddr.Addr),
   403  		checker.DstAddr(h.dstAddr.Addr),
   404  		checker.UDP(checker.DstPort(h.dstAddr.Port)),
   405  	)
   406  	flow.checkerFn()(c.t, b, checkers...)
   407  	return b
   408  }
   409  
   410  // injectPacket creates a packet of the given flow and with the given payload,
   411  // and injects it into the link endpoint. If badChecksum is true, the packet has
   412  // a bad checksum in the UDP header.
   413  func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) {
   414  	c.t.Helper()
   415  
   416  	h := flow.header4Tuple(incoming)
   417  	if flow.isV4() {
   418  		buf := c.buildV4Packet(payload, &h)
   419  		if badChecksum {
   420  			// Invalidate the UDP header checksum field, taking care to avoid
   421  			// overflow to zero, which would disable checksum validation.
   422  			for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
   423  				u.SetChecksum(u.Checksum() + 1)
   424  				if u.Checksum() != 0 {
   425  					break
   426  				}
   427  			}
   428  		}
   429  		c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   430  			Data: buf.ToVectorisedView(),
   431  		}))
   432  	} else {
   433  		buf := c.buildV6Packet(payload, &h)
   434  		if badChecksum {
   435  			// Invalidate the UDP header checksum field (Unlike IPv4, zero is
   436  			// a valid checksum value for IPv6 so no need to avoid it).
   437  			u := header.UDP(buf[header.IPv6MinimumSize:])
   438  			u.SetChecksum(u.Checksum() + 1)
   439  		}
   440  		c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   441  			Data: buf.ToVectorisedView(),
   442  		}))
   443  	}
   444  }
   445  
   446  // buildV6Packet creates a V6 test packet with the given payload and header
   447  // values in a buffer.
   448  func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
   449  	// Allocate a buffer for data and headers.
   450  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
   451  	payloadStart := len(buf) - len(payload)
   452  	copy(buf[payloadStart:], payload)
   453  
   454  	// Initialize the IP header.
   455  	ip := header.IPv6(buf)
   456  	ip.Encode(&header.IPv6Fields{
   457  		TrafficClass:      testTOS,
   458  		PayloadLength:     uint16(header.UDPMinimumSize + len(payload)),
   459  		TransportProtocol: udp.ProtocolNumber,
   460  		HopLimit:          65,
   461  		SrcAddr:           h.srcAddr.Addr,
   462  		DstAddr:           h.dstAddr.Addr,
   463  	})
   464  
   465  	// Initialize the UDP header.
   466  	u := header.UDP(buf[header.IPv6MinimumSize:])
   467  	u.Encode(&header.UDPFields{
   468  		SrcPort: h.srcAddr.Port,
   469  		DstPort: h.dstAddr.Port,
   470  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   471  	})
   472  
   473  	// Calculate the UDP pseudo-header checksum.
   474  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
   475  
   476  	// Calculate the UDP checksum and set it.
   477  	xsum = header.Checksum(payload, xsum)
   478  	u.SetChecksum(^u.CalculateChecksum(xsum))
   479  
   480  	return buf
   481  }
   482  
   483  // buildV4Packet creates a V4 test packet with the given payload and header
   484  // values in a buffer.
   485  func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
   486  	// Allocate a buffer for data and headers.
   487  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
   488  	payloadStart := len(buf) - len(payload)
   489  	copy(buf[payloadStart:], payload)
   490  
   491  	// Initialize the IP header.
   492  	ip := header.IPv4(buf)
   493  	ip.Encode(&header.IPv4Fields{
   494  		TOS:         testTOS,
   495  		TotalLength: uint16(len(buf)),
   496  		TTL:         65,
   497  		Protocol:    uint8(udp.ProtocolNumber),
   498  		SrcAddr:     h.srcAddr.Addr,
   499  		DstAddr:     h.dstAddr.Addr,
   500  	})
   501  	ip.SetChecksum(^ip.CalculateChecksum())
   502  
   503  	// Initialize the UDP header.
   504  	u := header.UDP(buf[header.IPv4MinimumSize:])
   505  	u.Encode(&header.UDPFields{
   506  		SrcPort: h.srcAddr.Port,
   507  		DstPort: h.dstAddr.Port,
   508  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   509  	})
   510  
   511  	// Calculate the UDP pseudo-header checksum.
   512  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
   513  
   514  	// Calculate the UDP checksum and set it.
   515  	xsum = header.Checksum(payload, xsum)
   516  	u.SetChecksum(^u.CalculateChecksum(xsum))
   517  
   518  	return buf
   519  }
   520  
   521  func newPayload() []byte {
   522  	return newMinPayload(30)
   523  }
   524  
   525  func newMinPayload(minSize int) []byte {
   526  	b := make([]byte, minSize+rand.Intn(100))
   527  	for i := range b {
   528  		b[i] = byte(rand.Intn(256))
   529  	}
   530  	return b
   531  }
   532  
   533  func TestBindToDeviceOption(t *testing.T) {
   534  	s := stack.New(stack.Options{
   535  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   536  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   537  		Clock:              &faketime.NullClock{},
   538  	})
   539  
   540  	ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
   541  	if err != nil {
   542  		t.Fatalf("NewEndpoint failed; %s", err)
   543  	}
   544  	defer ep.Close()
   545  
   546  	opts := stack.NICOptions{Name: "my_device"}
   547  	if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
   548  		t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
   549  	}
   550  
   551  	// nicIDPtr is used instead of taking the address of NICID literals, which is
   552  	// a compiler error.
   553  	nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
   554  		return &s
   555  	}
   556  
   557  	testActions := []struct {
   558  		name                 string
   559  		setBindToDevice      *tcpip.NICID
   560  		setBindToDeviceError tcpip.Error
   561  		getBindToDevice      int32
   562  	}{
   563  		{"GetDefaultValue", nil, nil, 0},
   564  		{"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
   565  		{"BindToExistent", nicIDPtr(321), nil, 321},
   566  		{"UnbindToDevice", nicIDPtr(0), nil, 0},
   567  	}
   568  	for _, testAction := range testActions {
   569  		t.Run(testAction.name, func(t *testing.T) {
   570  			if testAction.setBindToDevice != nil {
   571  				bindToDevice := int32(*testAction.setBindToDevice)
   572  				if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
   573  					t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
   574  				}
   575  			}
   576  			bindToDevice := ep.SocketOptions().GetBindToDevice()
   577  			if bindToDevice != testAction.getBindToDevice {
   578  				t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
   579  			}
   580  		})
   581  	}
   582  }
   583  
   584  // testReadInternal sends a packet of the given test flow into the stack by
   585  // injecting it into the link endpoint. It then attempts to read it from the
   586  // UDP endpoint and depending on if this was expected to succeed verifies its
   587  // correctness including any additional checker functions provided.
   588  func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
   589  	c.t.Helper()
   590  
   591  	payload := newPayload()
   592  	c.injectPacket(flow, payload, false)
   593  
   594  	// Try to receive the data.
   595  	we, ch := waiter.NewChannelEntry(nil)
   596  	c.wq.EventRegister(&we, waiter.ReadableEvents)
   597  	defer c.wq.EventUnregister(&we)
   598  
   599  	// Take a snapshot of the stats to validate them at the end of the test.
   600  	epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
   601  
   602  	var buf bytes.Buffer
   603  	res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
   604  	if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   605  		// Wait for data to become available.
   606  		select {
   607  		case <-ch:
   608  			res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
   609  
   610  		default:
   611  			if packetShouldBeDropped {
   612  				return // expected to time out
   613  			}
   614  			c.t.Fatal("timed out waiting for data")
   615  		}
   616  	}
   617  
   618  	if expectReadError && err != nil {
   619  		c.checkEndpointReadStats(1, epstats, err)
   620  		return
   621  	}
   622  
   623  	if err != nil {
   624  		c.t.Fatal("Read failed:", err)
   625  	}
   626  
   627  	if packetShouldBeDropped {
   628  		c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
   629  	}
   630  
   631  	// Check the read result.
   632  	h := flow.header4Tuple(incoming)
   633  	if diff := cmp.Diff(tcpip.ReadResult{
   634  		Count:      buf.Len(),
   635  		Total:      buf.Len(),
   636  		RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
   637  	}, res, checker.IgnoreCmpPath(
   638  		"ControlMessages", // ControlMessages will be checked later.
   639  		"RemoteAddr.NIC",
   640  		"RemoteAddr.Port",
   641  	)); diff != "" {
   642  		c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
   643  	}
   644  
   645  	// Check the payload.
   646  	v := buf.Bytes()
   647  	if !bytes.Equal(payload, v) {
   648  		c.t.Fatalf("got payload = %x, want = %x", v, payload)
   649  	}
   650  
   651  	// Run any checkers against the ControlMessages.
   652  	for _, f := range checkers {
   653  		f(c.t, res.ControlMessages)
   654  	}
   655  
   656  	c.checkEndpointReadStats(1, epstats, err)
   657  }
   658  
   659  // testRead sends a packet of the given test flow into the stack by injecting it
   660  // into the link endpoint. It then reads it from the UDP endpoint and verifies
   661  // its correctness including any additional checker functions provided.
   662  func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
   663  	c.t.Helper()
   664  	testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
   665  }
   666  
   667  // testFailingRead sends a packet of the given test flow into the stack by
   668  // injecting it into the link endpoint. It then tries to read it from the UDP
   669  // endpoint and expects this to fail.
   670  func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
   671  	c.t.Helper()
   672  	testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
   673  }
   674  
   675  func TestBindEphemeralPort(t *testing.T) {
   676  	c := newDualTestContext(t, defaultMTU)
   677  	defer c.cleanup()
   678  
   679  	c.createEndpoint(ipv6.ProtocolNumber)
   680  
   681  	if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
   682  		t.Fatalf("ep.Bind(...) failed: %s", err)
   683  	}
   684  }
   685  
   686  func TestBindReservedPort(t *testing.T) {
   687  	c := newDualTestContext(t, defaultMTU)
   688  	defer c.cleanup()
   689  
   690  	c.createEndpoint(ipv6.ProtocolNumber)
   691  
   692  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
   693  		c.t.Fatalf("Connect failed: %s", err)
   694  	}
   695  
   696  	addr, err := c.ep.GetLocalAddress()
   697  	if err != nil {
   698  		t.Fatalf("GetLocalAddress failed: %s", err)
   699  	}
   700  
   701  	// We can't bind the address reserved by the connected endpoint above.
   702  	{
   703  		ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
   704  		if err != nil {
   705  			t.Fatalf("NewEndpoint failed: %s", err)
   706  		}
   707  		defer ep.Close()
   708  		{
   709  			err := ep.Bind(addr)
   710  			if _, ok := err.(*tcpip.ErrPortInUse); !ok {
   711  				t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
   712  			}
   713  		}
   714  	}
   715  
   716  	func() {
   717  		ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
   718  		if err != nil {
   719  			t.Fatalf("NewEndpoint failed: %s", err)
   720  		}
   721  		defer ep.Close()
   722  		// We can't bind ipv4-any on the port reserved by the connected endpoint
   723  		// above, since the endpoint is dual-stack.
   724  		{
   725  			err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
   726  			if _, ok := err.(*tcpip.ErrPortInUse); !ok {
   727  				t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
   728  			}
   729  		}
   730  		// We can bind an ipv4 address on this port, though.
   731  		if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
   732  			t.Fatalf("ep.Bind(...) failed: %s", err)
   733  		}
   734  	}()
   735  
   736  	// Once the connected endpoint releases its port reservation, we are able to
   737  	// bind ipv4-any once again.
   738  	c.ep.Close()
   739  	func() {
   740  		ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
   741  		if err != nil {
   742  			t.Fatalf("NewEndpoint failed: %s", err)
   743  		}
   744  		defer ep.Close()
   745  		if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
   746  			t.Fatalf("ep.Bind(...) failed: %s", err)
   747  		}
   748  	}()
   749  }
   750  
   751  func TestV4ReadOnV6(t *testing.T) {
   752  	c := newDualTestContext(t, defaultMTU)
   753  	defer c.cleanup()
   754  
   755  	c.createEndpointForFlow(unicastV4in6)
   756  
   757  	// Bind to wildcard.
   758  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   759  		c.t.Fatalf("Bind failed: %s", err)
   760  	}
   761  
   762  	// Test acceptance.
   763  	testRead(c, unicastV4in6)
   764  }
   765  
   766  func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
   767  	c := newDualTestContext(t, defaultMTU)
   768  	defer c.cleanup()
   769  
   770  	c.createEndpointForFlow(unicastV4in6)
   771  
   772  	// Bind to v4 mapped wildcard.
   773  	if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
   774  		c.t.Fatalf("Bind failed: %s", err)
   775  	}
   776  
   777  	// Test acceptance.
   778  	testRead(c, unicastV4in6)
   779  }
   780  
   781  func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
   782  	c := newDualTestContext(t, defaultMTU)
   783  	defer c.cleanup()
   784  
   785  	c.createEndpointForFlow(unicastV4in6)
   786  
   787  	// Bind to local address.
   788  	if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
   789  		c.t.Fatalf("Bind failed: %s", err)
   790  	}
   791  
   792  	// Test acceptance.
   793  	testRead(c, unicastV4in6)
   794  }
   795  
   796  func TestV6ReadOnV6(t *testing.T) {
   797  	c := newDualTestContext(t, defaultMTU)
   798  	defer c.cleanup()
   799  
   800  	c.createEndpointForFlow(unicastV6)
   801  
   802  	// Bind to wildcard.
   803  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   804  		c.t.Fatalf("Bind failed: %s", err)
   805  	}
   806  
   807  	// Test acceptance.
   808  	testRead(c, unicastV6)
   809  }
   810  
   811  // TestV4ReadSelfSource checks that packets coming from a local IP address are
   812  // correctly dropped when handleLocal is true and not otherwise.
   813  func TestV4ReadSelfSource(t *testing.T) {
   814  	for _, tt := range []struct {
   815  		name              string
   816  		handleLocal       bool
   817  		wantErr           tcpip.Error
   818  		wantInvalidSource uint64
   819  	}{
   820  		{"HandleLocal", false, nil, 0},
   821  		{"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
   822  	} {
   823  		t.Run(tt.name, func(t *testing.T) {
   824  			c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal)
   825  			defer c.cleanup()
   826  
   827  			c.createEndpointForFlow(unicastV4)
   828  
   829  			if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   830  				t.Fatalf("Bind failed: %s", err)
   831  			}
   832  
   833  			payload := newPayload()
   834  			h := unicastV4.header4Tuple(incoming)
   835  			h.srcAddr = h.dstAddr
   836  
   837  			buf := c.buildV4Packet(payload, &h)
   838  			c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   839  				Data: buf.ToVectorisedView(),
   840  			}))
   841  
   842  			if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
   843  				t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
   844  			}
   845  
   846  			if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr {
   847  				t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
   848  			}
   849  		})
   850  	}
   851  }
   852  
   853  func TestV4ReadOnV4(t *testing.T) {
   854  	c := newDualTestContext(t, defaultMTU)
   855  	defer c.cleanup()
   856  
   857  	c.createEndpointForFlow(unicastV4)
   858  
   859  	// Bind to wildcard.
   860  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   861  		c.t.Fatalf("Bind failed: %s", err)
   862  	}
   863  
   864  	// Test acceptance.
   865  	testRead(c, unicastV4)
   866  }
   867  
   868  // TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
   869  // address and receive data sent to that address.
   870  func TestReadOnBoundToMulticast(t *testing.T) {
   871  	// FIXME(b/128189410): multicastV4in6 currently doesn't work as
   872  	// AddMembershipOption doesn't handle V4in6 addresses.
   873  	for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
   874  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   875  			c := newDualTestContext(t, defaultMTU)
   876  			defer c.cleanup()
   877  
   878  			c.createEndpointForFlow(flow)
   879  
   880  			// Bind to multicast address.
   881  			mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
   882  			if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
   883  				c.t.Fatal("Bind failed:", err)
   884  			}
   885  
   886  			// Join multicast group.
   887  			ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
   888  			if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
   889  				c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
   890  			}
   891  
   892  			// Check that we receive multicast packets but not unicast or broadcast
   893  			// ones.
   894  			testRead(c, flow)
   895  			testFailingRead(c, broadcast, false /* expectReadError */)
   896  			testFailingRead(c, unicastV4, false /* expectReadError */)
   897  		})
   898  	}
   899  }
   900  
   901  // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
   902  // address and can receive only broadcast data.
   903  func TestV4ReadOnBoundToBroadcast(t *testing.T) {
   904  	for _, flow := range []testFlow{broadcast, broadcastIn6} {
   905  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   906  			c := newDualTestContext(t, defaultMTU)
   907  			defer c.cleanup()
   908  
   909  			c.createEndpointForFlow(flow)
   910  
   911  			// Bind to broadcast address.
   912  			bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
   913  			if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
   914  				c.t.Fatalf("Bind failed: %s", err)
   915  			}
   916  
   917  			// Check that we receive broadcast packets but not unicast ones.
   918  			testRead(c, flow)
   919  			testFailingRead(c, unicastV4, false /* expectReadError */)
   920  		})
   921  	}
   922  }
   923  
   924  // TestReadFromMulticast checks that an endpoint will NOT receive a packet
   925  // that was sent with multicast SOURCE address.
   926  func TestReadFromMulticast(t *testing.T) {
   927  	for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
   928  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   929  			c := newDualTestContext(t, defaultMTU)
   930  			defer c.cleanup()
   931  
   932  			c.createEndpointForFlow(flow)
   933  
   934  			if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   935  				t.Fatalf("Bind failed: %s", err)
   936  			}
   937  			testFailingRead(c, flow, false /* expectReadError */)
   938  		})
   939  	}
   940  }
   941  
   942  // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
   943  // and receive broadcast and unicast data.
   944  func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
   945  	for _, flow := range []testFlow{broadcast, broadcastIn6} {
   946  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   947  			c := newDualTestContext(t, defaultMTU)
   948  			defer c.cleanup()
   949  
   950  			c.createEndpointForFlow(flow)
   951  
   952  			// Bind to wildcard.
   953  			if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
   954  				c.t.Fatalf("Bind failed: %s (", err)
   955  			}
   956  
   957  			// Check that we receive both broadcast and unicast packets.
   958  			testRead(c, flow)
   959  			testRead(c, unicastV4)
   960  		})
   961  	}
   962  }
   963  
   964  // testFailingWrite sends a packet of the given test flow into the UDP endpoint
   965  // and verifies it fails with the provided error code.
   966  func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) {
   967  	c.t.Helper()
   968  	// Take a snapshot of the stats to validate them at the end of the test.
   969  	epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
   970  	h := flow.header4Tuple(outgoing)
   971  	writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
   972  
   973  	var r bytes.Reader
   974  	r.Reset(newPayload())
   975  	_, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
   976  		To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
   977  	})
   978  	c.checkEndpointWriteStats(1, epstats, gotErr)
   979  	if gotErr != wantErr {
   980  		c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
   981  	}
   982  }
   983  
   984  // testWrite sends a packet of the given test flow from the UDP endpoint to the
   985  // flow's destination address:port. It then receives it from the link endpoint
   986  // and verifies its correctness including any additional checker functions
   987  // provided.
   988  func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
   989  	c.t.Helper()
   990  	return testWriteAndVerifyInternal(c, flow, true, checkers...)
   991  }
   992  
   993  // testWriteWithoutDestination sends a packet of the given test flow from the
   994  // UDP endpoint without giving a destination address:port. It then receives it
   995  // from the link endpoint and verifies its correctness including any additional
   996  // checker functions provided.
   997  func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
   998  	c.t.Helper()
   999  	return testWriteAndVerifyInternal(c, flow, false, checkers...)
  1000  }
  1001  
  1002  func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
  1003  	c.t.Helper()
  1004  	// Take a snapshot of the stats to validate them at the end of the test.
  1005  	epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
  1006  
  1007  	writeOpts := tcpip.WriteOptions{}
  1008  	if setDest {
  1009  		h := flow.header4Tuple(outgoing)
  1010  		writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
  1011  		writeOpts = tcpip.WriteOptions{
  1012  			To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
  1013  		}
  1014  	}
  1015  	var r bytes.Reader
  1016  	payload := newPayload()
  1017  	r.Reset(payload)
  1018  	n, err := c.ep.Write(&r, writeOpts)
  1019  	if err != nil {
  1020  		c.t.Fatalf("Write failed: %s", err)
  1021  	}
  1022  	if n != int64(len(payload)) {
  1023  		c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
  1024  	}
  1025  	c.checkEndpointWriteStats(1, epstats, err)
  1026  	return payload
  1027  }
  1028  
  1029  func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
  1030  	c.t.Helper()
  1031  	payload := testWriteNoVerify(c, flow, setDest)
  1032  	// Received the packet and check the payload.
  1033  	b := c.getPacketAndVerify(flow, checkers...)
  1034  	var udpH header.UDP
  1035  	if flow.isV4() {
  1036  		udpH = header.IPv4(b).Payload()
  1037  	} else {
  1038  		udpH = header.IPv6(b).Payload()
  1039  	}
  1040  	if !bytes.Equal(payload, udpH.Payload()) {
  1041  		c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload)
  1042  	}
  1043  
  1044  	return udpH.SourcePort()
  1045  }
  1046  
  1047  func testDualWrite(c *testContext) uint16 {
  1048  	c.t.Helper()
  1049  
  1050  	v4Port := testWrite(c, unicastV4in6)
  1051  	v6Port := testWrite(c, unicastV6)
  1052  	if v4Port != v6Port {
  1053  		c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
  1054  	}
  1055  
  1056  	return v4Port
  1057  }
  1058  
  1059  func TestDualWriteUnbound(t *testing.T) {
  1060  	c := newDualTestContext(t, defaultMTU)
  1061  	defer c.cleanup()
  1062  
  1063  	c.createEndpoint(ipv6.ProtocolNumber)
  1064  
  1065  	testDualWrite(c)
  1066  }
  1067  
  1068  func TestDualWriteBoundToWildcard(t *testing.T) {
  1069  	c := newDualTestContext(t, defaultMTU)
  1070  	defer c.cleanup()
  1071  
  1072  	c.createEndpoint(ipv6.ProtocolNumber)
  1073  
  1074  	// Bind to wildcard.
  1075  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  1076  		c.t.Fatalf("Bind failed: %s", err)
  1077  	}
  1078  
  1079  	p := testDualWrite(c)
  1080  	if p != stackPort {
  1081  		c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
  1082  	}
  1083  }
  1084  
  1085  func TestDualWriteConnectedToV6(t *testing.T) {
  1086  	c := newDualTestContext(t, defaultMTU)
  1087  	defer c.cleanup()
  1088  
  1089  	c.createEndpoint(ipv6.ProtocolNumber)
  1090  
  1091  	// Connect to v6 address.
  1092  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  1093  		c.t.Fatalf("Bind failed: %s", err)
  1094  	}
  1095  
  1096  	testWrite(c, unicastV6)
  1097  
  1098  	// Write to V4 mapped address.
  1099  	testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{})
  1100  	const want = 1
  1101  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
  1102  		c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
  1103  	}
  1104  }
  1105  
  1106  func TestDualWriteConnectedToV4Mapped(t *testing.T) {
  1107  	c := newDualTestContext(t, defaultMTU)
  1108  	defer c.cleanup()
  1109  
  1110  	c.createEndpoint(ipv6.ProtocolNumber)
  1111  
  1112  	// Connect to v4 mapped address.
  1113  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
  1114  		c.t.Fatalf("Bind failed: %s", err)
  1115  	}
  1116  
  1117  	testWrite(c, unicastV4in6)
  1118  
  1119  	// Write to v6 address.
  1120  	testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
  1121  }
  1122  
  1123  func TestV4WriteOnV6Only(t *testing.T) {
  1124  	c := newDualTestContext(t, defaultMTU)
  1125  	defer c.cleanup()
  1126  
  1127  	c.createEndpointForFlow(unicastV6Only)
  1128  
  1129  	// Write to V4 mapped address.
  1130  	testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{})
  1131  }
  1132  
  1133  func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
  1134  	c := newDualTestContext(t, defaultMTU)
  1135  	defer c.cleanup()
  1136  
  1137  	c.createEndpoint(ipv6.ProtocolNumber)
  1138  
  1139  	// Bind to v4 mapped address.
  1140  	if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
  1141  		c.t.Fatalf("Bind failed: %s", err)
  1142  	}
  1143  
  1144  	// Write to v6 address.
  1145  	testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
  1146  }
  1147  
  1148  func TestV6WriteOnConnected(t *testing.T) {
  1149  	c := newDualTestContext(t, defaultMTU)
  1150  	defer c.cleanup()
  1151  
  1152  	c.createEndpoint(ipv6.ProtocolNumber)
  1153  
  1154  	// Connect to v6 address.
  1155  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  1156  		c.t.Fatalf("Connect failed: %s", err)
  1157  	}
  1158  
  1159  	testWriteWithoutDestination(c, unicastV6)
  1160  }
  1161  
  1162  func TestV4WriteOnConnected(t *testing.T) {
  1163  	c := newDualTestContext(t, defaultMTU)
  1164  	defer c.cleanup()
  1165  
  1166  	c.createEndpoint(ipv6.ProtocolNumber)
  1167  
  1168  	// Connect to v4 mapped address.
  1169  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
  1170  		c.t.Fatalf("Connect failed: %s", err)
  1171  	}
  1172  
  1173  	testWriteWithoutDestination(c, unicastV4)
  1174  }
  1175  
  1176  func TestWriteOnConnectedInvalidPort(t *testing.T) {
  1177  	protocols := map[string]tcpip.NetworkProtocolNumber{
  1178  		"ipv4": ipv4.ProtocolNumber,
  1179  		"ipv6": ipv6.ProtocolNumber,
  1180  	}
  1181  	for name, pn := range protocols {
  1182  		t.Run(name, func(t *testing.T) {
  1183  			c := newDualTestContext(t, defaultMTU)
  1184  			defer c.cleanup()
  1185  
  1186  			c.createEndpoint(pn)
  1187  			if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
  1188  				c.t.Fatalf("Connect failed: %s", err)
  1189  			}
  1190  			writeOpts := tcpip.WriteOptions{
  1191  				To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
  1192  			}
  1193  			var r bytes.Reader
  1194  			payload := newPayload()
  1195  			r.Reset(payload)
  1196  			n, err := c.ep.Write(&r, writeOpts)
  1197  			if err != nil {
  1198  				c.t.Fatalf("c.ep.Write(...) = %s, want nil", err)
  1199  			}
  1200  			if got, want := n, int64(len(payload)); got != want {
  1201  				c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
  1202  			}
  1203  
  1204  			{
  1205  				err := c.ep.LastError()
  1206  				if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
  1207  					c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
  1208  				}
  1209  			}
  1210  		})
  1211  	}
  1212  }
  1213  
  1214  // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
  1215  // that is bound to a V4 multicast address.
  1216  func TestWriteOnBoundToV4Multicast(t *testing.T) {
  1217  	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
  1218  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1219  			c := newDualTestContext(t, defaultMTU)
  1220  			defer c.cleanup()
  1221  
  1222  			c.createEndpointForFlow(flow)
  1223  
  1224  			// Bind to V4 mcast address.
  1225  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
  1226  				c.t.Fatal("Bind failed:", err)
  1227  			}
  1228  
  1229  			testWrite(c, flow)
  1230  		})
  1231  	}
  1232  }
  1233  
  1234  // TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
  1235  // socket that is bound to a V4-mapped multicast address.
  1236  func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
  1237  	for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
  1238  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1239  			c := newDualTestContext(t, defaultMTU)
  1240  			defer c.cleanup()
  1241  
  1242  			c.createEndpointForFlow(flow)
  1243  
  1244  			// Bind to V4Mapped mcast address.
  1245  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
  1246  				c.t.Fatalf("Bind failed: %s", err)
  1247  			}
  1248  
  1249  			testWrite(c, flow)
  1250  		})
  1251  	}
  1252  }
  1253  
  1254  // TestWriteOnBoundToV6Multicast checks that we can send packets out of a
  1255  // socket that is bound to a V6 multicast address.
  1256  func TestWriteOnBoundToV6Multicast(t *testing.T) {
  1257  	for _, flow := range []testFlow{unicastV6, multicastV6} {
  1258  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1259  			c := newDualTestContext(t, defaultMTU)
  1260  			defer c.cleanup()
  1261  
  1262  			c.createEndpointForFlow(flow)
  1263  
  1264  			// Bind to V6 mcast address.
  1265  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
  1266  				c.t.Fatalf("Bind failed: %s", err)
  1267  			}
  1268  
  1269  			testWrite(c, flow)
  1270  		})
  1271  	}
  1272  }
  1273  
  1274  // TestWriteOnBoundToV6Multicast checks that we can send packets out of a
  1275  // V6-only socket that is bound to a V6 multicast address.
  1276  func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
  1277  	for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
  1278  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1279  			c := newDualTestContext(t, defaultMTU)
  1280  			defer c.cleanup()
  1281  
  1282  			c.createEndpointForFlow(flow)
  1283  
  1284  			// Bind to V6 mcast address.
  1285  			if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
  1286  				c.t.Fatalf("Bind failed: %s", err)
  1287  			}
  1288  
  1289  			testWrite(c, flow)
  1290  		})
  1291  	}
  1292  }
  1293  
  1294  // TestWriteOnBoundToBroadcast checks that we can send packets out of a
  1295  // socket that is bound to the broadcast address.
  1296  func TestWriteOnBoundToBroadcast(t *testing.T) {
  1297  	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
  1298  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1299  			c := newDualTestContext(t, defaultMTU)
  1300  			defer c.cleanup()
  1301  
  1302  			c.createEndpointForFlow(flow)
  1303  
  1304  			// Bind to V4 broadcast address.
  1305  			if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
  1306  				c.t.Fatal("Bind failed:", err)
  1307  			}
  1308  
  1309  			testWrite(c, flow)
  1310  		})
  1311  	}
  1312  }
  1313  
  1314  // TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
  1315  // socket that is bound to the V4-mapped broadcast address.
  1316  func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
  1317  	for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
  1318  		t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
  1319  			c := newDualTestContext(t, defaultMTU)
  1320  			defer c.cleanup()
  1321  
  1322  			c.createEndpointForFlow(flow)
  1323  
  1324  			// Bind to V4Mapped mcast address.
  1325  			if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
  1326  				c.t.Fatalf("Bind failed: %s", err)
  1327  			}
  1328  
  1329  			testWrite(c, flow)
  1330  		})
  1331  	}
  1332  }
  1333  
  1334  func TestReadIncrementsPacketsReceived(t *testing.T) {
  1335  	c := newDualTestContext(t, defaultMTU)
  1336  	defer c.cleanup()
  1337  
  1338  	// Create IPv4 UDP endpoint
  1339  	c.createEndpoint(ipv6.ProtocolNumber)
  1340  
  1341  	// Bind to wildcard.
  1342  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  1343  		c.t.Fatalf("Bind failed: %s", err)
  1344  	}
  1345  
  1346  	testRead(c, unicastV4)
  1347  
  1348  	var want uint64 = 1
  1349  	if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
  1350  		c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
  1351  	}
  1352  }
  1353  
  1354  func TestReadIPPacketInfo(t *testing.T) {
  1355  	tests := []struct {
  1356  		name              string
  1357  		proto             tcpip.NetworkProtocolNumber
  1358  		flow              testFlow
  1359  		expectedLocalAddr tcpip.Address
  1360  		expectedDestAddr  tcpip.Address
  1361  	}{
  1362  		{
  1363  			name:              "IPv4 unicast",
  1364  			proto:             header.IPv4ProtocolNumber,
  1365  			flow:              unicastV4,
  1366  			expectedLocalAddr: stackAddr,
  1367  			expectedDestAddr:  stackAddr,
  1368  		},
  1369  		{
  1370  			name:  "IPv4 multicast",
  1371  			proto: header.IPv4ProtocolNumber,
  1372  			flow:  multicastV4,
  1373  			// This should actually be a unicast address assigned to the interface.
  1374  			//
  1375  			// TODO(github.com/SagerNet/issue/3556): This check is validating incorrect
  1376  			// behaviour. We still include the test so that once the bug is
  1377  			// resolved, this test will start to fail and the individual tasked
  1378  			// with fixing this bug knows to also fix this test :).
  1379  			expectedLocalAddr: multicastAddr,
  1380  			expectedDestAddr:  multicastAddr,
  1381  		},
  1382  		{
  1383  			name:  "IPv4 broadcast",
  1384  			proto: header.IPv4ProtocolNumber,
  1385  			flow:  broadcast,
  1386  			// This should actually be a unicast address assigned to the interface.
  1387  			//
  1388  			// TODO(github.com/SagerNet/issue/3556): This check is validating incorrect
  1389  			// behaviour. We still include the test so that once the bug is
  1390  			// resolved, this test will start to fail and the individual tasked
  1391  			// with fixing this bug knows to also fix this test :).
  1392  			expectedLocalAddr: broadcastAddr,
  1393  			expectedDestAddr:  broadcastAddr,
  1394  		},
  1395  		{
  1396  			name:              "IPv6 unicast",
  1397  			proto:             header.IPv6ProtocolNumber,
  1398  			flow:              unicastV6,
  1399  			expectedLocalAddr: stackV6Addr,
  1400  			expectedDestAddr:  stackV6Addr,
  1401  		},
  1402  		{
  1403  			name:  "IPv6 multicast",
  1404  			proto: header.IPv6ProtocolNumber,
  1405  			flow:  multicastV6,
  1406  			// This should actually be a unicast address assigned to the interface.
  1407  			//
  1408  			// TODO(github.com/SagerNet/issue/3556): This check is validating incorrect
  1409  			// behaviour. We still include the test so that once the bug is
  1410  			// resolved, this test will start to fail and the individual tasked
  1411  			// with fixing this bug knows to also fix this test :).
  1412  			expectedLocalAddr: multicastV6Addr,
  1413  			expectedDestAddr:  multicastV6Addr,
  1414  		},
  1415  	}
  1416  
  1417  	for _, test := range tests {
  1418  		t.Run(test.name, func(t *testing.T) {
  1419  			c := newDualTestContext(t, defaultMTU)
  1420  			defer c.cleanup()
  1421  
  1422  			c.createEndpoint(test.proto)
  1423  
  1424  			bindAddr := tcpip.FullAddress{Port: stackPort}
  1425  			if err := c.ep.Bind(bindAddr); err != nil {
  1426  				t.Fatalf("Bind(%+v): %s", bindAddr, err)
  1427  			}
  1428  
  1429  			if test.flow.isMulticast() {
  1430  				ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
  1431  				if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
  1432  					c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
  1433  				}
  1434  			}
  1435  
  1436  			c.ep.SocketOptions().SetReceivePacketInfo(true)
  1437  
  1438  			testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
  1439  				NIC:             1,
  1440  				LocalAddr:       test.expectedLocalAddr,
  1441  				DestinationAddr: test.expectedDestAddr,
  1442  			}))
  1443  
  1444  			if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
  1445  				t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
  1446  			}
  1447  		})
  1448  	}
  1449  }
  1450  
  1451  func TestReadRecvOriginalDstAddr(t *testing.T) {
  1452  	tests := []struct {
  1453  		name                    string
  1454  		proto                   tcpip.NetworkProtocolNumber
  1455  		flow                    testFlow
  1456  		expectedOriginalDstAddr tcpip.FullAddress
  1457  	}{
  1458  		{
  1459  			name:                    "IPv4 unicast",
  1460  			proto:                   header.IPv4ProtocolNumber,
  1461  			flow:                    unicastV4,
  1462  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort},
  1463  		},
  1464  		{
  1465  			name:  "IPv4 multicast",
  1466  			proto: header.IPv4ProtocolNumber,
  1467  			flow:  multicastV4,
  1468  			// This should actually be a unicast address assigned to the interface.
  1469  			//
  1470  			// TODO(github.com/SagerNet/issue/3556): This check is validating incorrect
  1471  			// behaviour. We still include the test so that once the bug is
  1472  			// resolved, this test will start to fail and the individual tasked
  1473  			// with fixing this bug knows to also fix this test :).
  1474  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort},
  1475  		},
  1476  		{
  1477  			name:  "IPv4 broadcast",
  1478  			proto: header.IPv4ProtocolNumber,
  1479  			flow:  broadcast,
  1480  			// This should actually be a unicast address assigned to the interface.
  1481  			//
  1482  			// TODO(github.com/SagerNet/issue/3556): This check is validating incorrect
  1483  			// behaviour. We still include the test so that once the bug is
  1484  			// resolved, this test will start to fail and the individual tasked
  1485  			// with fixing this bug knows to also fix this test :).
  1486  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort},
  1487  		},
  1488  		{
  1489  			name:                    "IPv6 unicast",
  1490  			proto:                   header.IPv6ProtocolNumber,
  1491  			flow:                    unicastV6,
  1492  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort},
  1493  		},
  1494  		{
  1495  			name:  "IPv6 multicast",
  1496  			proto: header.IPv6ProtocolNumber,
  1497  			flow:  multicastV6,
  1498  			// This should actually be a unicast address assigned to the interface.
  1499  			//
  1500  			// TODO(github.com/SagerNet/issue/3556): This check is validating incorrect
  1501  			// behaviour. We still include the test so that once the bug is
  1502  			// resolved, this test will start to fail and the individual tasked
  1503  			// with fixing this bug knows to also fix this test :).
  1504  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort},
  1505  		},
  1506  	}
  1507  
  1508  	for _, test := range tests {
  1509  		t.Run(test.name, func(t *testing.T) {
  1510  			c := newDualTestContext(t, defaultMTU)
  1511  			defer c.cleanup()
  1512  
  1513  			c.createEndpoint(test.proto)
  1514  
  1515  			bindAddr := tcpip.FullAddress{Port: stackPort}
  1516  			if err := c.ep.Bind(bindAddr); err != nil {
  1517  				t.Fatalf("Bind(%#v): %s", bindAddr, err)
  1518  			}
  1519  
  1520  			if test.flow.isMulticast() {
  1521  				ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
  1522  				if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
  1523  					c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
  1524  				}
  1525  			}
  1526  
  1527  			c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
  1528  
  1529  			testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
  1530  
  1531  			if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
  1532  				t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
  1533  			}
  1534  		})
  1535  	}
  1536  }
  1537  
  1538  func TestWriteIncrementsPacketsSent(t *testing.T) {
  1539  	c := newDualTestContext(t, defaultMTU)
  1540  	defer c.cleanup()
  1541  
  1542  	c.createEndpoint(ipv6.ProtocolNumber)
  1543  
  1544  	testDualWrite(c)
  1545  
  1546  	var want uint64 = 2
  1547  	if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
  1548  		c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
  1549  	}
  1550  }
  1551  
  1552  func TestNoChecksum(t *testing.T) {
  1553  	for _, flow := range []testFlow{unicastV4, unicastV6} {
  1554  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1555  			c := newDualTestContext(t, defaultMTU)
  1556  			defer c.cleanup()
  1557  
  1558  			c.createEndpointForFlow(flow)
  1559  
  1560  			// Disable the checksum generation.
  1561  			c.ep.SocketOptions().SetNoChecksum(true)
  1562  			// This option is effective on IPv4 only.
  1563  			testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
  1564  
  1565  			// Enable the checksum generation.
  1566  			c.ep.SocketOptions().SetNoChecksum(false)
  1567  			testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
  1568  		})
  1569  	}
  1570  }
  1571  
  1572  var _ stack.NetworkInterface = (*testInterface)(nil)
  1573  
  1574  type testInterface struct {
  1575  	stack.NetworkInterface
  1576  }
  1577  
  1578  func (*testInterface) ID() tcpip.NICID {
  1579  	return 0
  1580  }
  1581  
  1582  func (*testInterface) Enabled() bool {
  1583  	return true
  1584  }
  1585  
  1586  func TestTTL(t *testing.T) {
  1587  	for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
  1588  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1589  			c := newDualTestContext(t, defaultMTU)
  1590  			defer c.cleanup()
  1591  
  1592  			c.createEndpointForFlow(flow)
  1593  
  1594  			const multicastTTL = 42
  1595  			if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
  1596  				c.t.Fatalf("SetSockOptInt failed: %s", err)
  1597  			}
  1598  
  1599  			var wantTTL uint8
  1600  			if flow.isMulticast() {
  1601  				wantTTL = multicastTTL
  1602  			} else {
  1603  				var p stack.NetworkProtocolFactory
  1604  				var n tcpip.NetworkProtocolNumber
  1605  				if flow.isV4() {
  1606  					p = ipv4.NewProtocol
  1607  					n = ipv4.ProtocolNumber
  1608  				} else {
  1609  					p = ipv6.NewProtocol
  1610  					n = ipv6.ProtocolNumber
  1611  				}
  1612  				s := stack.New(stack.Options{
  1613  					NetworkProtocols: []stack.NetworkProtocolFactory{p},
  1614  					Clock:            &faketime.NullClock{},
  1615  				})
  1616  				ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil)
  1617  				wantTTL = ep.DefaultTTL()
  1618  				ep.Close()
  1619  			}
  1620  
  1621  			testWrite(c, flow, checker.TTL(wantTTL))
  1622  		})
  1623  	}
  1624  }
  1625  
  1626  func TestSetTTL(t *testing.T) {
  1627  	for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
  1628  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1629  			for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
  1630  				t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
  1631  					c := newDualTestContext(t, defaultMTU)
  1632  					defer c.cleanup()
  1633  
  1634  					c.createEndpointForFlow(flow)
  1635  
  1636  					if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
  1637  						c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
  1638  					}
  1639  
  1640  					testWrite(c, flow, checker.TTL(wantTTL))
  1641  				})
  1642  			}
  1643  		})
  1644  	}
  1645  }
  1646  
  1647  func TestSetTOS(t *testing.T) {
  1648  	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
  1649  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1650  			c := newDualTestContext(t, defaultMTU)
  1651  			defer c.cleanup()
  1652  
  1653  			c.createEndpointForFlow(flow)
  1654  
  1655  			const tos = testTOS
  1656  			v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
  1657  			if err != nil {
  1658  				c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
  1659  			}
  1660  			// Test for expected default value.
  1661  			if v != 0 {
  1662  				c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
  1663  			}
  1664  
  1665  			if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
  1666  				c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
  1667  			}
  1668  
  1669  			v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
  1670  			if err != nil {
  1671  				c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
  1672  			}
  1673  
  1674  			if v != tos {
  1675  				c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
  1676  			}
  1677  
  1678  			testWrite(c, flow, checker.TOS(tos, 0))
  1679  		})
  1680  	}
  1681  }
  1682  
  1683  func TestSetTClass(t *testing.T) {
  1684  	for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
  1685  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1686  			c := newDualTestContext(t, defaultMTU)
  1687  			defer c.cleanup()
  1688  
  1689  			c.createEndpointForFlow(flow)
  1690  
  1691  			const tClass = testTOS
  1692  			v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
  1693  			if err != nil {
  1694  				c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
  1695  			}
  1696  			// Test for expected default value.
  1697  			if v != 0 {
  1698  				c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
  1699  			}
  1700  
  1701  			if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
  1702  				c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
  1703  			}
  1704  
  1705  			v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
  1706  			if err != nil {
  1707  				c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
  1708  			}
  1709  
  1710  			if v != tClass {
  1711  				c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
  1712  			}
  1713  
  1714  			// The header getter for TClass is called TOS, so use that checker.
  1715  			testWrite(c, flow, checker.TOS(tClass, 0))
  1716  		})
  1717  	}
  1718  }
  1719  
  1720  func TestReceiveTosTClass(t *testing.T) {
  1721  	const RcvTOSOpt = "ReceiveTosOption"
  1722  	const RcvTClassOpt = "ReceiveTClassOption"
  1723  
  1724  	testCases := []struct {
  1725  		name  string
  1726  		tests []testFlow
  1727  	}{
  1728  		{RcvTOSOpt, []testFlow{unicastV4, broadcast}},
  1729  		{RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
  1730  	}
  1731  	for _, testCase := range testCases {
  1732  		for _, flow := range testCase.tests {
  1733  			t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
  1734  				c := newDualTestContext(t, defaultMTU)
  1735  				defer c.cleanup()
  1736  
  1737  				c.createEndpointForFlow(flow)
  1738  				name := testCase.name
  1739  
  1740  				var optionGetter func() bool
  1741  				var optionSetter func(bool)
  1742  				switch name {
  1743  				case RcvTOSOpt:
  1744  					optionGetter = c.ep.SocketOptions().GetReceiveTOS
  1745  					optionSetter = c.ep.SocketOptions().SetReceiveTOS
  1746  				case RcvTClassOpt:
  1747  					optionGetter = c.ep.SocketOptions().GetReceiveTClass
  1748  					optionSetter = c.ep.SocketOptions().SetReceiveTClass
  1749  				default:
  1750  					t.Fatalf("unkown test variant: %s", name)
  1751  				}
  1752  
  1753  				// Verify that setting and reading the option works.
  1754  				v := optionGetter()
  1755  				// Test for expected default value.
  1756  				if v != false {
  1757  					c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
  1758  				}
  1759  
  1760  				const want = true
  1761  				optionSetter(want)
  1762  
  1763  				got := optionGetter()
  1764  				if got != want {
  1765  					c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
  1766  				}
  1767  
  1768  				// Verify that the correct received TOS or TClass is handed through as
  1769  				// ancillary data to the ControlMessages struct.
  1770  				if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  1771  					c.t.Fatalf("Bind failed: %s", err)
  1772  				}
  1773  				switch name {
  1774  				case RcvTClassOpt:
  1775  					testRead(c, flow, checker.ReceiveTClass(testTOS))
  1776  				case RcvTOSOpt:
  1777  					testRead(c, flow, checker.ReceiveTOS(testTOS))
  1778  				default:
  1779  					t.Fatalf("unknown test variant: %s", name)
  1780  				}
  1781  			})
  1782  		}
  1783  	}
  1784  }
  1785  
  1786  func TestMulticastInterfaceOption(t *testing.T) {
  1787  	for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
  1788  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1789  			for _, bindTyp := range []string{"bound", "unbound"} {
  1790  				t.Run(bindTyp, func(t *testing.T) {
  1791  					for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
  1792  						t.Run(optTyp, func(t *testing.T) {
  1793  							h := flow.header4Tuple(outgoing)
  1794  							mcastAddr := h.dstAddr.Addr
  1795  							localIfAddr := h.srcAddr.Addr
  1796  
  1797  							var ifoptSet tcpip.MulticastInterfaceOption
  1798  							switch optTyp {
  1799  							case "use local-addr":
  1800  								ifoptSet.InterfaceAddr = localIfAddr
  1801  							case "use NICID":
  1802  								ifoptSet.NIC = 1
  1803  							case "use local-addr and NIC":
  1804  								ifoptSet.InterfaceAddr = localIfAddr
  1805  								ifoptSet.NIC = 1
  1806  							default:
  1807  								t.Fatal("unknown test variant")
  1808  							}
  1809  
  1810  							c := newDualTestContext(t, defaultMTU)
  1811  							defer c.cleanup()
  1812  
  1813  							c.createEndpoint(flow.sockProto())
  1814  
  1815  							if bindTyp == "bound" {
  1816  								// Bind the socket by connecting to the multicast address.
  1817  								// This may have an influence on how the multicast interface
  1818  								// is set.
  1819  								addr := tcpip.FullAddress{
  1820  									Addr: flow.mapAddrIfApplicable(mcastAddr),
  1821  									Port: stackPort,
  1822  								}
  1823  								if err := c.ep.Connect(addr); err != nil {
  1824  									c.t.Fatalf("Connect failed: %s", err)
  1825  								}
  1826  							}
  1827  
  1828  							if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
  1829  								c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
  1830  							}
  1831  
  1832  							// Verify multicast interface addr and NIC were set correctly.
  1833  							// Note that NIC must be 1 since this is our outgoing interface.
  1834  							var ifoptGot tcpip.MulticastInterfaceOption
  1835  							if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
  1836  								c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
  1837  							} else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
  1838  								c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
  1839  							}
  1840  						})
  1841  					}
  1842  				})
  1843  			}
  1844  		})
  1845  	}
  1846  }
  1847  
  1848  // TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
  1849  // Unreachable message when a udp datagram is received on ports for which there
  1850  // is no bound udp socket.
  1851  func TestV4UnknownDestination(t *testing.T) {
  1852  	c := newDualTestContext(t, defaultMTU)
  1853  	defer c.cleanup()
  1854  
  1855  	testCases := []struct {
  1856  		flow         testFlow
  1857  		icmpRequired bool
  1858  		// largePayload if true, will result in a payload large enough
  1859  		// so that the final generated IPv4 packet is larger than
  1860  		// header.IPv4MinimumProcessableDatagramSize.
  1861  		largePayload bool
  1862  		// badChecksum if true, will set an invalid checksum in the
  1863  		// header.
  1864  		badChecksum bool
  1865  	}{
  1866  		{unicastV4, true, false, false},
  1867  		{unicastV4, true, true, false},
  1868  		{unicastV4, false, false, true},
  1869  		{unicastV4, false, true, true},
  1870  		{multicastV4, false, false, false},
  1871  		{multicastV4, false, true, false},
  1872  		{broadcast, false, false, false},
  1873  		{broadcast, false, true, false},
  1874  	}
  1875  	checksumErrors := uint64(0)
  1876  	for _, tc := range testCases {
  1877  		t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
  1878  			payload := newPayload()
  1879  			if tc.largePayload {
  1880  				payload = newMinPayload(576)
  1881  			}
  1882  			c.injectPacket(tc.flow, payload, tc.badChecksum)
  1883  			if tc.badChecksum {
  1884  				checksumErrors++
  1885  				if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
  1886  					t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1887  				}
  1888  			}
  1889  			if !tc.icmpRequired {
  1890  				if p, ok := c.linkEP.Read(); ok {
  1891  					t.Fatalf("unexpected packet received: %+v", p)
  1892  				}
  1893  				return
  1894  			}
  1895  
  1896  			// ICMP required.
  1897  			p, ok := c.linkEP.Read()
  1898  			if !ok {
  1899  				t.Fatalf("packet wasn't written out")
  1900  				return
  1901  			}
  1902  
  1903  			vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
  1904  			pkt := vv.ToView()
  1905  			if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
  1906  				t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
  1907  			}
  1908  
  1909  			hdr := header.IPv4(pkt)
  1910  			checker.IPv4(t, hdr, checker.ICMPv4(
  1911  				checker.ICMPv4Type(header.ICMPv4DstUnreachable),
  1912  				checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
  1913  
  1914  			// We need to compare the included data part of the UDP packet that is in
  1915  			// the ICMP packet with the matching original data.
  1916  			icmpPkt := header.ICMPv4(hdr.Payload())
  1917  			payloadIPHeader := header.IPv4(icmpPkt.Payload())
  1918  			incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
  1919  			wantLen := len(payload)
  1920  			if tc.largePayload {
  1921  				// To work out the data size we need to simulate what the sender would
  1922  				// have done. The wanted size is the total available minus the sum of
  1923  				// the headers in the UDP AND ICMP packets, given that we know the test
  1924  				// had only a minimal IP header but the ICMP sender will have allowed
  1925  				// for a maximally sized packet header.
  1926  				wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
  1927  			}
  1928  
  1929  			// In the case of large payloads the IP packet may be truncated. Update
  1930  			// the length field before retrieving the udp datagram payload.
  1931  			// Add back the two headers within the payload.
  1932  			payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
  1933  
  1934  			origDgram := header.UDP(payloadIPHeader.Payload())
  1935  			if got, want := len(origDgram.Payload()), wantLen; got != want {
  1936  				t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
  1937  			}
  1938  			if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
  1939  				t.Fatalf("unexpected payload got: %d, want: %d", got, want)
  1940  			}
  1941  		})
  1942  	}
  1943  }
  1944  
  1945  // TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
  1946  // Unreachable message when a udp datagram is received on ports for which there
  1947  // is no bound udp socket.
  1948  func TestV6UnknownDestination(t *testing.T) {
  1949  	c := newDualTestContext(t, defaultMTU)
  1950  	defer c.cleanup()
  1951  
  1952  	testCases := []struct {
  1953  		flow         testFlow
  1954  		icmpRequired bool
  1955  		// largePayload if true will result in a payload large enough to
  1956  		// create an IPv6 packet > header.IPv6MinimumMTU bytes.
  1957  		largePayload bool
  1958  		// badChecksum if true, will set an invalid checksum in the
  1959  		// header.
  1960  		badChecksum bool
  1961  	}{
  1962  		{unicastV6, true, false, false},
  1963  		{unicastV6, true, true, false},
  1964  		{unicastV6, false, false, true},
  1965  		{unicastV6, false, true, true},
  1966  		{multicastV6, false, false, false},
  1967  		{multicastV6, false, true, false},
  1968  	}
  1969  	checksumErrors := uint64(0)
  1970  	for _, tc := range testCases {
  1971  		t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
  1972  			payload := newPayload()
  1973  			if tc.largePayload {
  1974  				payload = newMinPayload(1280)
  1975  			}
  1976  			c.injectPacket(tc.flow, payload, tc.badChecksum)
  1977  			if tc.badChecksum {
  1978  				checksumErrors++
  1979  				if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
  1980  					t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1981  				}
  1982  			}
  1983  			if !tc.icmpRequired {
  1984  				if p, ok := c.linkEP.Read(); ok {
  1985  					t.Fatalf("unexpected packet received: %+v", p)
  1986  				}
  1987  				return
  1988  			}
  1989  
  1990  			// ICMP required.
  1991  			p, ok := c.linkEP.Read()
  1992  			if !ok {
  1993  				t.Fatalf("packet wasn't written out")
  1994  				return
  1995  			}
  1996  
  1997  			vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
  1998  			pkt := vv.ToView()
  1999  			if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
  2000  				t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
  2001  			}
  2002  
  2003  			hdr := header.IPv6(pkt)
  2004  			checker.IPv6(t, hdr, checker.ICMPv6(
  2005  				checker.ICMPv6Type(header.ICMPv6DstUnreachable),
  2006  				checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
  2007  
  2008  			icmpPkt := header.ICMPv6(hdr.Payload())
  2009  			payloadIPHeader := header.IPv6(icmpPkt.Payload())
  2010  			wantLen := len(payload)
  2011  			if tc.largePayload {
  2012  				wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
  2013  			}
  2014  			// In case of large payloads the IP packet may be truncated. Update
  2015  			// the length field before retrieving the udp datagram payload.
  2016  			payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
  2017  
  2018  			origDgram := header.UDP(payloadIPHeader.Payload())
  2019  			if got, want := len(origDgram.Payload()), wantLen; got != want {
  2020  				t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
  2021  			}
  2022  			if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
  2023  				t.Fatalf("unexpected payload got: %v, want: %v", got, want)
  2024  			}
  2025  		})
  2026  	}
  2027  }
  2028  
  2029  // TestIncrementMalformedPacketsReceived verifies if the malformed received
  2030  // global and endpoint stats are incremented.
  2031  func TestIncrementMalformedPacketsReceived(t *testing.T) {
  2032  	c := newDualTestContext(t, defaultMTU)
  2033  	defer c.cleanup()
  2034  
  2035  	c.createEndpoint(ipv6.ProtocolNumber)
  2036  	// Bind to wildcard.
  2037  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2038  		c.t.Fatalf("Bind failed: %s", err)
  2039  	}
  2040  
  2041  	payload := newPayload()
  2042  	h := unicastV6.header4Tuple(incoming)
  2043  	buf := c.buildV6Packet(payload, &h)
  2044  
  2045  	// Invalidate the UDP header length field.
  2046  	u := header.UDP(buf[header.IPv6MinimumSize:])
  2047  	u.SetLength(u.Length() + 1)
  2048  
  2049  	c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2050  		Data: buf.ToVectorisedView(),
  2051  	}))
  2052  
  2053  	const want = 1
  2054  	if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
  2055  		t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
  2056  	}
  2057  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
  2058  		t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
  2059  	}
  2060  }
  2061  
  2062  // TestShortHeader verifies that when a packet with a too-short UDP header is
  2063  // received, the malformed received global stat gets incremented.
  2064  func TestShortHeader(t *testing.T) {
  2065  	c := newDualTestContext(t, defaultMTU)
  2066  	defer c.cleanup()
  2067  
  2068  	c.createEndpoint(ipv6.ProtocolNumber)
  2069  	// Bind to wildcard.
  2070  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2071  		c.t.Fatalf("Bind failed: %s", err)
  2072  	}
  2073  
  2074  	h := unicastV6.header4Tuple(incoming)
  2075  
  2076  	// Allocate a buffer for an IPv6 and too-short UDP header.
  2077  	const udpSize = header.UDPMinimumSize - 1
  2078  	buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
  2079  	// Initialize the IP header.
  2080  	ip := header.IPv6(buf)
  2081  	ip.Encode(&header.IPv6Fields{
  2082  		TrafficClass:      testTOS,
  2083  		PayloadLength:     uint16(udpSize),
  2084  		TransportProtocol: udp.ProtocolNumber,
  2085  		HopLimit:          65,
  2086  		SrcAddr:           h.srcAddr.Addr,
  2087  		DstAddr:           h.dstAddr.Addr,
  2088  	})
  2089  
  2090  	// Initialize the UDP header.
  2091  	udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
  2092  	udpHdr.Encode(&header.UDPFields{
  2093  		SrcPort: h.srcAddr.Port,
  2094  		DstPort: h.dstAddr.Port,
  2095  		Length:  header.UDPMinimumSize,
  2096  	})
  2097  	// Calculate the UDP pseudo-header checksum.
  2098  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
  2099  	udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
  2100  	// Copy all but the last byte of the UDP header into the packet.
  2101  	copy(buf[header.IPv6MinimumSize:], udpHdr)
  2102  
  2103  	// Inject packet.
  2104  	c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2105  		Data: buf.ToVectorisedView(),
  2106  	}))
  2107  
  2108  	if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want {
  2109  		t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want)
  2110  	}
  2111  }
  2112  
  2113  // TestBadChecksumErrors verifies if a checksum error is detected,
  2114  // global and endpoint stats are incremented.
  2115  func TestBadChecksumErrors(t *testing.T) {
  2116  	for _, flow := range []testFlow{unicastV4, unicastV6} {
  2117  		t.Run(flow.String(), func(t *testing.T) {
  2118  			c := newDualTestContext(t, defaultMTU)
  2119  			defer c.cleanup()
  2120  
  2121  			c.createEndpoint(flow.sockProto())
  2122  			// Bind to wildcard.
  2123  			if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2124  				c.t.Fatalf("Bind failed: %s", err)
  2125  			}
  2126  
  2127  			payload := newPayload()
  2128  			c.injectPacket(flow, payload, true /* badChecksum */)
  2129  
  2130  			const want = 1
  2131  			if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
  2132  				t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  2133  			}
  2134  			if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  2135  				t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  2136  			}
  2137  		})
  2138  	}
  2139  }
  2140  
  2141  // TestPayloadModifiedV4 verifies if a checksum error is detected,
  2142  // global and endpoint stats are incremented.
  2143  func TestPayloadModifiedV4(t *testing.T) {
  2144  	c := newDualTestContext(t, defaultMTU)
  2145  	defer c.cleanup()
  2146  
  2147  	c.createEndpoint(ipv4.ProtocolNumber)
  2148  	// Bind to wildcard.
  2149  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2150  		c.t.Fatalf("Bind failed: %s", err)
  2151  	}
  2152  
  2153  	payload := newPayload()
  2154  	h := unicastV4.header4Tuple(incoming)
  2155  	buf := c.buildV4Packet(payload, &h)
  2156  	// Modify the payload so that the checksum value in the UDP header will be
  2157  	// incorrect.
  2158  	buf[len(buf)-1]++
  2159  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2160  		Data: buf.ToVectorisedView(),
  2161  	}))
  2162  
  2163  	const want = 1
  2164  	if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
  2165  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  2166  	}
  2167  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  2168  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  2169  	}
  2170  }
  2171  
  2172  // TestPayloadModifiedV6 verifies if a checksum error is detected,
  2173  // global and endpoint stats are incremented.
  2174  func TestPayloadModifiedV6(t *testing.T) {
  2175  	c := newDualTestContext(t, defaultMTU)
  2176  	defer c.cleanup()
  2177  
  2178  	c.createEndpoint(ipv6.ProtocolNumber)
  2179  	// Bind to wildcard.
  2180  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2181  		c.t.Fatalf("Bind failed: %s", err)
  2182  	}
  2183  
  2184  	payload := newPayload()
  2185  	h := unicastV6.header4Tuple(incoming)
  2186  	buf := c.buildV6Packet(payload, &h)
  2187  	// Modify the payload so that the checksum value in the UDP header will be
  2188  	// incorrect.
  2189  	buf[len(buf)-1]++
  2190  	c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2191  		Data: buf.ToVectorisedView(),
  2192  	}))
  2193  
  2194  	const want = 1
  2195  	if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
  2196  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  2197  	}
  2198  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  2199  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  2200  	}
  2201  }
  2202  
  2203  // TestChecksumZeroV4 verifies if the checksum value is zero, global and
  2204  // endpoint states are *not* incremented (UDP checksum is optional on IPv4).
  2205  func TestChecksumZeroV4(t *testing.T) {
  2206  	c := newDualTestContext(t, defaultMTU)
  2207  	defer c.cleanup()
  2208  
  2209  	c.createEndpoint(ipv4.ProtocolNumber)
  2210  	// Bind to wildcard.
  2211  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2212  		c.t.Fatalf("Bind failed: %s", err)
  2213  	}
  2214  
  2215  	payload := newPayload()
  2216  	h := unicastV4.header4Tuple(incoming)
  2217  	buf := c.buildV4Packet(payload, &h)
  2218  	// Set the checksum field in the UDP header to zero.
  2219  	u := header.UDP(buf[header.IPv4MinimumSize:])
  2220  	u.SetChecksum(0)
  2221  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2222  		Data: buf.ToVectorisedView(),
  2223  	}))
  2224  
  2225  	const want = 0
  2226  	if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
  2227  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  2228  	}
  2229  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  2230  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  2231  	}
  2232  }
  2233  
  2234  // TestChecksumZeroV6 verifies if the checksum value is zero, global and
  2235  // endpoint states are incremented (UDP checksum is *not* optional on IPv6).
  2236  func TestChecksumZeroV6(t *testing.T) {
  2237  	c := newDualTestContext(t, defaultMTU)
  2238  	defer c.cleanup()
  2239  
  2240  	c.createEndpoint(ipv6.ProtocolNumber)
  2241  	// Bind to wildcard.
  2242  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2243  		c.t.Fatalf("Bind failed: %s", err)
  2244  	}
  2245  
  2246  	payload := newPayload()
  2247  	h := unicastV6.header4Tuple(incoming)
  2248  	buf := c.buildV6Packet(payload, &h)
  2249  	// Set the checksum field in the UDP header to zero.
  2250  	u := header.UDP(buf[header.IPv6MinimumSize:])
  2251  	u.SetChecksum(0)
  2252  	c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2253  		Data: buf.ToVectorisedView(),
  2254  	}))
  2255  
  2256  	const want = 1
  2257  	if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
  2258  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  2259  	}
  2260  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  2261  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  2262  	}
  2263  }
  2264  
  2265  // TestShutdownRead verifies endpoint read shutdown and error
  2266  // stats increment on packet receive.
  2267  func TestShutdownRead(t *testing.T) {
  2268  	c := newDualTestContext(t, defaultMTU)
  2269  	defer c.cleanup()
  2270  
  2271  	c.createEndpoint(ipv6.ProtocolNumber)
  2272  
  2273  	// Bind to wildcard.
  2274  	if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
  2275  		c.t.Fatalf("Bind failed: %s", err)
  2276  	}
  2277  
  2278  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  2279  		c.t.Fatalf("Connect failed: %s", err)
  2280  	}
  2281  
  2282  	if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
  2283  		t.Fatalf("Shutdown failed: %s", err)
  2284  	}
  2285  
  2286  	testFailingRead(c, unicastV6, true /* expectReadError */)
  2287  
  2288  	var want uint64 = 1
  2289  	if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
  2290  		t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
  2291  	}
  2292  	if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
  2293  		t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
  2294  	}
  2295  }
  2296  
  2297  // TestShutdownWrite verifies endpoint write shutdown and error
  2298  // stats increment on packet write.
  2299  func TestShutdownWrite(t *testing.T) {
  2300  	c := newDualTestContext(t, defaultMTU)
  2301  	defer c.cleanup()
  2302  
  2303  	c.createEndpoint(ipv6.ProtocolNumber)
  2304  
  2305  	if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
  2306  		c.t.Fatalf("Connect failed: %s", err)
  2307  	}
  2308  
  2309  	if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
  2310  		t.Fatalf("Shutdown failed: %s", err)
  2311  	}
  2312  
  2313  	testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{})
  2314  }
  2315  
  2316  func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
  2317  	got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
  2318  	switch err.(type) {
  2319  	case nil:
  2320  		want.PacketsSent.IncrementBy(incr)
  2321  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
  2322  		want.WriteErrors.InvalidArgs.IncrementBy(incr)
  2323  	case *tcpip.ErrClosedForSend:
  2324  		want.WriteErrors.WriteClosed.IncrementBy(incr)
  2325  	case *tcpip.ErrInvalidEndpointState:
  2326  		want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
  2327  	case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
  2328  		want.SendErrors.NoRoute.IncrementBy(incr)
  2329  	default:
  2330  		want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
  2331  	}
  2332  	if got != want {
  2333  		c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
  2334  	}
  2335  }
  2336  
  2337  func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
  2338  	got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
  2339  	switch err.(type) {
  2340  	case nil, *tcpip.ErrWouldBlock:
  2341  	case *tcpip.ErrClosedForReceive:
  2342  		want.ReadErrors.ReadClosed.IncrementBy(incr)
  2343  	default:
  2344  		c.t.Errorf("Endpoint error missing stats update err %v", err)
  2345  	}
  2346  	if got != want {
  2347  		c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
  2348  	}
  2349  }
  2350  
  2351  func TestOutgoingSubnetBroadcast(t *testing.T) {
  2352  	const nicID1 = 1
  2353  
  2354  	ipv4Addr := tcpip.AddressWithPrefix{
  2355  		Address:   "\xc0\xa8\x01\x3a",
  2356  		PrefixLen: 24,
  2357  	}
  2358  	ipv4Subnet := ipv4Addr.Subnet()
  2359  	ipv4SubnetBcast := ipv4Subnet.Broadcast()
  2360  	ipv4Gateway := testutil.MustParse4("192.168.1.1")
  2361  	ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
  2362  		Address:   "\xc0\xa8\x01\x3a",
  2363  		PrefixLen: 31,
  2364  	}
  2365  	ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
  2366  	ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
  2367  	ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
  2368  		Address:   "\xc0\xa8\x01\x3a",
  2369  		PrefixLen: 32,
  2370  	}
  2371  	ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
  2372  	ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
  2373  	ipv6Addr := tcpip.AddressWithPrefix{
  2374  		Address:   "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
  2375  		PrefixLen: 64,
  2376  	}
  2377  	ipv6Subnet := ipv6Addr.Subnet()
  2378  	ipv6SubnetBcast := ipv6Subnet.Broadcast()
  2379  	remNetAddr := tcpip.AddressWithPrefix{
  2380  		Address:   "\x64\x0a\x7b\x18",
  2381  		PrefixLen: 24,
  2382  	}
  2383  	remNetSubnet := remNetAddr.Subnet()
  2384  	remNetSubnetBcast := remNetSubnet.Broadcast()
  2385  
  2386  	tests := []struct {
  2387  		name                 string
  2388  		nicAddr              tcpip.ProtocolAddress
  2389  		routes               []tcpip.Route
  2390  		remoteAddr           tcpip.Address
  2391  		requiresBroadcastOpt bool
  2392  	}{
  2393  		{
  2394  			name: "IPv4 Broadcast to local subnet",
  2395  			nicAddr: tcpip.ProtocolAddress{
  2396  				Protocol:          header.IPv4ProtocolNumber,
  2397  				AddressWithPrefix: ipv4Addr,
  2398  			},
  2399  			routes: []tcpip.Route{
  2400  				{
  2401  					Destination: ipv4Subnet,
  2402  					NIC:         nicID1,
  2403  				},
  2404  			},
  2405  			remoteAddr:           ipv4SubnetBcast,
  2406  			requiresBroadcastOpt: true,
  2407  		},
  2408  		{
  2409  			name: "IPv4 Broadcast to local /31 subnet",
  2410  			nicAddr: tcpip.ProtocolAddress{
  2411  				Protocol:          header.IPv4ProtocolNumber,
  2412  				AddressWithPrefix: ipv4AddrPrefix31,
  2413  			},
  2414  			routes: []tcpip.Route{
  2415  				{
  2416  					Destination: ipv4Subnet31,
  2417  					NIC:         nicID1,
  2418  				},
  2419  			},
  2420  			remoteAddr:           ipv4Subnet31Bcast,
  2421  			requiresBroadcastOpt: false,
  2422  		},
  2423  		{
  2424  			name: "IPv4 Broadcast to local /32 subnet",
  2425  			nicAddr: tcpip.ProtocolAddress{
  2426  				Protocol:          header.IPv4ProtocolNumber,
  2427  				AddressWithPrefix: ipv4AddrPrefix32,
  2428  			},
  2429  			routes: []tcpip.Route{
  2430  				{
  2431  					Destination: ipv4Subnet32,
  2432  					NIC:         nicID1,
  2433  				},
  2434  			},
  2435  			remoteAddr:           ipv4Subnet32Bcast,
  2436  			requiresBroadcastOpt: false,
  2437  		},
  2438  		// IPv6 has no notion of a broadcast.
  2439  		{
  2440  			name: "IPv6 'Broadcast' to local subnet",
  2441  			nicAddr: tcpip.ProtocolAddress{
  2442  				Protocol:          header.IPv6ProtocolNumber,
  2443  				AddressWithPrefix: ipv6Addr,
  2444  			},
  2445  			routes: []tcpip.Route{
  2446  				{
  2447  					Destination: ipv6Subnet,
  2448  					NIC:         nicID1,
  2449  				},
  2450  			},
  2451  			remoteAddr:           ipv6SubnetBcast,
  2452  			requiresBroadcastOpt: false,
  2453  		},
  2454  		{
  2455  			name: "IPv4 Broadcast to remote subnet",
  2456  			nicAddr: tcpip.ProtocolAddress{
  2457  				Protocol:          header.IPv4ProtocolNumber,
  2458  				AddressWithPrefix: ipv4Addr,
  2459  			},
  2460  			routes: []tcpip.Route{
  2461  				{
  2462  					Destination: remNetSubnet,
  2463  					Gateway:     ipv4Gateway,
  2464  					NIC:         nicID1,
  2465  				},
  2466  			},
  2467  			remoteAddr: remNetSubnetBcast,
  2468  			// TODO(github.com/SagerNet/issue/3938): Once we support marking a route as
  2469  			// broadcast, this test should require the broadcast option to be set.
  2470  			requiresBroadcastOpt: false,
  2471  		},
  2472  	}
  2473  
  2474  	for _, test := range tests {
  2475  		t.Run(test.name, func(t *testing.T) {
  2476  			s := stack.New(stack.Options{
  2477  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  2478  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
  2479  				Clock:              &faketime.NullClock{},
  2480  			})
  2481  			e := channel.New(0, defaultMTU, "")
  2482  			if err := s.CreateNIC(nicID1, e); err != nil {
  2483  				t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
  2484  			}
  2485  			if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
  2486  				t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
  2487  			}
  2488  
  2489  			s.SetRouteTable(test.routes)
  2490  
  2491  			var netProto tcpip.NetworkProtocolNumber
  2492  			switch l := len(test.remoteAddr); l {
  2493  			case header.IPv4AddressSize:
  2494  				netProto = header.IPv4ProtocolNumber
  2495  			case header.IPv6AddressSize:
  2496  				netProto = header.IPv6ProtocolNumber
  2497  			default:
  2498  				t.Fatalf("got unexpected address length = %d bytes", l)
  2499  			}
  2500  
  2501  			wq := waiter.Queue{}
  2502  			ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
  2503  			if err != nil {
  2504  				t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
  2505  			}
  2506  			defer ep.Close()
  2507  
  2508  			var r bytes.Reader
  2509  			data := []byte{1, 2, 3, 4}
  2510  			to := tcpip.FullAddress{
  2511  				Addr: test.remoteAddr,
  2512  				Port: 80,
  2513  			}
  2514  			opts := tcpip.WriteOptions{To: &to}
  2515  			expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
  2516  				if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
  2517  					return nil
  2518  				}
  2519  				return &tcpip.ErrBroadcastDisabled{}
  2520  			}
  2521  			if !test.requiresBroadcastOpt {
  2522  				expectedErrWithoutBcastOpt = nil
  2523  			}
  2524  
  2525  			r.Reset(data)
  2526  			{
  2527  				n, err := ep.Write(&r, opts)
  2528  				if expectedErrWithoutBcastOpt != nil {
  2529  					if want := expectedErrWithoutBcastOpt(err); want != nil {
  2530  						t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
  2531  					}
  2532  				} else if err != nil {
  2533  					t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
  2534  				}
  2535  			}
  2536  
  2537  			ep.SocketOptions().SetBroadcast(true)
  2538  
  2539  			r.Reset(data)
  2540  			if n, err := ep.Write(&r, opts); err != nil {
  2541  				t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
  2542  			}
  2543  
  2544  			ep.SocketOptions().SetBroadcast(false)
  2545  
  2546  			r.Reset(data)
  2547  			{
  2548  				n, err := ep.Write(&r, opts)
  2549  				if expectedErrWithoutBcastOpt != nil {
  2550  					if want := expectedErrWithoutBcastOpt(err); want != nil {
  2551  						t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
  2552  					}
  2553  				} else if err != nil {
  2554  					t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
  2555  				}
  2556  			}
  2557  		})
  2558  	}
  2559  }