gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/udp/udp_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp_test
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/binary"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"math"
    23  	"math/rand"
    24  	"os"
    25  	"testing"
    26  
    27  	"gvisor.dev/gvisor/pkg/buffer"
    28  	"gvisor.dev/gvisor/pkg/refs"
    29  	"gvisor.dev/gvisor/pkg/tcpip"
    30  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    31  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    32  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    33  	"gvisor.dev/gvisor/pkg/tcpip/header"
    34  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    35  	"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
    36  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    37  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    38  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    39  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    40  	"gvisor.dev/gvisor/pkg/tcpip/transport"
    41  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    42  	"gvisor.dev/gvisor/pkg/tcpip/transport/testing/context"
    43  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    44  	"gvisor.dev/gvisor/pkg/waiter"
    45  )
    46  
    47  const (
    48  	testTOS              = 0x80
    49  	testTTL              = 65
    50  	arbitraryPayloadSize = 30
    51  )
    52  
    53  // newRandomPayload returns a payload with the specified size and with
    54  // randomized content.
    55  func newRandomPayload(size int) []byte {
    56  	b := make([]byte, size)
    57  	for i := range b {
    58  		b[i] = byte(rand.Intn(math.MaxUint8 + 1))
    59  	}
    60  	return b
    61  }
    62  
    63  func testRead(c *context.Context, flow context.TestFlow, checkers ...checker.ControlMessagesChecker) {
    64  	c.T.Helper()
    65  
    66  	payload := newRandomPayload(arbitraryPayloadSize)
    67  	c.InjectPacket(flow.NetProto(), context.BuildUDPPacket(payload, flow, context.Incoming, testTOS, testTTL, false))
    68  	c.ReadFromEndpointExpectSuccess(payload, flow, checkers...)
    69  }
    70  
    71  func testFailingRead(c *context.Context, flow context.TestFlow, expectReadError bool) {
    72  	c.T.Helper()
    73  
    74  	c.InjectPacket(flow.NetProto(), context.BuildUDPPacket(newRandomPayload(arbitraryPayloadSize), flow, context.Incoming, testTOS, testTTL, false))
    75  	if expectReadError {
    76  		c.ReadFromEndpointExpectError()
    77  	} else {
    78  		c.ReadFromEndpointExpectNoPacket()
    79  	}
    80  }
    81  
    82  func TestBindToDeviceOption(t *testing.T) {
    83  	s := stack.New(stack.Options{
    84  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
    85  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
    86  		Clock:              &faketime.NullClock{},
    87  	})
    88  	defer s.Destroy()
    89  
    90  	ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
    91  	if err != nil {
    92  		t.Fatalf("NewEndpoint failed; %s", err)
    93  	}
    94  	defer ep.Close()
    95  
    96  	opts := stack.NICOptions{Name: "my_device"}
    97  	if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
    98  		t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
    99  	}
   100  
   101  	// nicIDPtr is used instead of taking the address of NICID literals, which is
   102  	// a compiler error.
   103  	nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
   104  		return &s
   105  	}
   106  
   107  	testActions := []struct {
   108  		name                 string
   109  		setBindToDevice      *tcpip.NICID
   110  		setBindToDeviceError tcpip.Error
   111  		getBindToDevice      int32
   112  	}{
   113  		{"GetDefaultValue", nil, nil, 0},
   114  		{"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
   115  		{"BindToExistent", nicIDPtr(321), nil, 321},
   116  		{"UnbindToDevice", nicIDPtr(0), nil, 0},
   117  	}
   118  	for _, testAction := range testActions {
   119  		t.Run(testAction.name, func(t *testing.T) {
   120  			if testAction.setBindToDevice != nil {
   121  				bindToDevice := int32(*testAction.setBindToDevice)
   122  				if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
   123  					t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
   124  				}
   125  			}
   126  			bindToDevice := ep.SocketOptions().GetBindToDevice()
   127  			if bindToDevice != testAction.getBindToDevice {
   128  				t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
   129  			}
   130  		})
   131  	}
   132  }
   133  
   134  func TestBindEphemeralPort(t *testing.T) {
   135  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   136  	defer c.Cleanup()
   137  
   138  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   139  
   140  	if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
   141  		t.Fatalf("ep.Bind(...) failed: %s", err)
   142  	}
   143  }
   144  
   145  func TestBindReservedPort(t *testing.T) {
   146  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   147  	defer c.Cleanup()
   148  
   149  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   150  
   151  	if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
   152  		c.T.Fatalf("Connect failed: %s", err)
   153  	}
   154  
   155  	addr, err := c.EP.GetLocalAddress()
   156  	if err != nil {
   157  		t.Fatalf("GetLocalAddress failed: %s", err)
   158  	}
   159  
   160  	// We can't bind the address reserved by the connected endpoint above.
   161  	{
   162  		ep, err := c.Stack.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
   163  		if err != nil {
   164  			t.Fatalf("NewEndpoint failed: %s", err)
   165  		}
   166  		defer ep.Close()
   167  		{
   168  			err := ep.Bind(addr)
   169  			if _, ok := err.(*tcpip.ErrPortInUse); !ok {
   170  				t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
   171  			}
   172  		}
   173  	}
   174  
   175  	func() {
   176  		ep, err := c.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
   177  		if err != nil {
   178  			t.Fatalf("NewEndpoint failed: %s", err)
   179  		}
   180  		defer ep.Close()
   181  		// We can't bind ipv4-any on the port reserved by the connected endpoint
   182  		// above, since the endpoint is dual-stack.
   183  		{
   184  			err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
   185  			if _, ok := err.(*tcpip.ErrPortInUse); !ok {
   186  				t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
   187  			}
   188  		}
   189  		// We can bind an ipv4 address on this port, though.
   190  		if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: addr.Port}); err != nil {
   191  			t.Fatalf("ep.Bind(...) failed: %s", err)
   192  		}
   193  	}()
   194  
   195  	// Once the connected endpoint releases its port reservation, we are able to
   196  	// bind ipv4-any once again.
   197  	c.EP.Close()
   198  	func() {
   199  		ep, err := c.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
   200  		if err != nil {
   201  			t.Fatalf("NewEndpoint failed: %s", err)
   202  		}
   203  		defer ep.Close()
   204  		if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
   205  			t.Fatalf("ep.Bind(...) failed: %s", err)
   206  		}
   207  	}()
   208  }
   209  
   210  func TestV4ReadOnV6(t *testing.T) {
   211  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   212  	defer c.Cleanup()
   213  
   214  	c.CreateEndpointForFlow(context.UnicastV4in6, udp.ProtocolNumber)
   215  
   216  	// Bind to wildcard.
   217  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   218  		c.T.Fatalf("Bind failed: %s", err)
   219  	}
   220  
   221  	payload := newRandomPayload(arbitraryPayloadSize)
   222  	buf := context.BuildUDPPacket(payload, context.UnicastV4in6, context.Incoming, testTOS, testTTL, false)
   223  	c.InjectPacket(header.IPv4ProtocolNumber, buf)
   224  	c.ReadFromEndpointExpectSuccess(payload, context.UnicastV4in6)
   225  }
   226  
   227  func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
   228  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   229  	defer c.Cleanup()
   230  
   231  	c.CreateEndpointForFlow(context.UnicastV4in6, udp.ProtocolNumber)
   232  
   233  	// Bind to v4 mapped wildcard.
   234  	if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
   235  		c.T.Fatalf("Bind failed: %s", err)
   236  	}
   237  
   238  	// Test acceptance.
   239  	testRead(c, context.UnicastV4in6)
   240  }
   241  
   242  func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
   243  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   244  	defer c.Cleanup()
   245  
   246  	c.CreateEndpointForFlow(context.UnicastV4in6, udp.ProtocolNumber)
   247  
   248  	// Bind to local address.
   249  	if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil {
   250  		c.T.Fatalf("Bind failed: %s", err)
   251  	}
   252  
   253  	// Test acceptance.
   254  	testRead(c, context.UnicastV4in6)
   255  }
   256  
   257  func TestV6ReadOnV6(t *testing.T) {
   258  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   259  	defer c.Cleanup()
   260  
   261  	c.CreateEndpointForFlow(context.UnicastV6, udp.ProtocolNumber)
   262  
   263  	// Bind to wildcard.
   264  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   265  		c.T.Fatalf("Bind failed: %s", err)
   266  	}
   267  
   268  	// Test acceptance.
   269  	testRead(c, context.UnicastV6)
   270  }
   271  
   272  // TestV4ReadSelfSource checks that packets coming from a local IP address are
   273  // correctly dropped when handleLocal is true and not otherwise.
   274  func TestV4ReadSelfSource(t *testing.T) {
   275  	for _, tt := range []struct {
   276  		name              string
   277  		handleLocal       bool
   278  		wantErr           tcpip.Error
   279  		wantInvalidSource uint64
   280  	}{
   281  		{"HandleLocal", false, nil, 0},
   282  		{"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
   283  	} {
   284  		t.Run(tt.name, func(t *testing.T) {
   285  			c := context.NewWithOptions(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, context.Options{
   286  				MTU:         context.DefaultMTU,
   287  				HandleLocal: tt.handleLocal,
   288  			})
   289  			defer c.Cleanup()
   290  
   291  			c.CreateEndpointForFlow(context.UnicastV4, udp.ProtocolNumber)
   292  
   293  			if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   294  				t.Fatalf("Bind failed: %s", err)
   295  			}
   296  
   297  			payload := newRandomPayload(arbitraryPayloadSize)
   298  			h := context.UnicastV4.MakeHeader4Tuple(context.Incoming)
   299  			h.Src = h.Dst
   300  			c.InjectPacket(header.IPv4ProtocolNumber, context.BuildV4UDPPacket(payload, h, testTOS, testTTL, false))
   301  
   302  			if got := c.Stack.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
   303  				t.Errorf("c.Stack.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
   304  			}
   305  
   306  			if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr {
   307  				t.Errorf("got c.EP.Read = %s, want = %s", err, tt.wantErr)
   308  			}
   309  		})
   310  	}
   311  }
   312  
   313  func TestV4ReadOnV4(t *testing.T) {
   314  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   315  	defer c.Cleanup()
   316  
   317  	c.CreateEndpointForFlow(context.UnicastV4, udp.ProtocolNumber)
   318  
   319  	// Bind to wildcard.
   320  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   321  		c.T.Fatalf("Bind failed: %s", err)
   322  	}
   323  
   324  	// Test acceptance.
   325  	testRead(c, context.UnicastV4)
   326  }
   327  
   328  // TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
   329  // address and receive data sent to that address.
   330  func TestReadOnBoundToMulticast(t *testing.T) {
   331  	// FIXME(b/128189410): context.MulticastV4in6 currently doesn't work as
   332  	// AddMembershipOption doesn't handle V4in6 addresses.
   333  	for _, flow := range []context.TestFlow{context.MulticastV4, context.MulticastV6, context.MulticastV6Only} {
   334  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   335  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   336  			defer c.Cleanup()
   337  
   338  			c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   339  
   340  			// Bind to multicast address.
   341  			mcastAddr := flow.MapAddrIfApplicable(flow.GetMulticastAddr())
   342  			if err := c.EP.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: context.StackPort}); err != nil {
   343  				c.T.Fatal("Bind failed:", err)
   344  			}
   345  
   346  			// Join multicast group.
   347  			ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
   348  			if err := c.EP.SetSockOpt(&ifoptSet); err != nil {
   349  				c.T.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
   350  			}
   351  
   352  			// Check that we receive multicast packets but not unicast or broadcast
   353  			// ones.
   354  			testRead(c, flow)
   355  			testFailingRead(c, context.Broadcast, false /* expectReadError */)
   356  			testFailingRead(c, context.UnicastV4, false /* expectReadError */)
   357  		})
   358  	}
   359  }
   360  
   361  // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
   362  // address and can receive only broadcast data.
   363  func TestV4ReadOnBoundToBroadcast(t *testing.T) {
   364  	for _, flow := range []context.TestFlow{context.Broadcast, context.BroadcastIn6} {
   365  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   366  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   367  			defer c.Cleanup()
   368  
   369  			c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   370  
   371  			// Bind to broadcast address.
   372  			broadcastAddr := flow.MapAddrIfApplicable(context.BroadcastAddr)
   373  			if err := c.EP.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: context.StackPort}); err != nil {
   374  				c.T.Fatalf("Bind failed: %s", err)
   375  			}
   376  
   377  			// Check that we receive broadcast packets but not unicast ones.
   378  			testRead(c, flow)
   379  			testFailingRead(c, context.UnicastV4, false /* expectReadError */)
   380  		})
   381  	}
   382  }
   383  
   384  // TestReadFromMulticast checks that an endpoint will NOT receive a packet
   385  // that was sent with multicast SOURCE address.
   386  func TestReadFromMulticast(t *testing.T) {
   387  	for _, flow := range []context.TestFlow{context.ReverseMulticastV4, context.ReverseMulticastV6} {
   388  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   389  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   390  			defer c.Cleanup()
   391  
   392  			c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   393  
   394  			if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   395  				t.Fatalf("Bind failed: %s", err)
   396  			}
   397  			testFailingRead(c, flow, false /* expectReadError */)
   398  		})
   399  	}
   400  }
   401  
   402  // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
   403  // and receive broadcast and unicast data.
   404  func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
   405  	for _, flow := range []context.TestFlow{context.Broadcast, context.BroadcastIn6} {
   406  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
   407  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   408  			defer c.Cleanup()
   409  
   410  			c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   411  
   412  			// Bind to wildcard.
   413  			if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   414  				c.T.Fatalf("Bind failed: %s (", err)
   415  			}
   416  
   417  			// Check that we receive both broadcast and unicast packets.
   418  			testRead(c, flow)
   419  			testRead(c, context.UnicastV4)
   420  		})
   421  	}
   422  }
   423  
   424  func getEndpointWithPreflight(c *context.Context) tcpip.EndpointWithPreflight {
   425  	epWithPreflight, ok := c.EP.(tcpip.EndpointWithPreflight)
   426  
   427  	if !ok {
   428  		c.T.Fatalf("expect endpoint implements tcpip.EndpointWithPreflight; found endpoint with type %T does not", c.EP)
   429  	}
   430  	return epWithPreflight
   431  }
   432  
   433  func getWriteOptionsForFlow(flow context.TestFlow) tcpip.WriteOptions {
   434  	h := flow.MakeHeader4Tuple(context.Outgoing)
   435  	writeDstAddr := flow.MapAddrIfApplicable(h.Dst.Addr)
   436  	return tcpip.WriteOptions{
   437  		To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port},
   438  	}
   439  }
   440  
   441  // testWriteFails calls the endpoint's Write method with a packet of the
   442  // given test flow, verifying that the method fails with the provided error
   443  // code.
   444  // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the
   445  // testing context.
   446  func testWriteFails(c *context.Context, flow context.TestFlow, payloadSize int, wantErr tcpip.Error) {
   447  	c.T.Helper()
   448  	// Take a snapshot of the stats to validate them at the end of the test.
   449  	var epstats tcpip.TransportEndpointStats
   450  	c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats)
   451  
   452  	var r bytes.Reader
   453  	r.Reset(newRandomPayload(payloadSize))
   454  	_, gotErr := c.EP.Write(&r, getWriteOptionsForFlow(flow))
   455  	c.CheckEndpointWriteStats(1, &epstats, gotErr)
   456  	if gotErr != wantErr {
   457  		c.T.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
   458  	}
   459  }
   460  
   461  // testPreflightSucceeds calls the endpoint's Preflight method with a
   462  // destination of the given flow, verifying that it succeeds.
   463  func testPreflightSucceeds(c *context.Context, flow context.TestFlow) {
   464  	c.T.Helper()
   465  	testPreflightImpl(c, flow, true, nil)
   466  }
   467  
   468  // testPreflightFails calls the endpoint's Preflight method with a destination
   469  // of the given flow, verifying that it fails with the provided error.
   470  func testPreflightFails(c *context.Context, flow context.TestFlow, wantErr tcpip.Error) {
   471  	c.T.Helper()
   472  	testPreflightImpl(c, flow, true, wantErr)
   473  }
   474  
   475  func testPreflightImpl(c *context.Context, flow context.TestFlow, setDest bool, wantErr tcpip.Error) {
   476  	c.T.Helper()
   477  	// Take a snapshot of the stats to validate them at the end of the test.
   478  	var epstats tcpip.TransportEndpointStats
   479  	c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats)
   480  
   481  	writeOpts := tcpip.WriteOptions{}
   482  	if setDest {
   483  		writeOpts = getWriteOptionsForFlow(flow)
   484  	}
   485  
   486  	gotErr := getEndpointWithPreflight(c).Preflight(writeOpts)
   487  	if gotErr != wantErr {
   488  		c.T.Fatalf("Preflight returned unexpected error: got %v, want %v", gotErr, wantErr)
   489  	}
   490  
   491  	c.CheckEndpointWriteStats(0, &epstats, gotErr)
   492  }
   493  
   494  type writeOperation int
   495  
   496  const (
   497  	write writeOperation = iota
   498  	preflight
   499  )
   500  
   501  // testWriteOpSequenceSucceeds calls the provided sequence of write operations with a packet of the
   502  // given test flow, verifying that each operation succeeds.
   503  func testWriteOpSequenceSucceeds(c *context.Context, flow context.TestFlow, ops []writeOperation, checkers ...checker.NetworkChecker) {
   504  	c.T.Helper()
   505  	for _, op := range ops {
   506  		switch op {
   507  		case write:
   508  			testWriteSucceedsAndGetReceivedSrcPort(c, flow, checkers...)
   509  		case preflight:
   510  			testPreflightSucceeds(c, flow)
   511  		}
   512  	}
   513  }
   514  
   515  // testWriteOpSequenceSucceedsNoDestination calls the provided sequence of write operations with a
   516  // packet of the given test flow, without giving a destination address:port, verifying that each
   517  // operation succeeds.
   518  func testWriteOpSequenceSucceedsNoDestination(c *context.Context, flow context.TestFlow, ops []writeOperation) {
   519  	c.T.Helper()
   520  	for _, op := range ops {
   521  		switch op {
   522  		case write:
   523  			testWriteAndVerifyInternal(c, flow, false /* setDest */)
   524  		case preflight:
   525  			testPreflightImpl(c, flow, false /* setDest */, nil /* wantErr */)
   526  		}
   527  	}
   528  }
   529  
   530  // testWriteOpSequenceFails calls the provided sequence of write operations with a packet of the
   531  // given test flow, verifying that each operation fails with the provided err.
   532  func testWriteOpSequenceFails(c *context.Context, flow context.TestFlow, ops []writeOperation, err tcpip.Error) {
   533  	c.T.Helper()
   534  	for _, op := range ops {
   535  		switch op {
   536  		case write:
   537  			testWriteFails(c, flow, arbitraryPayloadSize, err)
   538  		case preflight:
   539  			testPreflightFails(c, flow, err)
   540  		}
   541  	}
   542  }
   543  
   544  // testWriteSucceedsAndGetReceivedSrcPort calls the endpoint's Write method with a packet of the
   545  // given test flow and a destination constructed from the flow's destination address:port. It then
   546  // receives the packet from the link endpoint and verifies its correctness using the
   547  // provided checker functions, returning the found source port.
   548  // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the
   549  // testing context.
   550  func testWriteSucceedsAndGetReceivedSrcPort(c *context.Context, flow context.TestFlow, checkers ...checker.NetworkChecker) uint16 {
   551  	c.T.Helper()
   552  	return testWriteAndVerifyInternal(c, flow, true, checkers...)
   553  }
   554  
   555  // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the
   556  // testing context.
   557  func testWriteNoVerify(c *context.Context, flow context.TestFlow, setDest bool) []byte {
   558  	c.T.Helper()
   559  	// Take a snapshot of the stats to validate them at the end of the test.
   560  	var epstats tcpip.TransportEndpointStats
   561  	c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats)
   562  
   563  	writeOpts := tcpip.WriteOptions{}
   564  	if setDest {
   565  		h := flow.MakeHeader4Tuple(context.Outgoing)
   566  		writeDstAddr := flow.MapAddrIfApplicable(h.Dst.Addr)
   567  		writeOpts = tcpip.WriteOptions{
   568  			To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port},
   569  		}
   570  	}
   571  
   572  	var r bytes.Reader
   573  	payload := newRandomPayload(arbitraryPayloadSize)
   574  	r.Reset(payload)
   575  	n, err := c.EP.Write(&r, writeOpts)
   576  	if err != nil {
   577  		c.T.Fatalf("Write failed: %s", err)
   578  	}
   579  	if n != int64(len(payload)) {
   580  		c.T.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
   581  	}
   582  	c.CheckEndpointWriteStats(1, &epstats, err)
   583  	return payload
   584  }
   585  
   586  // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the
   587  // testing context.
   588  func testWriteAndVerifyInternal(c *context.Context, flow context.TestFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
   589  	c.T.Helper()
   590  	payload := testWriteNoVerify(c, flow, setDest)
   591  	// Received the packet and check the payload.
   592  
   593  	p := c.LinkEP.Read()
   594  	if p == nil {
   595  		c.T.Fatalf("Packet wasn't written out")
   596  	}
   597  	defer p.DecRef()
   598  
   599  	if got, want := p.NetworkProtocolNumber, flow.NetProto(); got != want {
   600  		c.T.Fatalf("got p.NetworkProtocolNumber = %d, want = %d", got, want)
   601  	}
   602  
   603  	if got, want := p.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
   604  		c.T.Errorf("got p.TransportProtocolNumber = %d, want = %d", got, want)
   605  	}
   606  
   607  	v := p.ToView()
   608  	defer v.Release()
   609  
   610  	h := flow.MakeHeader4Tuple(context.Outgoing)
   611  	checkers = append(
   612  		checkers,
   613  		checker.SrcAddr(h.Src.Addr),
   614  		checker.DstAddr(h.Dst.Addr),
   615  		checker.UDP(checker.DstPort(h.Dst.Port)),
   616  	)
   617  	flow.CheckerFn()(c.T, v, checkers...)
   618  
   619  	var udpH header.UDP
   620  	if flow.IsV4() {
   621  		udpH = header.IPv4(v.AsSlice()).Payload()
   622  	} else {
   623  		udpH = header.IPv6(v.AsSlice()).Payload()
   624  	}
   625  	if !bytes.Equal(payload, udpH.Payload()) {
   626  		c.T.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload)
   627  	}
   628  
   629  	return udpH.SourcePort()
   630  }
   631  
   632  func testDualWrite(c *context.Context) uint16 {
   633  	c.T.Helper()
   634  
   635  	v4Port := testWriteSucceedsAndGetReceivedSrcPort(c, context.UnicastV4in6)
   636  	v6Port := testWriteSucceedsAndGetReceivedSrcPort(c, context.UnicastV6)
   637  	if v4Port != v6Port {
   638  		c.T.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
   639  	}
   640  
   641  	return v4Port
   642  }
   643  
   644  func TestDualWriteUnbound(t *testing.T) {
   645  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   646  	defer c.Cleanup()
   647  
   648  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   649  
   650  	testDualWrite(c)
   651  }
   652  
   653  func TestDualWriteBoundToWildcard(t *testing.T) {
   654  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   655  	defer c.Cleanup()
   656  
   657  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   658  
   659  	// Bind to wildcard.
   660  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   661  		c.T.Fatalf("Bind failed: %s", err)
   662  	}
   663  
   664  	p := testDualWrite(c)
   665  	if p != context.StackPort {
   666  		c.T.Fatalf("Bad port: got %v, want %v", p, context.StackPort)
   667  	}
   668  }
   669  
   670  func TestDualWriteConnectedToV6(t *testing.T) {
   671  	for _, testCase := range []struct {
   672  		writeOpSequence         []writeOperation
   673  		expectedNoRouteErrCount uint64
   674  	}{
   675  		{writeOpSequence: []writeOperation{write}, expectedNoRouteErrCount: 1},
   676  		{writeOpSequence: []writeOperation{preflight}, expectedNoRouteErrCount: 0},
   677  		{writeOpSequence: []writeOperation{preflight, write}, expectedNoRouteErrCount: 1},
   678  	} {
   679  		c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   680  
   681  		c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   682  
   683  		// Connect to v6 address.
   684  		if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
   685  			c.T.Fatalf("Bind failed: %s", err)
   686  		}
   687  
   688  		testWriteOpSequenceSucceeds(c, context.UnicastV6, testCase.writeOpSequence)
   689  
   690  		// Write to V4 mapped address.
   691  		testWriteOpSequenceFails(c, context.UnicastV4in6, testCase.writeOpSequence, &tcpip.ErrNetworkUnreachable{})
   692  
   693  		if got := c.EP.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != testCase.expectedNoRouteErrCount {
   694  			c.T.Fatalf("Endpoint stat not updated. got %d want %d", got, testCase.expectedNoRouteErrCount)
   695  		}
   696  		c.Cleanup()
   697  	}
   698  }
   699  
   700  var writeOpSequences = map[string]([]writeOperation){
   701  	"write":           []writeOperation{write},
   702  	"preflight":       []writeOperation{preflight},
   703  	"write|preflight": []writeOperation{preflight, write},
   704  }
   705  
   706  func TestDualWriteConnectedToV4Mapped(t *testing.T) {
   707  	for name, writeOpSequence := range writeOpSequences {
   708  		t.Run(name, func(t *testing.T) {
   709  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   710  			defer c.Cleanup()
   711  
   712  			c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   713  
   714  			// Connect to v4 mapped address.
   715  			if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}); err != nil {
   716  				c.T.Fatalf("Bind failed: %s", err)
   717  			}
   718  
   719  			testWriteOpSequenceSucceeds(c, context.UnicastV4in6, writeOpSequence)
   720  
   721  			// Write to v6 address.
   722  			testWriteOpSequenceFails(c, context.UnicastV6, writeOpSequence, &tcpip.ErrInvalidEndpointState{})
   723  		})
   724  	}
   725  }
   726  
   727  func TestPreflightBindsEndpoint(t *testing.T) {
   728  	tcs := []struct {
   729  		name  string
   730  		proto tcpip.NetworkProtocolNumber
   731  		flow  context.TestFlow
   732  	}{
   733  		{
   734  			name:  "ipv4",
   735  			proto: ipv4.ProtocolNumber,
   736  			flow:  context.UnicastV4,
   737  		},
   738  		{
   739  			name:  "ipv6",
   740  			proto: ipv6.ProtocolNumber,
   741  			flow:  context.UnicastV6,
   742  		},
   743  	}
   744  	for _, tc := range tcs {
   745  		t.Run(tc.name, func(t *testing.T) {
   746  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol})
   747  			defer c.Cleanup()
   748  
   749  			c.CreateEndpoint(tc.proto, udp.ProtocolNumber)
   750  
   751  			h := tc.flow.MakeHeader4Tuple(context.Outgoing)
   752  			writeDstAddr := tc.flow.MapAddrIfApplicable(h.Dst.Addr)
   753  			writeOpts := tcpip.WriteOptions{
   754  				To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port},
   755  			}
   756  
   757  			if err := getEndpointWithPreflight(c).Preflight(writeOpts); err != nil {
   758  				c.T.Fatalf("Preflight failed: %s", err)
   759  			}
   760  
   761  			if c.EP.State() != uint32(transport.DatagramEndpointStateBound) {
   762  				c.T.Fatalf("Expect UDP endpoint in state %d, found %d", transport.DatagramEndpointStateBound, c.EP.State())
   763  			}
   764  		})
   765  	}
   766  }
   767  
   768  func TestV4WriteOnV6Only(t *testing.T) {
   769  	for name, writeOpSequence := range writeOpSequences {
   770  		t.Run(name, func(t *testing.T) {
   771  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   772  			defer c.Cleanup()
   773  
   774  			c.CreateEndpointForFlow(context.UnicastV6Only, udp.ProtocolNumber)
   775  
   776  			// Write to V4 mapped address.
   777  			testWriteOpSequenceFails(c, context.UnicastV4in6, writeOpSequence, &tcpip.ErrHostUnreachable{})
   778  		})
   779  	}
   780  }
   781  
   782  func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
   783  	for name, writeOpSequence := range writeOpSequences {
   784  		t.Run(name, func(t *testing.T) {
   785  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   786  			defer c.Cleanup()
   787  
   788  			c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   789  
   790  			// Bind to v4 mapped address.
   791  			if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil {
   792  				c.T.Fatalf("Bind failed: %s", err)
   793  			}
   794  
   795  			// Write to v6 address.
   796  			testWriteOpSequenceFails(c, context.UnicastV6, writeOpSequence, &tcpip.ErrInvalidEndpointState{})
   797  		})
   798  	}
   799  }
   800  
   801  func TestV6WriteOnConnected(t *testing.T) {
   802  	for name, writeOpSequence := range writeOpSequences {
   803  		t.Run(name, func(t *testing.T) {
   804  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   805  			defer c.Cleanup()
   806  
   807  			c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   808  
   809  			// Connect to v6 address.
   810  			if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
   811  				c.T.Fatalf("Connect failed: %s", err)
   812  			}
   813  
   814  			testWriteOpSequenceSucceedsNoDestination(c, context.UnicastV6, writeOpSequence)
   815  		})
   816  	}
   817  }
   818  
   819  func TestV4WriteOnConnected(t *testing.T) {
   820  	for name, writeOpSequence := range writeOpSequences {
   821  		t.Run(name, func(t *testing.T) {
   822  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   823  			defer c.Cleanup()
   824  
   825  			c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
   826  
   827  			// Connect to v4 mapped address.
   828  			if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}); err != nil {
   829  				c.T.Fatalf("Connect failed: %s", err)
   830  			}
   831  
   832  			testWriteOpSequenceSucceedsNoDestination(c, context.UnicastV4, writeOpSequence)
   833  		})
   834  	}
   835  }
   836  
   837  func TestWriteOnConnectedInvalidPort(t *testing.T) {
   838  	const invalidPort = 8192
   839  	protocols := map[string]tcpip.NetworkProtocolNumber{
   840  		"ipv4": ipv4.ProtocolNumber,
   841  		"ipv6": ipv6.ProtocolNumber,
   842  	}
   843  	for name, proto := range protocols {
   844  		t.Run(name, func(t *testing.T) {
   845  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   846  			defer c.Cleanup()
   847  
   848  			c.CreateEndpoint(proto, udp.ProtocolNumber)
   849  			if err := c.EP.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: invalidPort}); err != nil {
   850  				c.T.Fatalf("Connect failed: %s", err)
   851  			}
   852  			writeOpts := tcpip.WriteOptions{
   853  				To: &tcpip.FullAddress{Addr: context.StackAddr, Port: invalidPort},
   854  			}
   855  			var r bytes.Reader
   856  			payload := newRandomPayload(arbitraryPayloadSize)
   857  			r.Reset(payload)
   858  			n, err := c.EP.Write(&r, writeOpts)
   859  			if err != nil {
   860  				c.T.Fatalf("c.EP.Write(...) = %s, want nil", err)
   861  			}
   862  			if got, want := n, int64(len(payload)); got != want {
   863  				c.T.Fatalf("c.EP.Write(...) wrote %d bytes, want %d bytes", got, want)
   864  			}
   865  
   866  			{
   867  				err := c.EP.LastError()
   868  				if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
   869  					c.T.Fatalf("expected c.EP.LastError() == ErrConnectionRefused, got: %+v", err)
   870  				}
   871  			}
   872  		})
   873  	}
   874  }
   875  
   876  // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
   877  // that is bound to a V4 multicast address.
   878  func TestWriteOnBoundToV4Multicast(t *testing.T) {
   879  	for _, writeOpSequence := range writeOpSequences {
   880  		for _, flow := range []context.TestFlow{context.UnicastV4, context.MulticastV4, context.Broadcast} {
   881  			t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
   882  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   883  				defer c.Cleanup()
   884  
   885  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   886  
   887  				// Bind to V4 mcast address.
   888  				if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastAddr, Port: context.StackPort}); err != nil {
   889  					c.T.Fatal("Bind failed:", err)
   890  				}
   891  
   892  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence)
   893  			})
   894  		}
   895  	}
   896  }
   897  
   898  // TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
   899  // socket that is bound to a V4-mapped multicast address.
   900  func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
   901  	for _, writeOpSequence := range writeOpSequences {
   902  		for _, flow := range []context.TestFlow{context.UnicastV4in6, context.MulticastV4in6, context.BroadcastIn6} {
   903  			t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
   904  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   905  				defer c.Cleanup()
   906  
   907  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   908  
   909  				// Bind to V4Mapped mcast address.
   910  				if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastV4MappedAddr, Port: context.StackPort}); err != nil {
   911  					c.T.Fatalf("Bind failed: %s", err)
   912  				}
   913  
   914  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence)
   915  			})
   916  		}
   917  	}
   918  }
   919  
   920  // TestWriteOnBoundToV6Multicast checks that we can send packets out of a
   921  // socket that is bound to a V6 multicast address.
   922  func TestWriteOnBoundToV6Multicast(t *testing.T) {
   923  	for _, writeOpSequence := range writeOpSequences {
   924  		for _, flow := range []context.TestFlow{context.UnicastV6, context.MulticastV6} {
   925  			t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
   926  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   927  				defer c.Cleanup()
   928  
   929  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   930  
   931  				// Bind to V6 mcast address.
   932  				if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastV6Addr, Port: context.StackPort}); err != nil {
   933  					c.T.Fatalf("Bind failed: %s", err)
   934  				}
   935  
   936  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence)
   937  			})
   938  		}
   939  	}
   940  }
   941  
   942  // TestWriteOnBoundToV6Multicast checks that we can send packets out of a
   943  // V6-only socket that is bound to a V6 multicast address.
   944  func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
   945  	for _, writeOpSequence := range writeOpSequences {
   946  		for _, flow := range []context.TestFlow{context.UnicastV6Only, context.MulticastV6Only} {
   947  			t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
   948  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   949  				defer c.Cleanup()
   950  
   951  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   952  
   953  				// Bind to V6 mcast address.
   954  				if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastV6Addr, Port: context.StackPort}); err != nil {
   955  					c.T.Fatalf("Bind failed: %s", err)
   956  				}
   957  
   958  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence)
   959  			})
   960  		}
   961  	}
   962  }
   963  
   964  // TestWriteOnBoundToBroadcast checks that we can send packets out of a
   965  // socket that is bound to the broadcast address.
   966  func TestWriteOnBoundToBroadcast(t *testing.T) {
   967  	for _, writeOpSequence := range writeOpSequences {
   968  		for _, flow := range []context.TestFlow{context.UnicastV4, context.MulticastV4, context.Broadcast} {
   969  			t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
   970  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   971  				defer c.Cleanup()
   972  
   973  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   974  
   975  				// Bind to V4 broadcast address.
   976  				if err := c.EP.Bind(tcpip.FullAddress{Addr: context.BroadcastAddr, Port: context.StackPort}); err != nil {
   977  					c.T.Fatal("Bind failed:", err)
   978  				}
   979  
   980  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence)
   981  			})
   982  		}
   983  	}
   984  }
   985  
   986  // TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
   987  // socket that is bound to the V4-mapped broadcast address.
   988  func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
   989  	for _, writeOpSequence := range writeOpSequences {
   990  		for _, flow := range []context.TestFlow{context.UnicastV4in6, context.MulticastV4in6, context.BroadcastIn6} {
   991  			t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
   992  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
   993  				defer c.Cleanup()
   994  
   995  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
   996  
   997  				// Bind to V4Mapped mcast address.
   998  				if err := c.EP.Bind(tcpip.FullAddress{Addr: context.BroadcastV4MappedAddr, Port: context.StackPort}); err != nil {
   999  					c.T.Fatalf("Bind failed: %s", err)
  1000  				}
  1001  
  1002  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence)
  1003  			})
  1004  		}
  1005  	}
  1006  }
  1007  
  1008  func TestReadIncrementsPacketsReceived(t *testing.T) {
  1009  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1010  	defer c.Cleanup()
  1011  
  1012  	// Create IPv4 UDP endpoint
  1013  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1014  
  1015  	// Bind to wildcard.
  1016  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1017  		c.T.Fatalf("Bind failed: %s", err)
  1018  	}
  1019  
  1020  	testRead(c, context.UnicastV4)
  1021  
  1022  	var want uint64 = 1
  1023  	if got := c.Stack.Stats().UDP.PacketsReceived.Value(); got != want {
  1024  		c.T.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
  1025  	}
  1026  }
  1027  
  1028  func TestReadRecvOriginalDstAddr(t *testing.T) {
  1029  	tests := []struct {
  1030  		name                    string
  1031  		proto                   tcpip.NetworkProtocolNumber
  1032  		flow                    context.TestFlow
  1033  		expectedOriginalDstAddr tcpip.FullAddress
  1034  	}{
  1035  		{
  1036  			name:                    "IPv4 unicast",
  1037  			proto:                   header.IPv4ProtocolNumber,
  1038  			flow:                    context.UnicastV4,
  1039  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.StackAddr, Port: context.StackPort},
  1040  		},
  1041  		{
  1042  			name:  "IPv4 multicast",
  1043  			proto: header.IPv4ProtocolNumber,
  1044  			flow:  context.MulticastV4,
  1045  			// This should actually be a unicast address assigned to the interface.
  1046  			//
  1047  			// TODO(gvisor.dev/issue/3556): This check is validating incorrect
  1048  			// behaviour. We still include the test so that once the bug is resolved,
  1049  			// this test will start to fail and the individual tasked with fixing this
  1050  			// bug knows to also fix this test :).
  1051  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.MulticastAddr, Port: context.StackPort},
  1052  		},
  1053  		{
  1054  			name:  "IPv4 broadcast",
  1055  			proto: header.IPv4ProtocolNumber,
  1056  			flow:  context.Broadcast,
  1057  			// This should actually be a unicast address assigned to the interface.
  1058  			//
  1059  			// TODO(gvisor.dev/issue/3556): This check is validating incorrect
  1060  			// behaviour. We still include the test so that once the bug is resolved,
  1061  			// this test will start to fail and the individual tasked with fixing this
  1062  			// bug knows to also fix this test :).
  1063  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.BroadcastAddr, Port: context.StackPort},
  1064  		},
  1065  		{
  1066  			name:                    "IPv6 unicast",
  1067  			proto:                   header.IPv6ProtocolNumber,
  1068  			flow:                    context.UnicastV6,
  1069  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.StackV6Addr, Port: context.StackPort},
  1070  		},
  1071  		{
  1072  			name:  "IPv6 multicast",
  1073  			proto: header.IPv6ProtocolNumber,
  1074  			flow:  context.MulticastV6,
  1075  			// This should actually be a unicast address assigned to the interface.
  1076  			//
  1077  			// TODO(gvisor.dev/issue/3556): This check is validating incorrect
  1078  			// behaviour. We still include the test so that once the bug is resolved,
  1079  			// this test will start to fail and the individual tasked with fixing this
  1080  			// bug knows to also fix this test :).
  1081  			expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.MulticastV6Addr, Port: context.StackPort},
  1082  		},
  1083  	}
  1084  
  1085  	for _, test := range tests {
  1086  		t.Run(test.name, func(t *testing.T) {
  1087  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1088  			defer c.Cleanup()
  1089  
  1090  			c.CreateEndpoint(test.proto, udp.ProtocolNumber)
  1091  
  1092  			bindAddr := tcpip.FullAddress{Port: context.StackPort}
  1093  			if err := c.EP.Bind(bindAddr); err != nil {
  1094  				t.Fatalf("Bind(%#v): %s", bindAddr, err)
  1095  			}
  1096  
  1097  			if test.flow.IsMulticast() {
  1098  				ifoptSet := tcpip.AddMembershipOption{NIC: context.NICID, MulticastAddr: test.flow.GetMulticastAddr()}
  1099  				if err := c.EP.SetSockOpt(&ifoptSet); err != nil {
  1100  					c.T.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
  1101  				}
  1102  			}
  1103  
  1104  			c.EP.SocketOptions().SetReceiveOriginalDstAddress(true)
  1105  
  1106  			testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
  1107  
  1108  			if got := c.Stack.Stats().UDP.PacketsReceived.Value(); got != 1 {
  1109  				t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
  1110  			}
  1111  		})
  1112  	}
  1113  }
  1114  
  1115  func TestWriteIncrementsPacketsSent(t *testing.T) {
  1116  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1117  	defer c.Cleanup()
  1118  
  1119  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1120  
  1121  	testDualWrite(c)
  1122  
  1123  	var want uint64 = 2
  1124  	if got := c.Stack.Stats().UDP.PacketsSent.Value(); got != want {
  1125  		c.T.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
  1126  	}
  1127  }
  1128  
  1129  func TestNoChecksum(t *testing.T) {
  1130  	for _, writeOpSequence := range writeOpSequences {
  1131  		for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6} {
  1132  			t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1133  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1134  				defer c.Cleanup()
  1135  
  1136  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1137  
  1138  				// Disable the checksum generation.
  1139  				c.EP.SocketOptions().SetNoChecksum(true)
  1140  				// This option is effective on IPv4 only.
  1141  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.UDP(checker.NoChecksum(flow.IsV4())))
  1142  
  1143  				// Enable the checksum generation.
  1144  				c.EP.SocketOptions().SetNoChecksum(false)
  1145  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.UDP(checker.NoChecksum(false)))
  1146  			})
  1147  		}
  1148  	}
  1149  }
  1150  
  1151  var _ stack.NetworkInterface = (*testInterface)(nil)
  1152  
  1153  type testInterface struct {
  1154  	stack.NetworkInterface
  1155  }
  1156  
  1157  func (*testInterface) ID() tcpip.NICID {
  1158  	return 0
  1159  }
  1160  
  1161  func (*testInterface) Enabled() bool {
  1162  	return true
  1163  }
  1164  
  1165  func TestDefaultTTL(t *testing.T) {
  1166  	for _, writeOpSequence := range writeOpSequences {
  1167  		for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV4in6, context.UnicastV6, context.UnicastV6Only, context.Broadcast, context.BroadcastIn6} {
  1168  			t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1169  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1170  				defer c.Cleanup()
  1171  
  1172  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1173  				proto := c.Stack.NetworkProtocolInstance(flow.NetProto())
  1174  				if proto == nil {
  1175  					t.Fatalf("c.Stack.NetworkProtocolInstance(flow.NetProto()) did not return a protocol")
  1176  				}
  1177  
  1178  				var initialDefaultTTL tcpip.DefaultTTLOption
  1179  				if err := proto.Option(&initialDefaultTTL); err != nil {
  1180  					t.Fatalf("proto.Option(&initialDefaultTTL) (%T) failed: %s", initialDefaultTTL, err)
  1181  				}
  1182  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(uint8(initialDefaultTTL)))
  1183  
  1184  				newDefaultTTL := tcpip.DefaultTTLOption(initialDefaultTTL + 1)
  1185  				if err := proto.SetOption(&newDefaultTTL); err != nil {
  1186  					c.T.Fatalf("proto.SetOption(&%T(%d))) failed: %s", newDefaultTTL, newDefaultTTL, err)
  1187  				}
  1188  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(uint8(newDefaultTTL)))
  1189  			})
  1190  		}
  1191  	}
  1192  }
  1193  
  1194  func TestSetNonMulticastTTL(t *testing.T) {
  1195  	for _, writeOpSequence := range writeOpSequences {
  1196  		for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV4in6, context.UnicastV6, context.UnicastV6Only, context.Broadcast, context.BroadcastIn6} {
  1197  			t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1198  				for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
  1199  					t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
  1200  						c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1201  						defer c.Cleanup()
  1202  
  1203  						c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1204  
  1205  						var relevantOpt tcpip.SockOptInt
  1206  						var irrelevantOpt tcpip.SockOptInt
  1207  						if flow.IsV4() {
  1208  							relevantOpt = tcpip.IPv4TTLOption
  1209  							irrelevantOpt = tcpip.IPv6HopLimitOption
  1210  						} else {
  1211  							relevantOpt = tcpip.IPv6HopLimitOption
  1212  							irrelevantOpt = tcpip.IPv4TTLOption
  1213  						}
  1214  						if err := c.EP.SetSockOptInt(relevantOpt, int(wantTTL)); err != nil {
  1215  							c.T.Fatalf("SetSockOptInt(%d, %d) failed: %s", relevantOpt, wantTTL, err)
  1216  						}
  1217  						// Set a different ttl/hoplimit for the unused protocol, showing that
  1218  						// it does not affect the other protocol.
  1219  						if err := c.EP.SetSockOptInt(irrelevantOpt, int(wantTTL+1)); err != nil {
  1220  							c.T.Fatalf("SetSockOptInt(%d, %d) failed: %s", irrelevantOpt, wantTTL, err)
  1221  						}
  1222  
  1223  						testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(wantTTL))
  1224  					})
  1225  				}
  1226  			})
  1227  		}
  1228  	}
  1229  }
  1230  
  1231  func TestSetMulticastTTL(t *testing.T) {
  1232  	for _, writeOpSequence := range writeOpSequences {
  1233  		for _, flow := range []context.TestFlow{context.MulticastV4, context.MulticastV4in6, context.MulticastV6} {
  1234  			t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1235  				for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
  1236  					t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
  1237  						c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1238  						defer c.Cleanup()
  1239  
  1240  						c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1241  
  1242  						if err := c.EP.SetSockOptInt(tcpip.MulticastTTLOption, int(wantTTL)); err != nil {
  1243  							c.T.Fatalf("SetSockOptInt failed: %s", err)
  1244  						}
  1245  
  1246  						testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(wantTTL))
  1247  					})
  1248  				}
  1249  			})
  1250  		}
  1251  	}
  1252  }
  1253  
  1254  var v4PacketFlows = [...]context.TestFlow{context.UnicastV4, context.MulticastV4, context.Broadcast, context.UnicastV4in6, context.MulticastV4in6, context.BroadcastIn6}
  1255  
  1256  func TestSetTOS(t *testing.T) {
  1257  	for _, writeOpSequence := range writeOpSequences {
  1258  		for _, flow := range v4PacketFlows {
  1259  			t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1260  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1261  				defer c.Cleanup()
  1262  
  1263  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1264  
  1265  				const tos = testTOS
  1266  				v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
  1267  				if err != nil {
  1268  					c.T.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
  1269  				}
  1270  				// Test for expected default value.
  1271  				if v != 0 {
  1272  					c.T.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
  1273  				}
  1274  
  1275  				if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
  1276  					c.T.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
  1277  				}
  1278  
  1279  				v, err = c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
  1280  				if err != nil {
  1281  					c.T.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
  1282  				}
  1283  
  1284  				if v != tos {
  1285  					c.T.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
  1286  				}
  1287  
  1288  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TOS(tos, 0))
  1289  			})
  1290  		}
  1291  	}
  1292  }
  1293  
  1294  var v6PacketFlows = [...]context.TestFlow{context.UnicastV6, context.UnicastV6Only, context.MulticastV6}
  1295  
  1296  func TestSetTClass(t *testing.T) {
  1297  	for _, writeOpSequence := range writeOpSequences {
  1298  		for _, flow := range v6PacketFlows {
  1299  			t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1300  				c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1301  				defer c.Cleanup()
  1302  
  1303  				c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1304  
  1305  				const tClass = testTOS
  1306  				v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
  1307  				if err != nil {
  1308  					c.T.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
  1309  				}
  1310  				// Test for expected default value.
  1311  				if v != 0 {
  1312  					c.T.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
  1313  				}
  1314  
  1315  				if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
  1316  					c.T.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
  1317  				}
  1318  
  1319  				v, err = c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
  1320  				if err != nil {
  1321  					c.T.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
  1322  				}
  1323  
  1324  				if v != tClass {
  1325  					c.T.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
  1326  				}
  1327  
  1328  				// The header getter for TClass is called TOS, so use that checker.
  1329  				testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TOS(tClass, 0))
  1330  			})
  1331  		}
  1332  	}
  1333  }
  1334  
  1335  func TestReceiveControlMessage(t *testing.T) {
  1336  	for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6, context.UnicastV6Only, context.MulticastV4, context.MulticastV6, context.MulticastV6Only, context.Broadcast} {
  1337  		t.Run(flow.String(), func(t *testing.T) {
  1338  			for _, test := range []struct {
  1339  				name             string
  1340  				optionProtocol   tcpip.NetworkProtocolNumber
  1341  				getReceiveOption func(tcpip.Endpoint) bool
  1342  				setReceiveOption func(tcpip.Endpoint, bool)
  1343  				presenceChecker  checker.ControlMessagesChecker
  1344  				absenceChecker   checker.ControlMessagesChecker
  1345  			}{
  1346  				{
  1347  					name:             "TOS",
  1348  					optionProtocol:   header.IPv4ProtocolNumber,
  1349  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTOS() },
  1350  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTOS(value) },
  1351  					presenceChecker:  checker.ReceiveTOS(testTOS),
  1352  					absenceChecker:   checker.NoTOSReceived(),
  1353  				},
  1354  				{
  1355  					name:             "TClass",
  1356  					optionProtocol:   header.IPv6ProtocolNumber,
  1357  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTClass() },
  1358  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTClass(value) },
  1359  					presenceChecker:  checker.ReceiveTClass(testTOS),
  1360  					absenceChecker:   checker.NoTClassReceived(),
  1361  				},
  1362  				{
  1363  					name:             "TTL",
  1364  					optionProtocol:   header.IPv4ProtocolNumber,
  1365  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTTL() },
  1366  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTTL(value) },
  1367  					presenceChecker:  checker.ReceiveTTL(testTTL),
  1368  					absenceChecker:   checker.NoTTLReceived(),
  1369  				},
  1370  				{
  1371  					name:             "HopLimit",
  1372  					optionProtocol:   header.IPv6ProtocolNumber,
  1373  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveHopLimit() },
  1374  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveHopLimit(value) },
  1375  					presenceChecker:  checker.ReceiveHopLimit(testTTL),
  1376  					absenceChecker:   checker.NoHopLimitReceived(),
  1377  				},
  1378  				{
  1379  					name:             "PacketInfo",
  1380  					optionProtocol:   header.IPv4ProtocolNumber,
  1381  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceivePacketInfo() },
  1382  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceivePacketInfo(value) },
  1383  					presenceChecker: func() checker.ControlMessagesChecker {
  1384  						h := flow.MakeHeader4Tuple(context.Incoming)
  1385  						return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
  1386  							NIC: context.NICID,
  1387  							// TODO(https://gvisor.dev/issue/3556): Expect the NIC's address
  1388  							// instead of the header destination address for the LocalAddr
  1389  							// field.
  1390  							LocalAddr:       h.Dst.Addr,
  1391  							DestinationAddr: h.Dst.Addr,
  1392  						})
  1393  					}(),
  1394  					absenceChecker: checker.NoIPPacketInfoReceived(),
  1395  				},
  1396  				{
  1397  					name:             "IPv6PacketInfo",
  1398  					optionProtocol:   header.IPv6ProtocolNumber,
  1399  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetIPv6ReceivePacketInfo() },
  1400  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetIPv6ReceivePacketInfo(value) },
  1401  					presenceChecker: func() checker.ControlMessagesChecker {
  1402  						h := flow.MakeHeader4Tuple(context.Incoming)
  1403  						return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
  1404  							NIC:  context.NICID,
  1405  							Addr: h.Dst.Addr,
  1406  						})
  1407  					}(),
  1408  					absenceChecker: checker.NoIPv6PacketInfoReceived(),
  1409  				},
  1410  			} {
  1411  				t.Run(test.name, func(t *testing.T) {
  1412  					c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol})
  1413  					defer c.Cleanup()
  1414  
  1415  					c.CreateEndpointForFlow(flow, udp.ProtocolNumber)
  1416  					if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1417  						c.T.Fatalf("Bind failed: %s", err)
  1418  					}
  1419  					if flow.IsMulticast() {
  1420  						netProto := flow.NetProto()
  1421  						addr := flow.GetMulticastAddr()
  1422  						if err := c.Stack.JoinGroup(netProto, context.NICID, addr); err != nil {
  1423  							c.T.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, context.NICID, addr, err)
  1424  						}
  1425  					}
  1426  
  1427  					payload := newRandomPayload(arbitraryPayloadSize)
  1428  					buf := context.BuildUDPPacket(payload, flow, context.Incoming, testTOS, testTTL, false)
  1429  
  1430  					if test.getReceiveOption(c.EP) {
  1431  						t.Fatal("got getReceiveOption() = true, want = false")
  1432  					}
  1433  
  1434  					test.setReceiveOption(c.EP, true)
  1435  					if !test.getReceiveOption(c.EP) {
  1436  						t.Fatal("got getReceiveOption() = false, want = true")
  1437  					}
  1438  
  1439  					c.InjectPacket(flow.NetProto(), buf)
  1440  					if flow.NetProto() == test.optionProtocol {
  1441  						c.ReadFromEndpointExpectSuccess(payload, flow, test.presenceChecker)
  1442  					} else {
  1443  						c.ReadFromEndpointExpectSuccess(payload, flow, test.absenceChecker)
  1444  					}
  1445  				})
  1446  			}
  1447  		})
  1448  	}
  1449  }
  1450  
  1451  func TestMulticastInterfaceOption(t *testing.T) {
  1452  	for _, flow := range []context.TestFlow{context.MulticastV4, context.MulticastV4in6, context.MulticastV6, context.MulticastV6Only} {
  1453  		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
  1454  			for _, bindTyp := range []string{"bound", "unbound"} {
  1455  				t.Run(bindTyp, func(t *testing.T) {
  1456  					for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
  1457  						t.Run(optTyp, func(t *testing.T) {
  1458  							h := flow.MakeHeader4Tuple(context.Outgoing)
  1459  							mcastAddr := h.Dst.Addr
  1460  							localIfAddr := h.Src.Addr
  1461  
  1462  							var ifoptSet tcpip.MulticastInterfaceOption
  1463  							switch optTyp {
  1464  							case "use local-addr":
  1465  								ifoptSet.InterfaceAddr = localIfAddr
  1466  							case "use NICID":
  1467  								ifoptSet.NIC = 1
  1468  							case "use local-addr and NIC":
  1469  								ifoptSet.InterfaceAddr = localIfAddr
  1470  								ifoptSet.NIC = 1
  1471  							default:
  1472  								t.Fatal("unknown test variant")
  1473  							}
  1474  
  1475  							c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1476  							defer c.Cleanup()
  1477  
  1478  							c.CreateEndpoint(flow.SockProto(), udp.ProtocolNumber)
  1479  
  1480  							if bindTyp == "bound" {
  1481  								// Bind the socket by connecting to the multicast address.
  1482  								// This may have an influence on how the multicast interface
  1483  								// is set.
  1484  								addr := tcpip.FullAddress{
  1485  									Addr: flow.MapAddrIfApplicable(mcastAddr),
  1486  									Port: context.StackPort,
  1487  								}
  1488  								if err := c.EP.Connect(addr); err != nil {
  1489  									c.T.Fatalf("Connect failed: %s", err)
  1490  								}
  1491  							}
  1492  
  1493  							if err := c.EP.SetSockOpt(&ifoptSet); err != nil {
  1494  								c.T.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
  1495  							}
  1496  
  1497  							// Verify multicast interface addr and NIC were set correctly.
  1498  							// Note that NIC must be 1 since this is our outgoing interface.
  1499  							var ifoptGot tcpip.MulticastInterfaceOption
  1500  							if err := c.EP.GetSockOpt(&ifoptGot); err != nil {
  1501  								c.T.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
  1502  							} else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
  1503  								c.T.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
  1504  							}
  1505  						})
  1506  					}
  1507  				})
  1508  			}
  1509  		})
  1510  	}
  1511  }
  1512  
  1513  // TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
  1514  // Unreachable message when a udp datagram is received on ports for which there
  1515  // is no bound udp socket.
  1516  func TestV4UnknownDestination(t *testing.T) {
  1517  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1518  	defer c.Cleanup()
  1519  
  1520  	testCases := []struct {
  1521  		flow         context.TestFlow
  1522  		icmpRequired bool
  1523  		// largePayload if true, will result in a payload large enough
  1524  		// so that the final generated IPv4 packet is larger than
  1525  		// header.IPv4MinimumProcessableDatagramSize.
  1526  		largePayload bool
  1527  		// badChecksum if true, will set an invalid checksum in the
  1528  		// header.
  1529  		badChecksum bool
  1530  	}{
  1531  		{context.UnicastV4, true, false, false},
  1532  		{context.UnicastV4, true, true, false},
  1533  		{context.UnicastV4, false, false, true},
  1534  		{context.UnicastV4, false, true, true},
  1535  		{context.MulticastV4, false, false, false},
  1536  		{context.MulticastV4, false, true, false},
  1537  		{context.Broadcast, false, false, false},
  1538  		{context.Broadcast, false, true, false},
  1539  	}
  1540  	checksumErrors := uint64(0)
  1541  	for _, tc := range testCases {
  1542  		t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
  1543  			payloadSize := arbitraryPayloadSize
  1544  			if tc.largePayload {
  1545  				payloadSize += header.IPv4MinimumProcessableDatagramSize
  1546  			}
  1547  			payload := newRandomPayload(payloadSize)
  1548  			c.InjectPacket(tc.flow.NetProto(), context.BuildUDPPacket(payload, tc.flow, context.Incoming, testTOS, testTTL, tc.badChecksum))
  1549  			if tc.badChecksum {
  1550  				checksumErrors++
  1551  				if got, want := c.Stack.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
  1552  					t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1553  				}
  1554  			}
  1555  			if !tc.icmpRequired {
  1556  				if p := c.LinkEP.Read(); p != nil {
  1557  					t.Fatalf("unexpected packet received: %+v", p)
  1558  				}
  1559  				return
  1560  			}
  1561  
  1562  			// ICMP required.
  1563  			p := c.LinkEP.Read()
  1564  			if p == nil {
  1565  				t.Fatalf("packet wasn't written out")
  1566  			}
  1567  
  1568  			buf := p.ToBuffer()
  1569  			defer buf.Release()
  1570  			p.DecRef()
  1571  			pkt := buf.Flatten()
  1572  			if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
  1573  				t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
  1574  			}
  1575  
  1576  			hdr := buffer.NewViewWithData(pkt)
  1577  			defer hdr.Release()
  1578  			checker.IPv4(t, hdr, checker.ICMPv4(
  1579  				checker.ICMPv4Type(header.ICMPv4DstUnreachable),
  1580  				checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
  1581  
  1582  			// We need to compare the included data part of the UDP packet that is in
  1583  			// the ICMP packet with the matching original data.
  1584  			icmpPkt := header.ICMPv4(header.IPv4(hdr.AsSlice()).Payload())
  1585  			payloadIPHeader := header.IPv4(icmpPkt.Payload())
  1586  			incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
  1587  			wantLen := len(payload)
  1588  			if tc.largePayload {
  1589  				// To work out the data size we need to simulate what the sender would
  1590  				// have done. The wanted size is the total available minus the sum of
  1591  				// the headers in the UDP AND ICMP packets, given that we know the test
  1592  				// had only a minimal IP header but the ICMP sender will have allowed
  1593  				// for a maximally sized packet header.
  1594  				wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
  1595  			}
  1596  
  1597  			// In the case of large payloads the IP packet may be truncated. Update
  1598  			// the length field before retrieving the udp datagram payload.
  1599  			// Add back the two headers within the payload.
  1600  			payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
  1601  
  1602  			origDgram := header.UDP(payloadIPHeader.Payload())
  1603  			if got, want := len(origDgram.Payload()), wantLen; got != want {
  1604  				t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
  1605  			}
  1606  			if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
  1607  				t.Fatalf("unexpected payload got: %d, want: %d", got, want)
  1608  			}
  1609  		})
  1610  	}
  1611  }
  1612  
  1613  // TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
  1614  // Unreachable message when a udp datagram is received on ports for which there
  1615  // is no bound udp socket.
  1616  func TestV6UnknownDestination(t *testing.T) {
  1617  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1618  	defer c.Cleanup()
  1619  
  1620  	testCases := []struct {
  1621  		flow         context.TestFlow
  1622  		icmpRequired bool
  1623  		// largePayload if true will result in a payload large enough to
  1624  		// create an IPv6 packet > header.IPv6MinimumMTU bytes.
  1625  		largePayload bool
  1626  		// badChecksum if true, will set an invalid checksum in the
  1627  		// header.
  1628  		badChecksum bool
  1629  	}{
  1630  		{context.UnicastV6, true, false, false},
  1631  		{context.UnicastV6, true, true, false},
  1632  		{context.UnicastV6, false, false, true},
  1633  		{context.UnicastV6, false, true, true},
  1634  		{context.MulticastV6, false, false, false},
  1635  		{context.MulticastV6, false, true, false},
  1636  	}
  1637  	checksumErrors := uint64(0)
  1638  	for _, tc := range testCases {
  1639  		t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
  1640  			payloadSize := arbitraryPayloadSize
  1641  			if tc.largePayload {
  1642  				payloadSize += header.IPv6MinimumMTU
  1643  			}
  1644  			payload := newRandomPayload(payloadSize)
  1645  			c.InjectPacket(tc.flow.NetProto(), context.BuildUDPPacket(payload, tc.flow, context.Incoming, testTOS, testTTL, tc.badChecksum))
  1646  			if tc.badChecksum {
  1647  				checksumErrors++
  1648  				if got, want := c.Stack.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
  1649  					t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1650  				}
  1651  			}
  1652  			if !tc.icmpRequired {
  1653  				if p := c.LinkEP.Read(); p != nil {
  1654  					t.Fatalf("unexpected packet received: %+v", p)
  1655  				}
  1656  				return
  1657  			}
  1658  
  1659  			// ICMP required.
  1660  			p := c.LinkEP.Read()
  1661  			if p == nil {
  1662  				t.Fatalf("packet wasn't written out")
  1663  			}
  1664  
  1665  			buf := p.ToBuffer()
  1666  			defer buf.Release()
  1667  			p.DecRef()
  1668  			pkt := buf.Flatten()
  1669  			if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
  1670  				t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
  1671  			}
  1672  
  1673  			hdr := buffer.NewViewWithData(pkt)
  1674  			defer hdr.Release()
  1675  			checker.IPv6(t, hdr, checker.ICMPv6(
  1676  				checker.ICMPv6Type(header.ICMPv6DstUnreachable),
  1677  				checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
  1678  
  1679  			icmpPkt := header.ICMPv6(header.IPv6(hdr.AsSlice()).Payload())
  1680  			payloadIPHeader := header.IPv6(icmpPkt.Payload())
  1681  			wantLen := len(payload)
  1682  			if tc.largePayload {
  1683  				wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
  1684  			}
  1685  			// In case of large payloads the IP packet may be truncated. Update
  1686  			// the length field before retrieving the udp datagram payload.
  1687  			payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
  1688  
  1689  			origDgram := header.UDP(payloadIPHeader.Payload())
  1690  			if got, want := len(origDgram.Payload()), wantLen; got != want {
  1691  				t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
  1692  			}
  1693  			if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
  1694  				t.Fatalf("unexpected payload got: %v, want: %v", got, want)
  1695  			}
  1696  		})
  1697  	}
  1698  }
  1699  
  1700  // TestIncrementMalformedPacketsReceived verifies if the malformed received
  1701  // global and endpoint stats are incremented.
  1702  func TestIncrementMalformedPacketsReceived(t *testing.T) {
  1703  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1704  	defer c.Cleanup()
  1705  
  1706  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1707  	// Bind to wildcard.
  1708  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1709  		c.T.Fatalf("Bind failed: %s", err)
  1710  	}
  1711  
  1712  	payload := newRandomPayload(arbitraryPayloadSize)
  1713  	h := context.UnicastV6.MakeHeader4Tuple(context.Incoming)
  1714  	buf := context.BuildV6UDPPacket(payload, h, testTOS, testTTL, false)
  1715  
  1716  	// Invalidate the UDP header length field.
  1717  	u := header.UDP(buf[header.IPv6MinimumSize:])
  1718  	u.SetLength(u.Length() + 1)
  1719  	c.InjectPacket(header.IPv6ProtocolNumber, buf)
  1720  
  1721  	const want = 1
  1722  	if got := c.Stack.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
  1723  		t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
  1724  	}
  1725  	if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
  1726  		t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
  1727  	}
  1728  }
  1729  
  1730  // TestShortHeader verifies that when a packet with a too-short UDP header is
  1731  // received, the malformed received global stat gets incremented.
  1732  func TestShortHeader(t *testing.T) {
  1733  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1734  	defer c.Cleanup()
  1735  
  1736  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1737  	// Bind to wildcard.
  1738  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1739  		c.T.Fatalf("Bind failed: %s", err)
  1740  	}
  1741  
  1742  	h := context.UnicastV6.MakeHeader4Tuple(context.Incoming)
  1743  
  1744  	// Allocate a buffer for an IPv6 and too-short UDP header.
  1745  	const udpSize = header.UDPMinimumSize - 1
  1746  	buf := make([]byte, header.IPv6MinimumSize+udpSize)
  1747  	// Initialize the IP header.
  1748  	ip := header.IPv6(buf)
  1749  	ip.Encode(&header.IPv6Fields{
  1750  		TrafficClass:      testTOS,
  1751  		PayloadLength:     uint16(udpSize),
  1752  		TransportProtocol: udp.ProtocolNumber,
  1753  		HopLimit:          testTTL,
  1754  		SrcAddr:           h.Src.Addr,
  1755  		DstAddr:           h.Dst.Addr,
  1756  	})
  1757  
  1758  	// Initialize the UDP header.
  1759  	udpHdr := header.UDP(make([]byte, header.UDPMinimumSize))
  1760  	udpHdr.Encode(&header.UDPFields{
  1761  		SrcPort: h.Src.Port,
  1762  		DstPort: h.Dst.Port,
  1763  		Length:  header.UDPMinimumSize,
  1764  	})
  1765  	// Calculate the UDP pseudo-header checksum.
  1766  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.Src.Addr, h.Dst.Addr, uint16(len(udpHdr)))
  1767  	udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
  1768  	// Copy all but the last byte of the UDP header into the packet.
  1769  	copy(buf[header.IPv6MinimumSize:], udpHdr)
  1770  
  1771  	// Inject packet.
  1772  	c.InjectPacket(header.IPv6ProtocolNumber, buf)
  1773  
  1774  	if got, want := c.Stack.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want {
  1775  		t.Errorf("got c.Stack.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want)
  1776  	}
  1777  }
  1778  
  1779  // TestBadChecksumErrors verifies if a checksum error is detected,
  1780  // global and endpoint stats are incremented.
  1781  func TestBadChecksumErrors(t *testing.T) {
  1782  	for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6} {
  1783  		t.Run(flow.String(), func(t *testing.T) {
  1784  			c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1785  			defer c.Cleanup()
  1786  
  1787  			c.CreateEndpoint(flow.SockProto(), udp.ProtocolNumber)
  1788  			// Bind to wildcard.
  1789  			if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1790  				c.T.Fatalf("Bind failed: %s", err)
  1791  			}
  1792  
  1793  			c.InjectPacket(flow.NetProto(), context.BuildUDPPacket(newRandomPayload(arbitraryPayloadSize), flow, context.Incoming, testTOS, testTTL, true))
  1794  
  1795  			const want = 1
  1796  			if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want {
  1797  				t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1798  			}
  1799  			if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  1800  				t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  1801  			}
  1802  		})
  1803  	}
  1804  }
  1805  
  1806  // TestPayloadModifiedV4 verifies if a checksum error is detected,
  1807  // global and endpoint stats are incremented.
  1808  func TestPayloadModifiedV4(t *testing.T) {
  1809  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1810  	defer c.Cleanup()
  1811  
  1812  	c.CreateEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber)
  1813  	// Bind to wildcard.
  1814  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1815  		c.T.Fatalf("Bind failed: %s", err)
  1816  	}
  1817  
  1818  	payload := newRandomPayload(arbitraryPayloadSize)
  1819  	h := context.UnicastV4.MakeHeader4Tuple(context.Incoming)
  1820  	buf := context.BuildV4UDPPacket(payload, h, testTOS, testTTL, false)
  1821  	// Modify the payload so that the checksum value in the UDP header will be
  1822  	// incorrect.
  1823  	buf[len(buf)-1]++
  1824  	c.InjectPacket(header.IPv4ProtocolNumber, buf)
  1825  
  1826  	const want = 1
  1827  	if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want {
  1828  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1829  	}
  1830  	if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  1831  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  1832  	}
  1833  }
  1834  
  1835  // TestPayloadModifiedV6 verifies if a checksum error is detected,
  1836  // global and endpoint stats are incremented.
  1837  func TestPayloadModifiedV6(t *testing.T) {
  1838  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1839  	defer c.Cleanup()
  1840  
  1841  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1842  	// Bind to wildcard.
  1843  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1844  		c.T.Fatalf("Bind failed: %s", err)
  1845  	}
  1846  
  1847  	payload := newRandomPayload(arbitraryPayloadSize)
  1848  	h := context.UnicastV6.MakeHeader4Tuple(context.Incoming)
  1849  	buf := context.BuildV6UDPPacket(payload, h, testTOS, testTTL, false)
  1850  	// Modify the payload so that the checksum value in the UDP header will be
  1851  	// incorrect.
  1852  	buf[len(buf)-1]++
  1853  	c.InjectPacket(header.IPv6ProtocolNumber, buf)
  1854  
  1855  	const want = 1
  1856  	if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want {
  1857  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1858  	}
  1859  	if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  1860  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  1861  	}
  1862  }
  1863  
  1864  // TestChecksumZeroV4 verifies if the checksum value is zero, global and
  1865  // endpoint states are *not* incremented (UDP checksum is optional on IPv4).
  1866  func TestChecksumZeroV4(t *testing.T) {
  1867  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1868  	defer c.Cleanup()
  1869  
  1870  	c.CreateEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber)
  1871  	// Bind to wildcard.
  1872  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1873  		c.T.Fatalf("Bind failed: %s", err)
  1874  	}
  1875  
  1876  	payload := newRandomPayload(arbitraryPayloadSize)
  1877  	h := context.UnicastV4.MakeHeader4Tuple(context.Incoming)
  1878  	buf := context.BuildV4UDPPacket(payload, h, testTOS, testTTL, false)
  1879  	// Set the checksum field in the UDP header to zero.
  1880  	u := header.UDP(buf[header.IPv4MinimumSize:])
  1881  	u.SetChecksum(0)
  1882  	c.InjectPacket(header.IPv4ProtocolNumber, buf)
  1883  
  1884  	const want = 0
  1885  	if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want {
  1886  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1887  	}
  1888  	if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  1889  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  1890  	}
  1891  }
  1892  
  1893  // TestChecksumZeroV6 verifies if the checksum value is zero, global and
  1894  // endpoint states are incremented (UDP checksum is *not* optional on IPv6).
  1895  func TestChecksumZeroV6(t *testing.T) {
  1896  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1897  	defer c.Cleanup()
  1898  
  1899  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1900  	// Bind to wildcard.
  1901  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1902  		c.T.Fatalf("Bind failed: %s", err)
  1903  	}
  1904  
  1905  	payload := newRandomPayload(arbitraryPayloadSize)
  1906  	h := context.UnicastV6.MakeHeader4Tuple(context.Incoming)
  1907  	buf := context.BuildV6UDPPacket(payload, h, testTOS, testTTL, false)
  1908  	// Set the checksum field in the UDP header to zero.
  1909  	u := header.UDP(buf[header.IPv6MinimumSize:])
  1910  	u.SetChecksum(0)
  1911  	c.InjectPacket(header.IPv6ProtocolNumber, buf)
  1912  
  1913  	const want = 1
  1914  	if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want {
  1915  		t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
  1916  	}
  1917  	if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
  1918  		t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
  1919  	}
  1920  }
  1921  
  1922  // TestShutdownRead verifies endpoint read shutdown and error
  1923  // stats increment on packet receive.
  1924  func TestShutdownRead(t *testing.T) {
  1925  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1926  	defer c.Cleanup()
  1927  
  1928  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1929  
  1930  	// Bind to wildcard.
  1931  	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
  1932  		c.T.Fatalf("Bind failed: %s", err)
  1933  	}
  1934  
  1935  	if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
  1936  		c.T.Fatalf("Connect failed: %s", err)
  1937  	}
  1938  
  1939  	if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
  1940  		t.Fatalf("Shutdown failed: %s", err)
  1941  	}
  1942  
  1943  	testFailingRead(c, context.UnicastV6, true /* expectReadError */)
  1944  
  1945  	var want uint64 = 1
  1946  	if got := c.Stack.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
  1947  		t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
  1948  	}
  1949  	if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
  1950  		t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
  1951  	}
  1952  }
  1953  
  1954  // TestShutdownWrite verifies endpoint write shutdown and error
  1955  // stats increment on packet write.
  1956  func TestShutdownWrite(t *testing.T) {
  1957  	for _, writeOpSequence := range writeOpSequences {
  1958  		c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  1959  
  1960  		c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  1961  
  1962  		if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
  1963  			c.T.Fatalf("Connect failed: %s", err)
  1964  		}
  1965  
  1966  		if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
  1967  			t.Fatalf("Shutdown failed: %s", err)
  1968  		}
  1969  
  1970  		testWriteOpSequenceFails(c, context.UnicastV6, writeOpSequence, &tcpip.ErrClosedForSend{})
  1971  		c.Cleanup()
  1972  	}
  1973  }
  1974  
  1975  func TestOutgoingSubnetBroadcast(t *testing.T) {
  1976  	const nicID1 = 1
  1977  
  1978  	ipv4Addr := tcpip.AddressWithPrefix{
  1979  		Address:   tcpip.AddrFromSlice([]byte("\xc0\xa8\x01\x3a")),
  1980  		PrefixLen: 24,
  1981  	}
  1982  	ipv4Subnet := ipv4Addr.Subnet()
  1983  	ipv4SubnetBcast := ipv4Subnet.Broadcast()
  1984  	ipv4Gateway := testutil.MustParse4("192.168.1.1")
  1985  	ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
  1986  		Address:   tcpip.AddrFromSlice([]byte("\xc0\xa8\x01\x3a")),
  1987  		PrefixLen: 31,
  1988  	}
  1989  	ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
  1990  	ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
  1991  	ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
  1992  		Address:   tcpip.AddrFromSlice([]byte("\xc0\xa8\x01\x3a")),
  1993  		PrefixLen: 32,
  1994  	}
  1995  	ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
  1996  	ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
  1997  	ipv6Addr := tcpip.AddressWithPrefix{
  1998  		Address:   tcpip.AddrFromSlice([]byte("\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")),
  1999  		PrefixLen: 64,
  2000  	}
  2001  	ipv6Subnet := ipv6Addr.Subnet()
  2002  	ipv6SubnetBcast := ipv6Subnet.Broadcast()
  2003  	remNetAddr := tcpip.AddressWithPrefix{
  2004  		Address:   tcpip.AddrFromSlice([]byte("\x64\x0a\x7b\x18")),
  2005  		PrefixLen: 24,
  2006  	}
  2007  	remNetSubnet := remNetAddr.Subnet()
  2008  	remNetSubnetBcast := remNetSubnet.Broadcast()
  2009  
  2010  	tests := []struct {
  2011  		name                 string
  2012  		nicAddr              tcpip.ProtocolAddress
  2013  		routes               []tcpip.Route
  2014  		remoteAddr           tcpip.Address
  2015  		requiresBroadcastOpt bool
  2016  	}{
  2017  		{
  2018  			name: "IPv4 Broadcast to local subnet",
  2019  			nicAddr: tcpip.ProtocolAddress{
  2020  				Protocol:          header.IPv4ProtocolNumber,
  2021  				AddressWithPrefix: ipv4Addr,
  2022  			},
  2023  			routes: []tcpip.Route{
  2024  				{
  2025  					Destination: ipv4Subnet,
  2026  					NIC:         nicID1,
  2027  				},
  2028  			},
  2029  			remoteAddr:           ipv4SubnetBcast,
  2030  			requiresBroadcastOpt: true,
  2031  		},
  2032  		{
  2033  			name: "IPv4 Broadcast to local /31 subnet",
  2034  			nicAddr: tcpip.ProtocolAddress{
  2035  				Protocol:          header.IPv4ProtocolNumber,
  2036  				AddressWithPrefix: ipv4AddrPrefix31,
  2037  			},
  2038  			routes: []tcpip.Route{
  2039  				{
  2040  					Destination: ipv4Subnet31,
  2041  					NIC:         nicID1,
  2042  				},
  2043  			},
  2044  			remoteAddr:           ipv4Subnet31Bcast,
  2045  			requiresBroadcastOpt: false,
  2046  		},
  2047  		{
  2048  			name: "IPv4 Broadcast to local /32 subnet",
  2049  			nicAddr: tcpip.ProtocolAddress{
  2050  				Protocol:          header.IPv4ProtocolNumber,
  2051  				AddressWithPrefix: ipv4AddrPrefix32,
  2052  			},
  2053  			routes: []tcpip.Route{
  2054  				{
  2055  					Destination: ipv4Subnet32,
  2056  					NIC:         nicID1,
  2057  				},
  2058  			},
  2059  			remoteAddr:           ipv4Subnet32Bcast,
  2060  			requiresBroadcastOpt: false,
  2061  		},
  2062  		// IPv6 has no notion of a broadcast.
  2063  		{
  2064  			name: "IPv6 'Broadcast' to local subnet",
  2065  			nicAddr: tcpip.ProtocolAddress{
  2066  				Protocol:          header.IPv6ProtocolNumber,
  2067  				AddressWithPrefix: ipv6Addr,
  2068  			},
  2069  			routes: []tcpip.Route{
  2070  				{
  2071  					Destination: ipv6Subnet,
  2072  					NIC:         nicID1,
  2073  				},
  2074  			},
  2075  			remoteAddr:           ipv6SubnetBcast,
  2076  			requiresBroadcastOpt: false,
  2077  		},
  2078  		{
  2079  			name: "IPv4 Broadcast to remote subnet",
  2080  			nicAddr: tcpip.ProtocolAddress{
  2081  				Protocol:          header.IPv4ProtocolNumber,
  2082  				AddressWithPrefix: ipv4Addr,
  2083  			},
  2084  			routes: []tcpip.Route{
  2085  				{
  2086  					Destination: remNetSubnet,
  2087  					Gateway:     ipv4Gateway,
  2088  					NIC:         nicID1,
  2089  				},
  2090  			},
  2091  			remoteAddr: remNetSubnetBcast,
  2092  			// TODO(gvisor.dev/issue/3938): Once we support marking a route as
  2093  			// broadcast, this test should require the broadcast option to be set.
  2094  			requiresBroadcastOpt: false,
  2095  		},
  2096  	}
  2097  
  2098  	for _, test := range tests {
  2099  		t.Run(test.name, func(t *testing.T) {
  2100  			s := stack.New(stack.Options{
  2101  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  2102  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
  2103  				Clock:              &faketime.NullClock{},
  2104  			})
  2105  			defer s.Destroy()
  2106  			e := channel.New(0, context.DefaultMTU, "")
  2107  			defer e.Close()
  2108  			if err := s.CreateNIC(nicID1, e); err != nil {
  2109  				t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
  2110  			}
  2111  			if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
  2112  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
  2113  			}
  2114  
  2115  			s.SetRouteTable(test.routes)
  2116  
  2117  			var netProto tcpip.NetworkProtocolNumber
  2118  			switch l := test.remoteAddr.Len(); l {
  2119  			case header.IPv4AddressSize:
  2120  				netProto = header.IPv4ProtocolNumber
  2121  			case header.IPv6AddressSize:
  2122  				netProto = header.IPv6ProtocolNumber
  2123  			default:
  2124  				t.Fatalf("got unexpected address length = %d bytes", l)
  2125  			}
  2126  
  2127  			wq := waiter.Queue{}
  2128  			ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
  2129  			if err != nil {
  2130  				t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
  2131  			}
  2132  			defer ep.Close()
  2133  
  2134  			var r bytes.Reader
  2135  			data := []byte{1, 2, 3, 4}
  2136  			to := tcpip.FullAddress{
  2137  				Addr: test.remoteAddr,
  2138  				Port: 80,
  2139  			}
  2140  			opts := tcpip.WriteOptions{To: &to}
  2141  			expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
  2142  				if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
  2143  					return nil
  2144  				}
  2145  				return &tcpip.ErrBroadcastDisabled{}
  2146  			}
  2147  			if !test.requiresBroadcastOpt {
  2148  				expectedErrWithoutBcastOpt = nil
  2149  			}
  2150  
  2151  			r.Reset(data)
  2152  			{
  2153  				n, err := ep.Write(&r, opts)
  2154  				if expectedErrWithoutBcastOpt != nil {
  2155  					if want := expectedErrWithoutBcastOpt(err); want != nil {
  2156  						t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
  2157  					}
  2158  				} else if err != nil {
  2159  					t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
  2160  				}
  2161  			}
  2162  
  2163  			ep.SocketOptions().SetBroadcast(true)
  2164  
  2165  			r.Reset(data)
  2166  			if n, err := ep.Write(&r, opts); err != nil {
  2167  				t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
  2168  			}
  2169  
  2170  			ep.SocketOptions().SetBroadcast(false)
  2171  
  2172  			r.Reset(data)
  2173  			{
  2174  				n, err := ep.Write(&r, opts)
  2175  				if expectedErrWithoutBcastOpt != nil {
  2176  					if want := expectedErrWithoutBcastOpt(err); want != nil {
  2177  						t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
  2178  					}
  2179  				} else if err != nil {
  2180  					t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
  2181  				}
  2182  			}
  2183  		})
  2184  	}
  2185  }
  2186  
  2187  func TestChecksumWithZeroValueOnesComplementSum(t *testing.T) {
  2188  	c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol})
  2189  	defer c.Cleanup()
  2190  
  2191  	c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  2192  	var writeOpts tcpip.WriteOptions
  2193  	h := context.UnicastV6.MakeHeader4Tuple(context.Outgoing)
  2194  	writeDstAddr := context.UnicastV6.MapAddrIfApplicable(h.Dst.Addr)
  2195  	writeOpts = tcpip.WriteOptions{
  2196  		To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port},
  2197  	}
  2198  
  2199  	// Write a packet to calculate what the checksum value will be with a zero
  2200  	// value payload. We will then take that checksum value to construct another
  2201  	// packet which would result in the ones complement of the packet to be zero.
  2202  	var payload [2]byte
  2203  	{
  2204  		var r bytes.Reader
  2205  		r.Reset(payload[:])
  2206  		n, err := c.EP.Write(&r, writeOpts)
  2207  		if err != nil {
  2208  			t.Fatalf("Write failed: %s", err)
  2209  		}
  2210  		if want := int64(len(payload)); n != want {
  2211  			t.Fatalf("got n = %d, want = %d", n, want)
  2212  		}
  2213  
  2214  		pkt := c.LinkEP.Read()
  2215  		if pkt == nil {
  2216  			t.Fatal("Packet wasn't written out")
  2217  		}
  2218  
  2219  		v := stack.PayloadSince(pkt.NetworkHeader())
  2220  		defer v.Release()
  2221  		pkt.DecRef()
  2222  		checker.IPv6(t, v, checker.UDP())
  2223  
  2224  		// Simply replacing the payload with the checksum value is enough to make
  2225  		// sure that we end up with an all ones value for the ones complement sum
  2226  		// because the checksum value is held the ones complement of the ones
  2227  		// complement sum.
  2228  		//
  2229  		// In ones complement arithmetic, adding a value A with a ones complement of
  2230  		// another value B is the same as subtracting B from A.
  2231  		//
  2232  		// The resulting ones complement will be  C' = C - C so we know C' will be
  2233  		// zero. The stack should never send a zero value though so we expect all
  2234  		// ones below.
  2235  		binary.BigEndian.PutUint16(payload[:], header.UDP(header.IPv6(v.AsSlice()).Payload()).Checksum())
  2236  	}
  2237  
  2238  	{
  2239  		var r bytes.Reader
  2240  		r.Reset(payload[:])
  2241  		n, err := c.EP.Write(&r, writeOpts)
  2242  		if err != nil {
  2243  			t.Fatalf("Write failed: %s", err)
  2244  		}
  2245  		if want := int64(len(payload)); n != want {
  2246  			t.Fatalf("got n = %d, want = %d", n, want)
  2247  		}
  2248  	}
  2249  
  2250  	{
  2251  		pkt := c.LinkEP.Read()
  2252  		if pkt == nil {
  2253  			t.Fatal("Packet wasn't written out")
  2254  		}
  2255  		defer pkt.DecRef()
  2256  
  2257  		v := stack.PayloadSince(pkt.NetworkHeader())
  2258  		defer v.Release()
  2259  		checker.IPv6(t, v, checker.UDP(checker.TransportChecksum(math.MaxUint16)))
  2260  
  2261  		// Make sure the all ones checksum is valid.
  2262  		hdr := header.IPv6(v.AsSlice())
  2263  		udp := header.UDP(hdr.Payload())
  2264  		if src, dst, payloadXsum := hdr.SourceAddress(), hdr.DestinationAddress(), checksum.Checksum(udp.Payload(), 0); !udp.IsChecksumValid(src, dst, payloadXsum) {
  2265  			t.Errorf("got udp.IsChecksumValid(%s, %s, %d) = false, want = true", src, dst, payloadXsum)
  2266  		}
  2267  	}
  2268  }
  2269  
  2270  // TestWritePayloadSizeTooBig verifies that writing anything bigger than
  2271  // header.UDPMaximumPacketSize fails.
  2272  func TestWritePayloadSizeTooBig(t *testing.T) {
  2273  	for _, writeOpSequence := range writeOpSequences {
  2274  		c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
  2275  
  2276  		c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
  2277  
  2278  		if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
  2279  			c.T.Fatalf("Connect failed: %s", err)
  2280  		}
  2281  
  2282  		testWriteOpSequenceSucceeds(c, context.UnicastV6, writeOpSequence)
  2283  
  2284  		for _, writeOp := range writeOpSequence {
  2285  			switch writeOp {
  2286  			case write:
  2287  				testWriteFails(c, context.UnicastV6, header.UDPMaximumPacketSize+1, &tcpip.ErrMessageTooLong{})
  2288  			case preflight:
  2289  				testPreflightSucceeds(c, context.UnicastV6)
  2290  			}
  2291  		}
  2292  		c.Cleanup()
  2293  	}
  2294  }
  2295  
  2296  func TestMain(m *testing.M) {
  2297  	refs.SetLeakMode(refs.LeaksPanic)
  2298  	code := m.Run()
  2299  	refs.DoLeakCheck()
  2300  	os.Exit(code)
  2301  }