github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/tests/integration/multicast_broadcast_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package multicast_broadcast_test
    16  
    17  import (
    18  	"bytes"
    19  	"testing"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	"github.com/SagerNet/gvisor/pkg/tcpip"
    23  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/link/loopback"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/tests/utils"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    35  	"github.com/SagerNet/gvisor/pkg/waiter"
    36  )
    37  
    38  const (
    39  	defaultMTU = 1280
    40  	ttl        = 255
    41  )
    42  
    43  // TestPingMulticastBroadcast tests that responding to an Echo Request destined
    44  // to a multicast or broadcast address uses a unicast source address for the
    45  // reply.
    46  func TestPingMulticastBroadcast(t *testing.T) {
    47  	const (
    48  		nicID = 1
    49  		ttl   = 64
    50  	)
    51  
    52  	tests := []struct {
    53  		name        string
    54  		protoNum    tcpip.NetworkProtocolNumber
    55  		rxICMP      func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
    56  		srcAddr     tcpip.Address
    57  		dstAddr     tcpip.Address
    58  		expectedSrc tcpip.Address
    59  	}{
    60  		{
    61  			name:        "IPv4 unicast",
    62  			protoNum:    header.IPv4ProtocolNumber,
    63  			dstAddr:     utils.Ipv4Addr.Address,
    64  			srcAddr:     utils.RemoteIPv4Addr,
    65  			rxICMP:      utils.RxICMPv4EchoRequest,
    66  			expectedSrc: utils.Ipv4Addr.Address,
    67  		},
    68  		{
    69  			name:        "IPv4 directed broadcast",
    70  			protoNum:    header.IPv4ProtocolNumber,
    71  			rxICMP:      utils.RxICMPv4EchoRequest,
    72  			srcAddr:     utils.RemoteIPv4Addr,
    73  			dstAddr:     utils.Ipv4SubnetBcast,
    74  			expectedSrc: utils.Ipv4Addr.Address,
    75  		},
    76  		{
    77  			name:        "IPv4 broadcast",
    78  			protoNum:    header.IPv4ProtocolNumber,
    79  			rxICMP:      utils.RxICMPv4EchoRequest,
    80  			srcAddr:     utils.RemoteIPv4Addr,
    81  			dstAddr:     header.IPv4Broadcast,
    82  			expectedSrc: utils.Ipv4Addr.Address,
    83  		},
    84  		{
    85  			name:        "IPv4 all-systems multicast",
    86  			protoNum:    header.IPv4ProtocolNumber,
    87  			rxICMP:      utils.RxICMPv4EchoRequest,
    88  			srcAddr:     utils.RemoteIPv4Addr,
    89  			dstAddr:     header.IPv4AllSystems,
    90  			expectedSrc: utils.Ipv4Addr.Address,
    91  		},
    92  		{
    93  			name:        "IPv6 unicast",
    94  			protoNum:    header.IPv6ProtocolNumber,
    95  			rxICMP:      utils.RxICMPv6EchoRequest,
    96  			srcAddr:     utils.RemoteIPv6Addr,
    97  			dstAddr:     utils.Ipv6Addr.Address,
    98  			expectedSrc: utils.Ipv6Addr.Address,
    99  		},
   100  		{
   101  			name:        "IPv6 all-nodes multicast",
   102  			protoNum:    header.IPv6ProtocolNumber,
   103  			rxICMP:      utils.RxICMPv6EchoRequest,
   104  			srcAddr:     utils.RemoteIPv6Addr,
   105  			dstAddr:     header.IPv6AllNodesMulticastAddress,
   106  			expectedSrc: utils.Ipv6Addr.Address,
   107  		},
   108  	}
   109  
   110  	for _, test := range tests {
   111  		t.Run(test.name, func(t *testing.T) {
   112  			s := stack.New(stack.Options{
   113  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   114  				TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
   115  			})
   116  			// We only expect a single packet in response to our ICMP Echo Request.
   117  			e := channel.New(1, defaultMTU, "")
   118  			if err := s.CreateNIC(nicID, e); err != nil {
   119  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   120  			}
   121  			ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
   122  			if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
   123  				t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
   124  			}
   125  			ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
   126  			if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
   127  				t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
   128  			}
   129  
   130  			// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
   131  			// node when attempting to send the ICMP Echo Reply.
   132  			s.SetRouteTable([]tcpip.Route{
   133  				{
   134  					Destination: header.IPv6EmptySubnet,
   135  					NIC:         nicID,
   136  				},
   137  				{
   138  					Destination: header.IPv4EmptySubnet,
   139  					NIC:         nicID,
   140  				},
   141  			})
   142  
   143  			test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
   144  			pkt, ok := e.Read()
   145  			if !ok {
   146  				t.Fatal("expected ICMP response")
   147  			}
   148  
   149  			if pkt.Route.LocalAddress != test.expectedSrc {
   150  				t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc)
   151  			}
   152  			// The destination of the response packet should be the source of the
   153  			// original packet.
   154  			if pkt.Route.RemoteAddress != test.srcAddr {
   155  				t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr)
   156  			}
   157  
   158  			src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
   159  			if src != test.expectedSrc {
   160  				t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
   161  			}
   162  			// The destination of the response packet should be the source of the
   163  			// original packet.
   164  			if dst != test.srcAddr {
   165  				t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
   166  			}
   167  		})
   168  	}
   169  
   170  }
   171  
   172  func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
   173  	payloadLen := header.UDPMinimumSize + len(data)
   174  	totalLen := header.IPv4MinimumSize + payloadLen
   175  	hdr := buffer.NewPrependable(totalLen)
   176  	u := header.UDP(hdr.Prepend(payloadLen))
   177  	u.Encode(&header.UDPFields{
   178  		SrcPort: utils.RemotePort,
   179  		DstPort: utils.LocalPort,
   180  		Length:  uint16(payloadLen),
   181  	})
   182  	copy(u.Payload(), data)
   183  	sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
   184  	sum = header.Checksum(data, sum)
   185  	u.SetChecksum(^u.CalculateChecksum(sum))
   186  
   187  	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
   188  	ip.Encode(&header.IPv4Fields{
   189  		TotalLength: uint16(totalLen),
   190  		Protocol:    uint8(udp.ProtocolNumber),
   191  		TTL:         ttl,
   192  		SrcAddr:     src,
   193  		DstAddr:     dst,
   194  	})
   195  	ip.SetChecksum(^ip.CalculateChecksum())
   196  
   197  	e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   198  		Data: hdr.View().ToVectorisedView(),
   199  	}))
   200  }
   201  
   202  func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
   203  	payloadLen := header.UDPMinimumSize + len(data)
   204  	hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
   205  	u := header.UDP(hdr.Prepend(payloadLen))
   206  	u.Encode(&header.UDPFields{
   207  		SrcPort: utils.RemotePort,
   208  		DstPort: utils.LocalPort,
   209  		Length:  uint16(payloadLen),
   210  	})
   211  	copy(u.Payload(), data)
   212  	sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
   213  	sum = header.Checksum(data, sum)
   214  	u.SetChecksum(^u.CalculateChecksum(sum))
   215  
   216  	ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   217  	ip.Encode(&header.IPv6Fields{
   218  		PayloadLength:     uint16(payloadLen),
   219  		TransportProtocol: udp.ProtocolNumber,
   220  		HopLimit:          ttl,
   221  		SrcAddr:           src,
   222  		DstAddr:           dst,
   223  	})
   224  
   225  	e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   226  		Data: hdr.View().ToVectorisedView(),
   227  	}))
   228  }
   229  
   230  // TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
   231  // multicast or broadcast address.
   232  func TestIncomingMulticastAndBroadcast(t *testing.T) {
   233  	const nicID = 1
   234  
   235  	data := []byte{1, 2, 3, 4}
   236  
   237  	tests := []struct {
   238  		name       string
   239  		proto      tcpip.NetworkProtocolNumber
   240  		remoteAddr tcpip.Address
   241  		localAddr  tcpip.AddressWithPrefix
   242  		rxUDP      func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
   243  		bindAddr   tcpip.Address
   244  		dstAddr    tcpip.Address
   245  		expectRx   bool
   246  	}{
   247  		{
   248  			name:       "IPv4 unicast binding to unicast",
   249  			proto:      header.IPv4ProtocolNumber,
   250  			remoteAddr: utils.RemoteIPv4Addr,
   251  			localAddr:  utils.Ipv4Addr,
   252  			rxUDP:      rxIPv4UDP,
   253  			bindAddr:   utils.Ipv4Addr.Address,
   254  			dstAddr:    utils.Ipv4Addr.Address,
   255  			expectRx:   true,
   256  		},
   257  		{
   258  			name:       "IPv4 unicast binding to broadcast",
   259  			proto:      header.IPv4ProtocolNumber,
   260  			remoteAddr: utils.RemoteIPv4Addr,
   261  			localAddr:  utils.Ipv4Addr,
   262  			rxUDP:      rxIPv4UDP,
   263  			bindAddr:   header.IPv4Broadcast,
   264  			dstAddr:    utils.Ipv4Addr.Address,
   265  			expectRx:   false,
   266  		},
   267  		{
   268  			name:       "IPv4 unicast binding to wildcard",
   269  			proto:      header.IPv4ProtocolNumber,
   270  			remoteAddr: utils.RemoteIPv4Addr,
   271  			localAddr:  utils.Ipv4Addr,
   272  			rxUDP:      rxIPv4UDP,
   273  			dstAddr:    utils.Ipv4Addr.Address,
   274  			expectRx:   true,
   275  		},
   276  
   277  		{
   278  			name:       "IPv4 directed broadcast binding to subnet broadcast",
   279  			proto:      header.IPv4ProtocolNumber,
   280  			remoteAddr: utils.RemoteIPv4Addr,
   281  			localAddr:  utils.Ipv4Addr,
   282  			rxUDP:      rxIPv4UDP,
   283  			bindAddr:   utils.Ipv4SubnetBcast,
   284  			dstAddr:    utils.Ipv4SubnetBcast,
   285  			expectRx:   true,
   286  		},
   287  		{
   288  			name:       "IPv4 directed broadcast binding to broadcast",
   289  			proto:      header.IPv4ProtocolNumber,
   290  			remoteAddr: utils.RemoteIPv4Addr,
   291  			localAddr:  utils.Ipv4Addr,
   292  			rxUDP:      rxIPv4UDP,
   293  			bindAddr:   header.IPv4Broadcast,
   294  			dstAddr:    utils.Ipv4SubnetBcast,
   295  			expectRx:   false,
   296  		},
   297  		{
   298  			name:       "IPv4 directed broadcast binding to wildcard",
   299  			proto:      header.IPv4ProtocolNumber,
   300  			remoteAddr: utils.RemoteIPv4Addr,
   301  			localAddr:  utils.Ipv4Addr,
   302  			rxUDP:      rxIPv4UDP,
   303  			dstAddr:    utils.Ipv4SubnetBcast,
   304  			expectRx:   true,
   305  		},
   306  
   307  		{
   308  			name:       "IPv4 broadcast binding to broadcast",
   309  			proto:      header.IPv4ProtocolNumber,
   310  			remoteAddr: utils.RemoteIPv4Addr,
   311  			localAddr:  utils.Ipv4Addr,
   312  			rxUDP:      rxIPv4UDP,
   313  			bindAddr:   header.IPv4Broadcast,
   314  			dstAddr:    header.IPv4Broadcast,
   315  			expectRx:   true,
   316  		},
   317  		{
   318  			name:       "IPv4 broadcast binding to subnet broadcast",
   319  			proto:      header.IPv4ProtocolNumber,
   320  			remoteAddr: utils.RemoteIPv4Addr,
   321  			localAddr:  utils.Ipv4Addr,
   322  			rxUDP:      rxIPv4UDP,
   323  			bindAddr:   utils.Ipv4SubnetBcast,
   324  			dstAddr:    header.IPv4Broadcast,
   325  			expectRx:   false,
   326  		},
   327  		{
   328  			name:       "IPv4 broadcast binding to wildcard",
   329  			proto:      header.IPv4ProtocolNumber,
   330  			remoteAddr: utils.RemoteIPv4Addr,
   331  			localAddr:  utils.Ipv4Addr,
   332  			rxUDP:      rxIPv4UDP,
   333  			dstAddr:    utils.Ipv4SubnetBcast,
   334  			expectRx:   true,
   335  		},
   336  
   337  		{
   338  			name:       "IPv4 all-systems multicast binding to all-systems multicast",
   339  			proto:      header.IPv4ProtocolNumber,
   340  			remoteAddr: utils.RemoteIPv4Addr,
   341  			localAddr:  utils.Ipv4Addr,
   342  			rxUDP:      rxIPv4UDP,
   343  			bindAddr:   header.IPv4AllSystems,
   344  			dstAddr:    header.IPv4AllSystems,
   345  			expectRx:   true,
   346  		},
   347  		{
   348  			name:       "IPv4 all-systems multicast binding to wildcard",
   349  			proto:      header.IPv4ProtocolNumber,
   350  			remoteAddr: utils.RemoteIPv4Addr,
   351  			localAddr:  utils.Ipv4Addr,
   352  			rxUDP:      rxIPv4UDP,
   353  			dstAddr:    header.IPv4AllSystems,
   354  			expectRx:   true,
   355  		},
   356  		{
   357  			name:       "IPv4 all-systems multicast binding to unicast",
   358  			proto:      header.IPv4ProtocolNumber,
   359  			remoteAddr: utils.RemoteIPv4Addr,
   360  			localAddr:  utils.Ipv4Addr,
   361  			rxUDP:      rxIPv4UDP,
   362  			bindAddr:   utils.Ipv4Addr.Address,
   363  			dstAddr:    header.IPv4AllSystems,
   364  			expectRx:   false,
   365  		},
   366  
   367  		// IPv6 has no notion of a broadcast.
   368  		{
   369  			name:       "IPv6 unicast binding to wildcard",
   370  			dstAddr:    utils.Ipv6Addr.Address,
   371  			proto:      header.IPv6ProtocolNumber,
   372  			remoteAddr: utils.RemoteIPv6Addr,
   373  			localAddr:  utils.Ipv6Addr,
   374  			rxUDP:      rxIPv6UDP,
   375  			expectRx:   true,
   376  		},
   377  		{
   378  			name:       "IPv6 broadcast-like address binding to wildcard",
   379  			dstAddr:    utils.Ipv6SubnetBcast,
   380  			proto:      header.IPv6ProtocolNumber,
   381  			remoteAddr: utils.RemoteIPv6Addr,
   382  			localAddr:  utils.Ipv6Addr,
   383  			rxUDP:      rxIPv6UDP,
   384  			expectRx:   false,
   385  		},
   386  	}
   387  
   388  	for _, test := range tests {
   389  		t.Run(test.name, func(t *testing.T) {
   390  			s := stack.New(stack.Options{
   391  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   392  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   393  			})
   394  			e := channel.New(0, defaultMTU, "")
   395  			if err := s.CreateNIC(nicID, e); err != nil {
   396  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   397  			}
   398  			protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
   399  			if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
   400  				t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
   401  			}
   402  
   403  			var wq waiter.Queue
   404  			ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
   405  			if err != nil {
   406  				t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
   407  			}
   408  			defer ep.Close()
   409  
   410  			bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort}
   411  			if err := ep.Bind(bindAddr); err != nil {
   412  				t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
   413  			}
   414  
   415  			test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
   416  			var buf bytes.Buffer
   417  			var opts tcpip.ReadOptions
   418  			if res, err := ep.Read(&buf, opts); test.expectRx {
   419  				if err != nil {
   420  					t.Fatalf("ep.Read(_, %#v): %s", opts, err)
   421  				}
   422  				if diff := cmp.Diff(tcpip.ReadResult{
   423  					Count: buf.Len(),
   424  					Total: buf.Len(),
   425  				}, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
   426  					t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   427  				}
   428  				if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
   429  					t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
   430  				}
   431  			} else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
   432  				t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
   433  			}
   434  		})
   435  	}
   436  }
   437  
   438  // TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
   439  // interested endpoints.
   440  func TestReuseAddrAndBroadcast(t *testing.T) {
   441  	const (
   442  		nicID     = 1
   443  		localPort = 9000
   444  	)
   445  	loopbackBroadcast := testutil.MustParse4("127.255.255.255")
   446  
   447  	tests := []struct {
   448  		name          string
   449  		broadcastAddr tcpip.Address
   450  	}{
   451  		{
   452  			name:          "Subnet directed broadcast",
   453  			broadcastAddr: loopbackBroadcast,
   454  		},
   455  		{
   456  			name:          "IPv4 broadcast",
   457  			broadcastAddr: header.IPv4Broadcast,
   458  		},
   459  	}
   460  
   461  	for _, test := range tests {
   462  		t.Run(test.name, func(t *testing.T) {
   463  			s := stack.New(stack.Options{
   464  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   465  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   466  			})
   467  			if err := s.CreateNIC(nicID, loopback.New()); err != nil {
   468  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   469  			}
   470  			protoAddr := tcpip.ProtocolAddress{
   471  				Protocol: header.IPv4ProtocolNumber,
   472  				AddressWithPrefix: tcpip.AddressWithPrefix{
   473  					Address:   "\x7f\x00\x00\x01",
   474  					PrefixLen: 8,
   475  				},
   476  			}
   477  			if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
   478  				t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
   479  			}
   480  
   481  			s.SetRouteTable([]tcpip.Route{
   482  				{
   483  					// We use the empty subnet instead of just the loopback subnet so we
   484  					// also have a route to the IPv4 Broadcast address.
   485  					Destination: header.IPv4EmptySubnet,
   486  					NIC:         nicID,
   487  				},
   488  			})
   489  
   490  			type endpointAndWaiter struct {
   491  				ep tcpip.Endpoint
   492  				ch chan struct{}
   493  			}
   494  			var eps []endpointAndWaiter
   495  			// We create endpoints that bind to both the wildcard address and the
   496  			// broadcast address to make sure both of these types of "broadcast
   497  			// interested" endpoints receive broadcast packets.
   498  			for _, bindWildcard := range []bool{false, true} {
   499  				// Create multiple endpoints for each type of "broadcast interested"
   500  				// endpoint so we can test that all endpoints receive the broadcast
   501  				// packet.
   502  				for i := 0; i < 2; i++ {
   503  					var wq waiter.Queue
   504  					we, ch := waiter.NewChannelEntry(nil)
   505  					wq.EventRegister(&we, waiter.ReadableEvents)
   506  					ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   507  					if err != nil {
   508  						t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
   509  					}
   510  					defer ep.Close()
   511  
   512  					ep.SocketOptions().SetReuseAddress(true)
   513  					ep.SocketOptions().SetBroadcast(true)
   514  
   515  					bindAddr := tcpip.FullAddress{Port: localPort}
   516  					if bindWildcard {
   517  						if err := ep.Bind(bindAddr); err != nil {
   518  							t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
   519  						}
   520  					} else {
   521  						bindAddr.Addr = test.broadcastAddr
   522  						if err := ep.Bind(bindAddr); err != nil {
   523  							t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
   524  						}
   525  					}
   526  
   527  					eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
   528  				}
   529  			}
   530  
   531  			for i, wep := range eps {
   532  				writeOpts := tcpip.WriteOptions{
   533  					To: &tcpip.FullAddress{
   534  						Addr: test.broadcastAddr,
   535  						Port: localPort,
   536  					},
   537  				}
   538  				data := []byte{byte(i), 2, 3, 4}
   539  				var r bytes.Reader
   540  				r.Reset(data)
   541  				if n, err := wep.ep.Write(&r, writeOpts); err != nil {
   542  					t.Fatalf("eps[%d].Write(_, _): %s", i, err)
   543  				} else if want := int64(len(data)); n != want {
   544  					t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
   545  				}
   546  
   547  				for j, rep := range eps {
   548  					// Wait for the endpoint to become readable.
   549  					<-rep.ch
   550  
   551  					var buf bytes.Buffer
   552  					result, err := rep.ep.Read(&buf, tcpip.ReadOptions{})
   553  					if err != nil {
   554  						t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
   555  						continue
   556  					}
   557  					if diff := cmp.Diff(tcpip.ReadResult{
   558  						Count: buf.Len(),
   559  						Total: buf.Len(),
   560  					}, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
   561  						t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
   562  					}
   563  					if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
   564  						t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
   565  					}
   566  				}
   567  			}
   568  		})
   569  	}
   570  }
   571  
   572  func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
   573  	const (
   574  		nicID = 1
   575  	)
   576  
   577  	data := []byte{1, 2, 3, 4}
   578  
   579  	tests := []struct {
   580  		name          string
   581  		proto         tcpip.NetworkProtocolNumber
   582  		remoteAddr    tcpip.Address
   583  		localAddr     tcpip.AddressWithPrefix
   584  		rxUDP         func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
   585  		multicastAddr tcpip.Address
   586  	}{
   587  		{
   588  			name:          "IPv4 unicast binding to unicast",
   589  			multicastAddr: "\xe0\x01\x02\x03",
   590  			proto:         header.IPv4ProtocolNumber,
   591  			remoteAddr:    utils.RemoteIPv4Addr,
   592  			localAddr:     utils.Ipv4Addr,
   593  			rxUDP:         rxIPv4UDP,
   594  		},
   595  		{
   596  			name:          "IPv6 broadcast-like address binding to wildcard",
   597  			multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04",
   598  			proto:         header.IPv6ProtocolNumber,
   599  			remoteAddr:    utils.RemoteIPv6Addr,
   600  			localAddr:     utils.Ipv6Addr,
   601  			rxUDP:         rxIPv6UDP,
   602  		},
   603  	}
   604  
   605  	subTests := []struct {
   606  		name           string
   607  		specifyNICID   bool
   608  		specifyNICAddr bool
   609  	}{
   610  		{
   611  			name:           "Specify NIC ID and NIC address",
   612  			specifyNICID:   true,
   613  			specifyNICAddr: true,
   614  		},
   615  		{
   616  			name:           "Don't specify NIC ID or NIC address",
   617  			specifyNICID:   false,
   618  			specifyNICAddr: false,
   619  		},
   620  		{
   621  			name:           "Specify NIC ID but don't specify NIC address",
   622  			specifyNICID:   true,
   623  			specifyNICAddr: false,
   624  		},
   625  		{
   626  			name:           "Don't specify NIC ID but specify NIC address",
   627  			specifyNICID:   false,
   628  			specifyNICAddr: true,
   629  		},
   630  	}
   631  
   632  	for _, test := range tests {
   633  		t.Run(test.name, func(t *testing.T) {
   634  			for _, subTest := range subTests {
   635  				t.Run(subTest.name, func(t *testing.T) {
   636  					s := stack.New(stack.Options{
   637  						NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   638  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   639  					})
   640  					e := channel.New(0, defaultMTU, "")
   641  					if err := s.CreateNIC(nicID, e); err != nil {
   642  						t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   643  					}
   644  					protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
   645  					if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
   646  						t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
   647  					}
   648  
   649  					// Set the route table so that UDP can find a NIC that is
   650  					// routable to the multicast address when the NIC isn't specified.
   651  					if !subTest.specifyNICID && !subTest.specifyNICAddr {
   652  						s.SetRouteTable([]tcpip.Route{
   653  							{
   654  								Destination: header.IPv6EmptySubnet,
   655  								NIC:         nicID,
   656  							},
   657  							{
   658  								Destination: header.IPv4EmptySubnet,
   659  								NIC:         nicID,
   660  							},
   661  						})
   662  					}
   663  
   664  					var wq waiter.Queue
   665  					ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
   666  					if err != nil {
   667  						t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
   668  					}
   669  					defer ep.Close()
   670  
   671  					bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
   672  					if err := ep.Bind(bindAddr); err != nil {
   673  						t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
   674  					}
   675  
   676  					memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
   677  					if subTest.specifyNICID {
   678  						memOpt.NIC = nicID
   679  					}
   680  					if subTest.specifyNICAddr {
   681  						memOpt.InterfaceAddr = test.localAddr.Address
   682  					}
   683  
   684  					// We should receive UDP packets to the group once we join the
   685  					// multicast group.
   686  					addOpt := tcpip.AddMembershipOption(memOpt)
   687  					if err := ep.SetSockOpt(&addOpt); err != nil {
   688  						t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
   689  					}
   690  					test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
   691  					var buf bytes.Buffer
   692  					result, err := ep.Read(&buf, tcpip.ReadOptions{})
   693  					if err != nil {
   694  						t.Fatalf("ep.Read: %s", err)
   695  					} else {
   696  						if diff := cmp.Diff(tcpip.ReadResult{
   697  							Count: buf.Len(),
   698  							Total: buf.Len(),
   699  						}, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
   700  							t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   701  						}
   702  						if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
   703  							t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
   704  						}
   705  					}
   706  
   707  					// We should not receive UDP packets to the group once we leave
   708  					// the multicast group.
   709  					removeOpt := tcpip.RemoveMembershipOption(memOpt)
   710  					if err := ep.SetSockOpt(&removeOpt); err != nil {
   711  						t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
   712  					}
   713  					{
   714  						_, err := ep.Read(&buf, tcpip.ReadOptions{})
   715  						if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
   716  							t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{})
   717  						}
   718  					}
   719  				})
   720  			}
   721  		})
   722  	}
   723  }