gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/tests/integration/multicast_broadcast_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package multicast_broadcast_test
    16  
    17  import (
    18  	"bytes"
    19  	"testing"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	"gvisor.dev/gvisor/pkg/buffer"
    23  	"gvisor.dev/gvisor/pkg/tcpip"
    24  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    25  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    26  	"gvisor.dev/gvisor/pkg/tcpip/header"
    27  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    28  	"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
    29  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    30  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    31  	"gvisor.dev/gvisor/pkg/tcpip/prependable"
    32  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    33  	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
    34  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    35  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    36  	"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
    37  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    38  	"gvisor.dev/gvisor/pkg/waiter"
    39  )
    40  
    41  const (
    42  	defaultMTU = 1280
    43  	ttl        = 255
    44  )
    45  
    46  // TestPingMulticastBroadcast tests that responding to an Echo Request destined
    47  // to a multicast or broadcast address uses a unicast source address for the
    48  // reply.
    49  func TestPingMulticastBroadcast(t *testing.T) {
    50  	const (
    51  		nicID = 1
    52  		ttl   = 64
    53  	)
    54  
    55  	tests := []struct {
    56  		name        string
    57  		protoNum    tcpip.NetworkProtocolNumber
    58  		rxICMP      func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
    59  		srcAddr     tcpip.Address
    60  		dstAddr     tcpip.Address
    61  		expectedSrc tcpip.Address
    62  	}{
    63  		{
    64  			name:        "IPv4 unicast",
    65  			protoNum:    header.IPv4ProtocolNumber,
    66  			dstAddr:     utils.Ipv4Addr.Address,
    67  			srcAddr:     utils.RemoteIPv4Addr,
    68  			rxICMP:      utils.RxICMPv4EchoRequest,
    69  			expectedSrc: utils.Ipv4Addr.Address,
    70  		},
    71  		{
    72  			name:        "IPv4 directed broadcast",
    73  			protoNum:    header.IPv4ProtocolNumber,
    74  			rxICMP:      utils.RxICMPv4EchoRequest,
    75  			srcAddr:     utils.RemoteIPv4Addr,
    76  			dstAddr:     utils.Ipv4SubnetBcast,
    77  			expectedSrc: utils.Ipv4Addr.Address,
    78  		},
    79  		{
    80  			name:        "IPv4 broadcast",
    81  			protoNum:    header.IPv4ProtocolNumber,
    82  			rxICMP:      utils.RxICMPv4EchoRequest,
    83  			srcAddr:     utils.RemoteIPv4Addr,
    84  			dstAddr:     header.IPv4Broadcast,
    85  			expectedSrc: utils.Ipv4Addr.Address,
    86  		},
    87  		{
    88  			name:        "IPv4 all-systems multicast",
    89  			protoNum:    header.IPv4ProtocolNumber,
    90  			rxICMP:      utils.RxICMPv4EchoRequest,
    91  			srcAddr:     utils.RemoteIPv4Addr,
    92  			dstAddr:     header.IPv4AllSystems,
    93  			expectedSrc: utils.Ipv4Addr.Address,
    94  		},
    95  		{
    96  			name:        "IPv6 unicast",
    97  			protoNum:    header.IPv6ProtocolNumber,
    98  			rxICMP:      utils.RxICMPv6EchoRequest,
    99  			srcAddr:     utils.RemoteIPv6Addr,
   100  			dstAddr:     utils.Ipv6Addr.Address,
   101  			expectedSrc: utils.Ipv6Addr.Address,
   102  		},
   103  		{
   104  			name:        "IPv6 all-nodes multicast",
   105  			protoNum:    header.IPv6ProtocolNumber,
   106  			rxICMP:      utils.RxICMPv6EchoRequest,
   107  			srcAddr:     utils.RemoteIPv6Addr,
   108  			dstAddr:     header.IPv6AllNodesMulticastAddress,
   109  			expectedSrc: utils.Ipv6Addr.Address,
   110  		},
   111  	}
   112  
   113  	for _, test := range tests {
   114  		t.Run(test.name, func(t *testing.T) {
   115  			s := stack.New(stack.Options{
   116  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   117  				TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
   118  			})
   119  			// We only expect a single packet in response to our ICMP Echo Request.
   120  			e := channel.New(1, defaultMTU, "")
   121  			defer e.Close()
   122  			if err := s.CreateNIC(nicID, e); err != nil {
   123  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   124  			}
   125  			ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
   126  			if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
   127  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
   128  			}
   129  			ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
   130  			if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
   131  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
   132  			}
   133  
   134  			// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
   135  			// node when attempting to send the ICMP Echo Reply.
   136  			s.SetRouteTable([]tcpip.Route{
   137  				{
   138  					Destination: header.IPv6EmptySubnet,
   139  					NIC:         nicID,
   140  				},
   141  				{
   142  					Destination: header.IPv4EmptySubnet,
   143  					NIC:         nicID,
   144  				},
   145  			})
   146  
   147  			test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
   148  			pkt := e.Read()
   149  			if pkt == nil {
   150  				t.Fatal("expected ICMP response")
   151  			}
   152  			defer pkt.DecRef()
   153  
   154  			if pkt.EgressRoute.LocalAddress != test.expectedSrc {
   155  				t.Errorf("got pkt.EgressRoute.LocalAddress = %s, want = %s", pkt.EgressRoute.LocalAddress, test.expectedSrc)
   156  			}
   157  			// The destination of the response packet should be the source of the
   158  			// original packet.
   159  			if pkt.EgressRoute.RemoteAddress != test.srcAddr {
   160  				t.Errorf("got pkt.EgressRoute.RemoteAddress = %s, want = %s", pkt.EgressRoute.RemoteAddress, test.srcAddr)
   161  			}
   162  
   163  			v := stack.PayloadSince(pkt.NetworkHeader())
   164  			defer v.Release()
   165  			src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(v.AsSlice())
   166  			if src != test.expectedSrc {
   167  				t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
   168  			}
   169  			// The destination of the response packet should be the source of the
   170  			// original packet.
   171  			if dst != test.srcAddr {
   172  				t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
   173  			}
   174  		})
   175  	}
   176  
   177  }
   178  
   179  func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
   180  	payloadLen := header.UDPMinimumSize + len(data)
   181  	totalLen := header.IPv4MinimumSize + payloadLen
   182  	hdr := prependable.New(totalLen)
   183  	u := header.UDP(hdr.Prepend(payloadLen))
   184  	u.Encode(&header.UDPFields{
   185  		SrcPort: utils.RemotePort,
   186  		DstPort: utils.LocalPort,
   187  		Length:  uint16(payloadLen),
   188  	})
   189  	copy(u.Payload(), data)
   190  	sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
   191  	sum = checksum.Checksum(data, sum)
   192  	u.SetChecksum(^u.CalculateChecksum(sum))
   193  
   194  	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
   195  	ip.Encode(&header.IPv4Fields{
   196  		TotalLength: uint16(totalLen),
   197  		Protocol:    uint8(udp.ProtocolNumber),
   198  		TTL:         ttl,
   199  		SrcAddr:     src,
   200  		DstAddr:     dst,
   201  	})
   202  	ip.SetChecksum(^ip.CalculateChecksum())
   203  
   204  	e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   205  		Payload: buffer.MakeWithData(hdr.View()),
   206  	}))
   207  }
   208  
   209  func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
   210  	payloadLen := header.UDPMinimumSize + len(data)
   211  	hdr := prependable.New(header.IPv6MinimumSize + payloadLen)
   212  	u := header.UDP(hdr.Prepend(payloadLen))
   213  	u.Encode(&header.UDPFields{
   214  		SrcPort: utils.RemotePort,
   215  		DstPort: utils.LocalPort,
   216  		Length:  uint16(payloadLen),
   217  	})
   218  	copy(u.Payload(), data)
   219  	sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
   220  	sum = checksum.Checksum(data, sum)
   221  	u.SetChecksum(^u.CalculateChecksum(sum))
   222  
   223  	ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   224  	ip.Encode(&header.IPv6Fields{
   225  		PayloadLength:     uint16(payloadLen),
   226  		TransportProtocol: udp.ProtocolNumber,
   227  		HopLimit:          ttl,
   228  		SrcAddr:           src,
   229  		DstAddr:           dst,
   230  	})
   231  
   232  	e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   233  		Payload: buffer.MakeWithData(hdr.View()),
   234  	}))
   235  }
   236  
   237  // TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
   238  // multicast or broadcast address.
   239  func TestIncomingMulticastAndBroadcast(t *testing.T) {
   240  	const nicID = 1
   241  
   242  	data := []byte{1, 2, 3, 4}
   243  
   244  	tests := []struct {
   245  		name       string
   246  		proto      tcpip.NetworkProtocolNumber
   247  		remoteAddr tcpip.Address
   248  		localAddr  tcpip.AddressWithPrefix
   249  		rxUDP      func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
   250  		bindAddr   tcpip.Address
   251  		dstAddr    tcpip.Address
   252  		expectRx   bool
   253  	}{
   254  		{
   255  			name:       "IPv4 unicast binding to unicast",
   256  			proto:      header.IPv4ProtocolNumber,
   257  			remoteAddr: utils.RemoteIPv4Addr,
   258  			localAddr:  utils.Ipv4Addr,
   259  			rxUDP:      rxIPv4UDP,
   260  			bindAddr:   utils.Ipv4Addr.Address,
   261  			dstAddr:    utils.Ipv4Addr.Address,
   262  			expectRx:   true,
   263  		},
   264  		{
   265  			name:       "IPv4 unicast binding to broadcast",
   266  			proto:      header.IPv4ProtocolNumber,
   267  			remoteAddr: utils.RemoteIPv4Addr,
   268  			localAddr:  utils.Ipv4Addr,
   269  			rxUDP:      rxIPv4UDP,
   270  			bindAddr:   header.IPv4Broadcast,
   271  			dstAddr:    utils.Ipv4Addr.Address,
   272  			expectRx:   false,
   273  		},
   274  		{
   275  			name:       "IPv4 unicast binding to wildcard",
   276  			proto:      header.IPv4ProtocolNumber,
   277  			remoteAddr: utils.RemoteIPv4Addr,
   278  			localAddr:  utils.Ipv4Addr,
   279  			rxUDP:      rxIPv4UDP,
   280  			dstAddr:    utils.Ipv4Addr.Address,
   281  			expectRx:   true,
   282  		},
   283  
   284  		{
   285  			name:       "IPv4 directed broadcast binding to subnet broadcast",
   286  			proto:      header.IPv4ProtocolNumber,
   287  			remoteAddr: utils.RemoteIPv4Addr,
   288  			localAddr:  utils.Ipv4Addr,
   289  			rxUDP:      rxIPv4UDP,
   290  			bindAddr:   utils.Ipv4SubnetBcast,
   291  			dstAddr:    utils.Ipv4SubnetBcast,
   292  			expectRx:   true,
   293  		},
   294  		{
   295  			name:       "IPv4 directed broadcast binding to broadcast",
   296  			proto:      header.IPv4ProtocolNumber,
   297  			remoteAddr: utils.RemoteIPv4Addr,
   298  			localAddr:  utils.Ipv4Addr,
   299  			rxUDP:      rxIPv4UDP,
   300  			bindAddr:   header.IPv4Broadcast,
   301  			dstAddr:    utils.Ipv4SubnetBcast,
   302  			expectRx:   false,
   303  		},
   304  		{
   305  			name:       "IPv4 directed broadcast binding to wildcard",
   306  			proto:      header.IPv4ProtocolNumber,
   307  			remoteAddr: utils.RemoteIPv4Addr,
   308  			localAddr:  utils.Ipv4Addr,
   309  			rxUDP:      rxIPv4UDP,
   310  			dstAddr:    utils.Ipv4SubnetBcast,
   311  			expectRx:   true,
   312  		},
   313  
   314  		{
   315  			name:       "IPv4 broadcast binding to broadcast",
   316  			proto:      header.IPv4ProtocolNumber,
   317  			remoteAddr: utils.RemoteIPv4Addr,
   318  			localAddr:  utils.Ipv4Addr,
   319  			rxUDP:      rxIPv4UDP,
   320  			bindAddr:   header.IPv4Broadcast,
   321  			dstAddr:    header.IPv4Broadcast,
   322  			expectRx:   true,
   323  		},
   324  		{
   325  			name:       "IPv4 broadcast binding to subnet broadcast",
   326  			proto:      header.IPv4ProtocolNumber,
   327  			remoteAddr: utils.RemoteIPv4Addr,
   328  			localAddr:  utils.Ipv4Addr,
   329  			rxUDP:      rxIPv4UDP,
   330  			bindAddr:   utils.Ipv4SubnetBcast,
   331  			dstAddr:    header.IPv4Broadcast,
   332  			expectRx:   false,
   333  		},
   334  		{
   335  			name:       "IPv4 broadcast binding to wildcard",
   336  			proto:      header.IPv4ProtocolNumber,
   337  			remoteAddr: utils.RemoteIPv4Addr,
   338  			localAddr:  utils.Ipv4Addr,
   339  			rxUDP:      rxIPv4UDP,
   340  			dstAddr:    utils.Ipv4SubnetBcast,
   341  			expectRx:   true,
   342  		},
   343  
   344  		{
   345  			name:       "IPv4 all-systems multicast binding to all-systems multicast",
   346  			proto:      header.IPv4ProtocolNumber,
   347  			remoteAddr: utils.RemoteIPv4Addr,
   348  			localAddr:  utils.Ipv4Addr,
   349  			rxUDP:      rxIPv4UDP,
   350  			bindAddr:   header.IPv4AllSystems,
   351  			dstAddr:    header.IPv4AllSystems,
   352  			expectRx:   true,
   353  		},
   354  		{
   355  			name:       "IPv4 all-systems multicast binding to wildcard",
   356  			proto:      header.IPv4ProtocolNumber,
   357  			remoteAddr: utils.RemoteIPv4Addr,
   358  			localAddr:  utils.Ipv4Addr,
   359  			rxUDP:      rxIPv4UDP,
   360  			dstAddr:    header.IPv4AllSystems,
   361  			expectRx:   true,
   362  		},
   363  		{
   364  			name:       "IPv4 all-systems multicast binding to unicast",
   365  			proto:      header.IPv4ProtocolNumber,
   366  			remoteAddr: utils.RemoteIPv4Addr,
   367  			localAddr:  utils.Ipv4Addr,
   368  			rxUDP:      rxIPv4UDP,
   369  			bindAddr:   utils.Ipv4Addr.Address,
   370  			dstAddr:    header.IPv4AllSystems,
   371  			expectRx:   false,
   372  		},
   373  
   374  		// IPv6 has no notion of a broadcast.
   375  		{
   376  			name:       "IPv6 unicast binding to wildcard",
   377  			dstAddr:    utils.Ipv6Addr.Address,
   378  			proto:      header.IPv6ProtocolNumber,
   379  			remoteAddr: utils.RemoteIPv6Addr,
   380  			localAddr:  utils.Ipv6Addr,
   381  			rxUDP:      rxIPv6UDP,
   382  			expectRx:   true,
   383  		},
   384  		{
   385  			name:       "IPv6 broadcast-like address binding to wildcard",
   386  			dstAddr:    utils.Ipv6SubnetBcast,
   387  			proto:      header.IPv6ProtocolNumber,
   388  			remoteAddr: utils.RemoteIPv6Addr,
   389  			localAddr:  utils.Ipv6Addr,
   390  			rxUDP:      rxIPv6UDP,
   391  			expectRx:   false,
   392  		},
   393  	}
   394  
   395  	for _, test := range tests {
   396  		t.Run(test.name, func(t *testing.T) {
   397  			s := stack.New(stack.Options{
   398  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   399  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   400  			})
   401  			e := channel.New(0, defaultMTU, "")
   402  			defer e.Close()
   403  			if err := s.CreateNIC(nicID, e); err != nil {
   404  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   405  			}
   406  			protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
   407  			if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
   408  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
   409  			}
   410  
   411  			var wq waiter.Queue
   412  			ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
   413  			if err != nil {
   414  				t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
   415  			}
   416  			defer ep.Close()
   417  
   418  			bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort}
   419  			if err := ep.Bind(bindAddr); err != nil {
   420  				t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
   421  			}
   422  
   423  			test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
   424  			var buf bytes.Buffer
   425  			var opts tcpip.ReadOptions
   426  			if res, err := ep.Read(&buf, opts); test.expectRx {
   427  				if err != nil {
   428  					t.Fatalf("ep.Read(_, %#v): %s", opts, err)
   429  				}
   430  				if diff := cmp.Diff(tcpip.ReadResult{
   431  					Count: buf.Len(),
   432  					Total: buf.Len(),
   433  				}, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
   434  					t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   435  				}
   436  				if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
   437  					t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
   438  				}
   439  			} else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
   440  				t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
   441  			}
   442  		})
   443  	}
   444  }
   445  
   446  // TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
   447  // interested endpoints.
   448  func TestReuseAddrAndBroadcast(t *testing.T) {
   449  	const (
   450  		nicID     = 1
   451  		localPort = 9000
   452  	)
   453  	loopbackBroadcast := testutil.MustParse4("127.255.255.255")
   454  
   455  	tests := []struct {
   456  		name          string
   457  		broadcastAddr tcpip.Address
   458  	}{
   459  		{
   460  			name:          "Subnet directed broadcast",
   461  			broadcastAddr: loopbackBroadcast,
   462  		},
   463  		{
   464  			name:          "IPv4 broadcast",
   465  			broadcastAddr: header.IPv4Broadcast,
   466  		},
   467  	}
   468  
   469  	for _, test := range tests {
   470  		t.Run(test.name, func(t *testing.T) {
   471  			s := stack.New(stack.Options{
   472  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   473  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   474  			})
   475  			if err := s.CreateNIC(nicID, loopback.New()); err != nil {
   476  				t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   477  			}
   478  			protoAddr := tcpip.ProtocolAddress{
   479  				Protocol: header.IPv4ProtocolNumber,
   480  				AddressWithPrefix: tcpip.AddressWithPrefix{
   481  					Address:   tcpip.AddrFromSlice([]byte("\x7f\x00\x00\x01")),
   482  					PrefixLen: 8,
   483  				},
   484  			}
   485  			if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
   486  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
   487  			}
   488  
   489  			s.SetRouteTable([]tcpip.Route{
   490  				{
   491  					// We use the empty subnet instead of just the loopback subnet so we
   492  					// also have a route to the IPv4 Broadcast address.
   493  					Destination: header.IPv4EmptySubnet,
   494  					NIC:         nicID,
   495  				},
   496  			})
   497  
   498  			type endpointAndWaiter struct {
   499  				ep tcpip.Endpoint
   500  				ch chan struct{}
   501  			}
   502  			var eps []endpointAndWaiter
   503  			// We create endpoints that bind to both the wildcard address and the
   504  			// broadcast address to make sure both of these types of "broadcast
   505  			// interested" endpoints receive broadcast packets.
   506  			for _, bindWildcard := range []bool{false, true} {
   507  				// Create multiple endpoints for each type of "broadcast interested"
   508  				// endpoint so we can test that all endpoints receive the broadcast
   509  				// packet.
   510  				for i := 0; i < 2; i++ {
   511  					var wq waiter.Queue
   512  					we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   513  					wq.EventRegister(&we)
   514  					ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   515  					if err != nil {
   516  						t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
   517  					}
   518  					defer ep.Close()
   519  
   520  					ep.SocketOptions().SetReuseAddress(true)
   521  					ep.SocketOptions().SetBroadcast(true)
   522  
   523  					bindAddr := tcpip.FullAddress{Port: localPort}
   524  					if bindWildcard {
   525  						if err := ep.Bind(bindAddr); err != nil {
   526  							t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
   527  						}
   528  					} else {
   529  						bindAddr.Addr = test.broadcastAddr
   530  						if err := ep.Bind(bindAddr); err != nil {
   531  							t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
   532  						}
   533  					}
   534  
   535  					eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
   536  				}
   537  			}
   538  
   539  			for i, wep := range eps {
   540  				writeOpts := tcpip.WriteOptions{
   541  					To: &tcpip.FullAddress{
   542  						Addr: test.broadcastAddr,
   543  						Port: localPort,
   544  					},
   545  				}
   546  				data := []byte{byte(i), 2, 3, 4}
   547  				var r bytes.Reader
   548  				r.Reset(data)
   549  				if n, err := wep.ep.Write(&r, writeOpts); err != nil {
   550  					t.Fatalf("eps[%d].Write(_, _): %s", i, err)
   551  				} else if want := int64(len(data)); n != want {
   552  					t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
   553  				}
   554  
   555  				for j, rep := range eps {
   556  					// Wait for the endpoint to become readable.
   557  					<-rep.ch
   558  
   559  					var buf bytes.Buffer
   560  					result, err := rep.ep.Read(&buf, tcpip.ReadOptions{})
   561  					if err != nil {
   562  						t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
   563  						continue
   564  					}
   565  					if diff := cmp.Diff(tcpip.ReadResult{
   566  						Count: buf.Len(),
   567  						Total: buf.Len(),
   568  					}, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
   569  						t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
   570  					}
   571  					if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
   572  						t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
   573  					}
   574  				}
   575  			}
   576  		})
   577  	}
   578  }
   579  
   580  func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
   581  	const (
   582  		nicID = 1
   583  	)
   584  
   585  	data := []byte{1, 2, 3, 4}
   586  
   587  	tests := []struct {
   588  		name          string
   589  		proto         tcpip.NetworkProtocolNumber
   590  		remoteAddr    tcpip.Address
   591  		localAddr     tcpip.AddressWithPrefix
   592  		rxUDP         func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
   593  		multicastAddr tcpip.Address
   594  	}{
   595  		{
   596  			name:          "IPv4 unicast binding to unicast",
   597  			multicastAddr: tcpip.AddrFromSlice([]byte("\xe0\x01\x02\x03")),
   598  			proto:         header.IPv4ProtocolNumber,
   599  			remoteAddr:    utils.RemoteIPv4Addr,
   600  			localAddr:     utils.Ipv4Addr,
   601  			rxUDP:         rxIPv4UDP,
   602  		},
   603  		{
   604  			name:          "IPv6 broadcast-like address binding to wildcard",
   605  			multicastAddr: tcpip.AddrFromSlice([]byte("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04")),
   606  			proto:         header.IPv6ProtocolNumber,
   607  			remoteAddr:    utils.RemoteIPv6Addr,
   608  			localAddr:     utils.Ipv6Addr,
   609  			rxUDP:         rxIPv6UDP,
   610  		},
   611  	}
   612  
   613  	subTests := []struct {
   614  		name           string
   615  		specifyNICID   bool
   616  		specifyNICAddr bool
   617  	}{
   618  		{
   619  			name:           "Specify NIC ID and NIC address",
   620  			specifyNICID:   true,
   621  			specifyNICAddr: true,
   622  		},
   623  		{
   624  			name:           "Don't specify NIC ID or NIC address",
   625  			specifyNICID:   false,
   626  			specifyNICAddr: false,
   627  		},
   628  		{
   629  			name:           "Specify NIC ID but don't specify NIC address",
   630  			specifyNICID:   true,
   631  			specifyNICAddr: false,
   632  		},
   633  		{
   634  			name:           "Don't specify NIC ID but specify NIC address",
   635  			specifyNICID:   false,
   636  			specifyNICAddr: true,
   637  		},
   638  	}
   639  
   640  	for _, test := range tests {
   641  		t.Run(test.name, func(t *testing.T) {
   642  			for _, subTest := range subTests {
   643  				t.Run(subTest.name, func(t *testing.T) {
   644  					s := stack.New(stack.Options{
   645  						NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   646  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   647  					})
   648  					e := channel.New(0, defaultMTU, "")
   649  					defer e.Close()
   650  					if err := s.CreateNIC(nicID, e); err != nil {
   651  						t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   652  					}
   653  					protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
   654  					if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
   655  						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
   656  					}
   657  
   658  					// Set the route table so that UDP can find a NIC that is
   659  					// routable to the multicast address when the NIC isn't specified.
   660  					if !subTest.specifyNICID && !subTest.specifyNICAddr {
   661  						s.SetRouteTable([]tcpip.Route{
   662  							{
   663  								Destination: header.IPv6EmptySubnet,
   664  								NIC:         nicID,
   665  							},
   666  							{
   667  								Destination: header.IPv4EmptySubnet,
   668  								NIC:         nicID,
   669  							},
   670  						})
   671  					}
   672  
   673  					var wq waiter.Queue
   674  					ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
   675  					if err != nil {
   676  						t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
   677  					}
   678  					defer ep.Close()
   679  
   680  					bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
   681  					if err := ep.Bind(bindAddr); err != nil {
   682  						t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
   683  					}
   684  
   685  					memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
   686  					if subTest.specifyNICID {
   687  						memOpt.NIC = nicID
   688  					}
   689  					if subTest.specifyNICAddr {
   690  						memOpt.InterfaceAddr = test.localAddr.Address
   691  					}
   692  
   693  					// We should receive UDP packets to the group once we join the
   694  					// multicast group.
   695  					addOpt := tcpip.AddMembershipOption(memOpt)
   696  					if err := ep.SetSockOpt(&addOpt); err != nil {
   697  						t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
   698  					}
   699  					test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
   700  					var buf bytes.Buffer
   701  					result, err := ep.Read(&buf, tcpip.ReadOptions{})
   702  					if err != nil {
   703  						t.Fatalf("ep.Read: %s", err)
   704  					} else {
   705  						if diff := cmp.Diff(tcpip.ReadResult{
   706  							Count: buf.Len(),
   707  							Total: buf.Len(),
   708  						}, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
   709  							t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   710  						}
   711  						if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
   712  							t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
   713  						}
   714  					}
   715  
   716  					// We should not receive UDP packets to the group once we leave
   717  					// the multicast group.
   718  					removeOpt := tcpip.RemoveMembershipOption(memOpt)
   719  					if err := ep.SetSockOpt(&removeOpt); err != nil {
   720  						t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
   721  					}
   722  					{
   723  						_, err := ep.Read(&buf, tcpip.ReadOptions{})
   724  						if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
   725  							t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{})
   726  						}
   727  					}
   728  				})
   729  			}
   730  		})
   731  	}
   732  }
   733  
   734  func TestAddMembershipInterfacePrecedence(t *testing.T) {
   735  	const nicID = 1
   736  	multicastAddr := tcpip.AddrFromSlice([]byte("\xe0\x01\x02\x03"))
   737  	proto := header.IPv4ProtocolNumber
   738  	// This address is nonsensical. If the precedence is correct, this should not
   739  	// matter, because ADD_IP_MEMBERSHIP should consider the interface index
   740  	// and use that before checking the address.
   741  	localAddr := tcpip.AddressWithPrefix{
   742  		Address:   testutil.MustParse4("8.0.8.0"),
   743  		PrefixLen: 24,
   744  	}
   745  	s := stack.New(stack.Options{
   746  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   747  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   748  	})
   749  	e := channel.New(0, defaultMTU, "")
   750  	defer e.Close()
   751  	if err := s.CreateNIC(nicID, e); err != nil {
   752  		t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   753  	}
   754  	protoAddr := tcpip.ProtocolAddress{Protocol: proto, AddressWithPrefix: localAddr}
   755  	if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
   756  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
   757  	}
   758  
   759  	var wq waiter.Queue
   760  	ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
   761  	if err != nil {
   762  		t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, proto, err)
   763  	}
   764  	defer ep.Close()
   765  
   766  	bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
   767  	if err := ep.Bind(bindAddr); err != nil {
   768  		t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
   769  	}
   770  
   771  	memOpt := tcpip.MembershipOption{MulticastAddr: multicastAddr}
   772  	memOpt.NIC = nicID
   773  	memOpt.InterfaceAddr = localAddr.Address
   774  
   775  	// Add membership should succeed when the interface index is specified,
   776  	// even if a bad interface address is specified.
   777  	addOpt := tcpip.AddMembershipOption(memOpt)
   778  	if err := ep.SetSockOpt(&addOpt); err != nil {
   779  		t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
   780  	}
   781  }
   782  
   783  func TestMismatchedMulticastAddressAndProtocol(t *testing.T) {
   784  	const nicID = 1
   785  	// MulticastAddr is IPv4, but proto is IPv6.
   786  	multicastAddr := tcpip.AddrFromSlice([]byte("\xe0\x01\x02\x03"))
   787  	s := stack.New(stack.Options{
   788  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   789  		TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
   790  		RawFactory:         raw.EndpointFactory{},
   791  	})
   792  	e := channel.New(0, defaultMTU, "")
   793  	defer e.Close()
   794  	if err := s.CreateNIC(nicID, e); err != nil {
   795  		t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
   796  	}
   797  	protoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
   798  	if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
   799  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
   800  	}
   801  
   802  	var wq waiter.Queue
   803  	ep, err := s.NewRawEndpoint(header.ICMPv6ProtocolNumber, header.IPv6ProtocolNumber, &wq, false)
   804  	if err != nil {
   805  		t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv6ProtocolNumber, err)
   806  	}
   807  	defer ep.Close()
   808  
   809  	bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
   810  	if err := ep.Bind(bindAddr); err != nil {
   811  		t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
   812  	}
   813  
   814  	memOpt := tcpip.MembershipOption{
   815  		MulticastAddr: multicastAddr,
   816  		NIC:           0,
   817  		InterfaceAddr: utils.Ipv4Addr.Address,
   818  	}
   819  
   820  	// Add/remove membership should succeed when the interface index is specified,
   821  	// even if a bad interface address is specified.
   822  	addOpt := tcpip.AddMembershipOption(memOpt)
   823  	expErr := &tcpip.ErrInvalidOptionValue{}
   824  	if err := ep.SetSockOpt(&addOpt); err != expErr {
   825  		t.Fatalf("ep.SetSockOpt(&%#v): want %q, got %q", addOpt, expErr, err)
   826  	}
   827  
   828  	removeOpt := tcpip.RemoveMembershipOption(memOpt)
   829  	if err := ep.SetSockOpt(&removeOpt); err != expErr {
   830  		t.Fatalf("ep.SetSockOpt(&%#v): want %q, got %q", addOpt, expErr, err)
   831  	}
   832  }