gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/tests/integration/multicast_forward_test.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package multicast_forward_test
    16  
    17  import (
    18  	"fmt"
    19  	"os"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"github.com/google/go-cmp/cmp/cmpopts"
    25  	"gvisor.dev/gvisor/pkg/refs"
    26  	"gvisor.dev/gvisor/pkg/tcpip"
    27  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    28  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    29  	"gvisor.dev/gvisor/pkg/tcpip/header"
    30  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    31  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    32  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    33  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    34  	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
    35  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    36  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    37  )
    38  
    39  const (
    40  	incomingNICID      = 1
    41  	outgoingNICID      = 2
    42  	otherOutgoingNICID = 3
    43  	otherNICID         = 4
    44  	unknownNICID       = 5
    45  	packetTTL          = 64
    46  	routeMinTTL        = 2
    47  )
    48  
    49  type addrType int
    50  
    51  const (
    52  	emptyAddr addrType = iota
    53  	anyAddr
    54  	linkLocalMulticastAddr
    55  	linkLocalUnicastAddr
    56  	multicastAddr
    57  	otherMulticastAddr
    58  	remoteUnicastAddr
    59  )
    60  
    61  type endpointAddrType int
    62  
    63  const (
    64  	incomingEndpointAddr endpointAddrType = iota
    65  	otherEndpointAddr
    66  	outgoingEndpointAddr
    67  	otherOutgoingEndpointAddr
    68  )
    69  
    70  type onMissingRouteData struct {
    71  	context stack.MulticastPacketContext
    72  }
    73  
    74  type onUnexpectedInputInterfaceData struct {
    75  	context                stack.MulticastPacketContext
    76  	expectedInputInterface tcpip.NICID
    77  }
    78  
    79  var _ stack.MulticastForwardingEventDispatcher = (*fakeMulticastEventDispatcher)(nil)
    80  
    81  type fakeMulticastEventDispatcher struct {
    82  	onMissingRouteData             *onMissingRouteData
    83  	onUnexpectedInputInterfaceData *onUnexpectedInputInterfaceData
    84  }
    85  
    86  func (m *fakeMulticastEventDispatcher) OnMissingRoute(context stack.MulticastPacketContext) {
    87  	m.onMissingRouteData = &onMissingRouteData{context}
    88  }
    89  
    90  func (m *fakeMulticastEventDispatcher) OnUnexpectedInputInterface(context stack.MulticastPacketContext, expectedInputInterface tcpip.NICID) {
    91  	m.onUnexpectedInputInterfaceData = &onUnexpectedInputInterfaceData{
    92  		context,
    93  		expectedInputInterface,
    94  	}
    95  }
    96  
    97  var (
    98  	v4Addrs = map[addrType]tcpip.Address{
    99  		anyAddr:                header.IPv4Any,
   100  		emptyAddr:              tcpip.Address{},
   101  		linkLocalMulticastAddr: testutil.MustParse4("224.0.0.1"),
   102  		linkLocalUnicastAddr:   testutil.MustParse4("169.254.0.10"),
   103  		multicastAddr:          testutil.MustParse4("225.0.0.0"),
   104  		otherMulticastAddr:     testutil.MustParse4("225.0.0.1"),
   105  		remoteUnicastAddr:      utils.RemoteIPv4Addr,
   106  	}
   107  
   108  	v6Addrs = map[addrType]tcpip.Address{
   109  		anyAddr:                header.IPv6Any,
   110  		emptyAddr:              tcpip.Address{},
   111  		linkLocalMulticastAddr: testutil.MustParse6("ff02::a"),
   112  		linkLocalUnicastAddr:   testutil.MustParse6("fe80::a"),
   113  		multicastAddr:          testutil.MustParse6("ff0e::a"),
   114  		otherMulticastAddr:     testutil.MustParse6("ff0e::b"),
   115  		remoteUnicastAddr:      utils.RemoteIPv6Addr,
   116  	}
   117  
   118  	v4EndpointAddrs = map[endpointAddrType]tcpip.AddressWithPrefix{
   119  		incomingEndpointAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix,
   120  		otherEndpointAddr:    utils.Host1IPv4Addr.AddressWithPrefix,
   121  		outgoingEndpointAddr: utils.RouterNIC2IPv4Addr.AddressWithPrefix,
   122  		otherOutgoingNICID:   utils.Host2IPv4Addr.AddressWithPrefix,
   123  	}
   124  
   125  	v6EndpointAddrs = map[endpointAddrType]tcpip.AddressWithPrefix{
   126  		incomingEndpointAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix,
   127  		otherEndpointAddr:    utils.Host1IPv6Addr.AddressWithPrefix,
   128  		outgoingEndpointAddr: utils.RouterNIC2IPv6Addr.AddressWithPrefix,
   129  		otherOutgoingNICID:   utils.Host2IPv6Addr.AddressWithPrefix,
   130  	}
   131  )
   132  
   133  func getAddr(protocol tcpip.NetworkProtocolNumber, addrType addrType) tcpip.Address {
   134  	switch protocol {
   135  	case ipv4.ProtocolNumber:
   136  		if addr, ok := v4Addrs[addrType]; ok {
   137  			return addr
   138  		}
   139  		panic(fmt.Sprintf("unsupported addrType: %d", addrType))
   140  	case ipv6.ProtocolNumber:
   141  		if addr, ok := v6Addrs[addrType]; ok {
   142  			return addr
   143  		}
   144  		panic(fmt.Sprintf("unsupported addrType: %d", addrType))
   145  	default:
   146  		panic(fmt.Sprintf("unsupported protocol: %d", protocol))
   147  	}
   148  }
   149  
   150  func getEndpointAddr(protocol tcpip.NetworkProtocolNumber, addrType endpointAddrType) tcpip.AddressWithPrefix {
   151  	switch protocol {
   152  	case ipv4.ProtocolNumber:
   153  		if addr, ok := v4EndpointAddrs[addrType]; ok {
   154  			return addr
   155  		}
   156  		panic(fmt.Sprintf("unsupported endpointAddrType: %d", addrType))
   157  	case ipv6.ProtocolNumber:
   158  		if addr, ok := v6EndpointAddrs[addrType]; ok {
   159  			return addr
   160  		}
   161  		panic(fmt.Sprintf("unsupported endpointAddrType: %d", addrType))
   162  	default:
   163  		panic(fmt.Sprintf("unsupported protocol: %d", protocol))
   164  	}
   165  }
   166  
   167  func checkEchoRequest(t *testing.T, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, srcAddr, dstAddr tcpip.Address, ttl uint8) {
   168  	payload := stack.PayloadSince(pkt.NetworkHeader())
   169  	defer payload.Release()
   170  	switch protocol {
   171  	case ipv4.ProtocolNumber:
   172  		checker.IPv4(t, payload,
   173  			checker.SrcAddr(srcAddr),
   174  			checker.DstAddr(dstAddr),
   175  			checker.TTL(ttl),
   176  			checker.ICMPv4(
   177  				checker.ICMPv4Type(header.ICMPv4Echo),
   178  			),
   179  		)
   180  	case ipv6.ProtocolNumber:
   181  		checker.IPv6(t, payload,
   182  			checker.SrcAddr(srcAddr),
   183  			checker.DstAddr(dstAddr),
   184  			checker.TTL(ttl),
   185  			checker.ICMPv6(
   186  				checker.ICMPv6Type(header.ICMPv6EchoRequest),
   187  			),
   188  		)
   189  	default:
   190  		panic(fmt.Sprintf("unsupported protocol: %d", protocol))
   191  	}
   192  }
   193  
   194  func checkEchoReply(t *testing.T, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, srcAddr, dstAddr tcpip.Address) {
   195  	payload := stack.PayloadSince(pkt.NetworkHeader())
   196  	defer payload.Release()
   197  	switch protocol {
   198  	case ipv4.ProtocolNumber:
   199  		checker.IPv4(t, payload,
   200  			checker.SrcAddr(srcAddr),
   201  			checker.DstAddr(dstAddr),
   202  			checker.ICMPv4(
   203  				checker.ICMPv4Type(header.ICMPv4EchoReply),
   204  			),
   205  		)
   206  	case ipv6.ProtocolNumber:
   207  		checker.IPv6(t, payload,
   208  			checker.SrcAddr(srcAddr),
   209  			checker.DstAddr(dstAddr),
   210  			checker.ICMPv6(
   211  				checker.ICMPv6Type(header.ICMPv6EchoReply),
   212  			),
   213  		)
   214  	default:
   215  		panic(fmt.Sprintf("unsupported protocol: %d", protocol))
   216  	}
   217  }
   218  
   219  func injectPacket(ep *channel.Endpoint, protocol tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, ttl uint8) {
   220  	switch protocol {
   221  	case ipv4.ProtocolNumber:
   222  		utils.RxICMPv4EchoRequest(ep, srcAddr, dstAddr, ttl)
   223  	case ipv6.ProtocolNumber:
   224  		utils.RxICMPv6EchoRequest(ep, srcAddr, dstAddr, ttl)
   225  	default:
   226  		panic(fmt.Sprintf("unsupported protocol: %d", protocol))
   227  	}
   228  }
   229  
   230  func TestAddMulticastRoute(t *testing.T) {
   231  	endpointConfigs := map[tcpip.NICID]endpointAddrType{
   232  		incomingNICID: incomingEndpointAddr,
   233  		outgoingNICID: outgoingEndpointAddr,
   234  		otherNICID:    otherEndpointAddr,
   235  	}
   236  
   237  	type multicastForwardingEvent int
   238  	const (
   239  		enabledForProtocol multicastForwardingEvent = iota
   240  		enabledForNIC
   241  		injectPendingPacket
   242  	)
   243  
   244  	type multicastForwardingStateBeforeAddRouteCalled struct {
   245  		multicastForwardingEvents []multicastForwardingEvent
   246  	}
   247  
   248  	tests := []struct {
   249  		name                                          string
   250  		srcAddr, dstAddr                              addrType
   251  		routeIncomingNICID                            tcpip.NICID
   252  		routeOutgoingNICID                            tcpip.NICID
   253  		omitOutgoingInterfaces                        bool
   254  		multicastForwardingEventsBeforeAddRouteCalled []multicastForwardingEvent
   255  		expectForward                                 bool
   256  		wantErr                                       tcpip.Error
   257  	}{
   258  		{
   259  			name:               "no pending packets",
   260  			srcAddr:            remoteUnicastAddr,
   261  			dstAddr:            multicastAddr,
   262  			routeIncomingNICID: incomingNICID,
   263  			routeOutgoingNICID: outgoingNICID,
   264  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   265  			wantErr: nil,
   266  		},
   267  		{
   268  			name:               "packet arrived after forwarding enabled but before add route called",
   269  			srcAddr:            remoteUnicastAddr,
   270  			dstAddr:            multicastAddr,
   271  			routeIncomingNICID: incomingNICID,
   272  			routeOutgoingNICID: outgoingNICID,
   273  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol, injectPendingPacket},
   274  			expectForward: true,
   275  		},
   276  		{
   277  			name:               "packet arrived before multicast forwarding enabled",
   278  			srcAddr:            remoteUnicastAddr,
   279  			dstAddr:            multicastAddr,
   280  			routeIncomingNICID: incomingNICID,
   281  			routeOutgoingNICID: outgoingNICID,
   282  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, injectPendingPacket, enabledForProtocol},
   283  			expectForward: false,
   284  		},
   285  		{
   286  			name:    "unexpected input interface",
   287  			srcAddr: remoteUnicastAddr,
   288  			dstAddr: multicastAddr,
   289  			// The added route's incoming NICID does not match the pending packet's
   290  			// incoming NICID. As a result, the packet should not be forwarded.
   291  			routeIncomingNICID: otherNICID,
   292  			routeOutgoingNICID: outgoingNICID,
   293  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   294  		},
   295  		{
   296  			name:               "multicast forwarding disabled for NIC",
   297  			srcAddr:            remoteUnicastAddr,
   298  			dstAddr:            multicastAddr,
   299  			routeIncomingNICID: incomingNICID,
   300  			routeOutgoingNICID: outgoingNICID,
   301  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForProtocol},
   302  			expectForward: false,
   303  			wantErr:       nil,
   304  		},
   305  		{
   306  			name:               "multicast forwarding disabled for protocol",
   307  			srcAddr:            remoteUnicastAddr,
   308  			dstAddr:            multicastAddr,
   309  			routeIncomingNICID: incomingNICID,
   310  			routeOutgoingNICID: outgoingNICID,
   311  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC},
   312  			wantErr: &tcpip.ErrNotPermitted{},
   313  		},
   314  		{
   315  			name:               "multicast source",
   316  			srcAddr:            multicastAddr,
   317  			dstAddr:            multicastAddr,
   318  			routeIncomingNICID: incomingNICID,
   319  			routeOutgoingNICID: outgoingNICID,
   320  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   321  			wantErr: &tcpip.ErrBadAddress{},
   322  		},
   323  		{
   324  			name:               "any source",
   325  			srcAddr:            anyAddr,
   326  			dstAddr:            multicastAddr,
   327  			routeIncomingNICID: incomingNICID,
   328  			routeOutgoingNICID: outgoingNICID,
   329  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   330  			wantErr: &tcpip.ErrBadAddress{},
   331  		},
   332  		{
   333  			name:               "link-local unicast source",
   334  			srcAddr:            linkLocalUnicastAddr,
   335  			dstAddr:            multicastAddr,
   336  			routeIncomingNICID: incomingNICID,
   337  			routeOutgoingNICID: outgoingNICID,
   338  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   339  			wantErr: &tcpip.ErrBadAddress{},
   340  		},
   341  		{
   342  			name:               "empty source",
   343  			srcAddr:            emptyAddr,
   344  			dstAddr:            multicastAddr,
   345  			routeIncomingNICID: incomingNICID,
   346  			routeOutgoingNICID: outgoingNICID,
   347  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   348  			wantErr: &tcpip.ErrBadAddress{},
   349  		},
   350  		{
   351  			name:               "unicast destination",
   352  			srcAddr:            remoteUnicastAddr,
   353  			dstAddr:            remoteUnicastAddr,
   354  			routeIncomingNICID: incomingNICID,
   355  			routeOutgoingNICID: outgoingNICID,
   356  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   357  			wantErr: &tcpip.ErrBadAddress{},
   358  		},
   359  		{
   360  			name:               "empty destination",
   361  			srcAddr:            remoteUnicastAddr,
   362  			dstAddr:            emptyAddr,
   363  			routeIncomingNICID: incomingNICID,
   364  			routeOutgoingNICID: outgoingNICID,
   365  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   366  			wantErr: &tcpip.ErrBadAddress{},
   367  		},
   368  		{
   369  			name:               "link-local multicast destination",
   370  			srcAddr:            remoteUnicastAddr,
   371  			dstAddr:            linkLocalMulticastAddr,
   372  			routeIncomingNICID: incomingNICID,
   373  			routeOutgoingNICID: outgoingNICID,
   374  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   375  			wantErr: &tcpip.ErrBadAddress{},
   376  		},
   377  		{
   378  			name:               "unknown input NICID",
   379  			srcAddr:            remoteUnicastAddr,
   380  			dstAddr:            multicastAddr,
   381  			routeIncomingNICID: unknownNICID,
   382  			routeOutgoingNICID: outgoingNICID,
   383  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   384  			wantErr: &tcpip.ErrUnknownNICID{},
   385  		},
   386  		{
   387  			name:               "unknown output NICID",
   388  			srcAddr:            remoteUnicastAddr,
   389  			dstAddr:            multicastAddr,
   390  			routeIncomingNICID: incomingNICID,
   391  			routeOutgoingNICID: unknownNICID,
   392  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   393  			wantErr: &tcpip.ErrUnknownNICID{},
   394  		},
   395  		{
   396  			name:               "input NIC matches output NIC",
   397  			srcAddr:            remoteUnicastAddr,
   398  			dstAddr:            multicastAddr,
   399  			routeIncomingNICID: incomingNICID,
   400  			routeOutgoingNICID: incomingNICID,
   401  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   402  			wantErr: &tcpip.ErrMulticastInputCannotBeOutput{},
   403  		},
   404  		{
   405  			name:                   "empty outgoing interfaces",
   406  			srcAddr:                remoteUnicastAddr,
   407  			dstAddr:                multicastAddr,
   408  			routeIncomingNICID:     incomingNICID,
   409  			routeOutgoingNICID:     outgoingNICID,
   410  			omitOutgoingInterfaces: true,
   411  			multicastForwardingEventsBeforeAddRouteCalled: []multicastForwardingEvent{enabledForNIC, enabledForProtocol},
   412  			wantErr: &tcpip.ErrMissingRequiredFields{},
   413  		},
   414  	}
   415  
   416  	for _, test := range tests {
   417  		for _, protocol := range []tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} {
   418  			t.Run(fmt.Sprintf("%s %d", test.name, protocol), func(t *testing.T) {
   419  				eventDispatcher := &fakeMulticastEventDispatcher{}
   420  				s := stack.New(stack.Options{
   421  					NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   422  				})
   423  				defer s.Destroy()
   424  
   425  				endpoints := make(map[tcpip.NICID]*channel.Endpoint)
   426  				for nicID, addrType := range endpointConfigs {
   427  					ep := channel.New(1, ipv4.MaxTotalSize, "")
   428  					defer ep.Close()
   429  
   430  					if err := s.CreateNIC(nicID, ep); err != nil {
   431  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   432  					}
   433  					addr := tcpip.ProtocolAddress{
   434  						Protocol:          protocol,
   435  						AddressWithPrefix: getEndpointAddr(protocol, addrType),
   436  					}
   437  					if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
   438  						t.Fatalf("s.AddProtocolAddress(%d, %#v, {}): %s", nicID, addr, err)
   439  					}
   440  					endpoints[nicID] = ep
   441  				}
   442  
   443  				srcAddr := getAddr(protocol, test.srcAddr)
   444  				dstAddr := getAddr(protocol, test.dstAddr)
   445  
   446  				for _, event := range test.multicastForwardingEventsBeforeAddRouteCalled {
   447  					switch event {
   448  					case enabledForNIC:
   449  						for nicID := range endpoints {
   450  							s.SetNICMulticastForwarding(nicID, protocol, true /* enable */)
   451  						}
   452  					case enabledForProtocol:
   453  						if _, err := s.EnableMulticastForwardingForProtocol(protocol, eventDispatcher); err != nil {
   454  							t.Fatalf("s.EnableMulticastForwardingForProtocol(%d, _): (_, %s)", protocol, err)
   455  						}
   456  					case injectPendingPacket:
   457  						incomingEp, ok := endpoints[incomingNICID]
   458  						if !ok {
   459  							t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", incomingNICID)
   460  						}
   461  
   462  						injectPacket(incomingEp, protocol, srcAddr, dstAddr, packetTTL)
   463  						p := incomingEp.Read()
   464  
   465  						if p != nil {
   466  							// An ICMP error should never be sent in response to a multicast packet.
   467  							t.Fatalf("got incomingEp.Read() = %#v, want = nil", p)
   468  						}
   469  					default:
   470  						panic(fmt.Sprintf("unsupported multicastForwardingEvent: %d", event))
   471  					}
   472  				}
   473  
   474  				outgoingInterfaces := []stack.MulticastRouteOutgoingInterface{
   475  					{ID: test.routeOutgoingNICID, MinTTL: routeMinTTL},
   476  				}
   477  				if test.omitOutgoingInterfaces {
   478  					outgoingInterfaces = nil
   479  				}
   480  
   481  				addresses := stack.UnicastSourceAndMulticastDestination{
   482  					Source:      srcAddr,
   483  					Destination: dstAddr,
   484  				}
   485  
   486  				route := stack.MulticastRoute{
   487  					ExpectedInputInterface: test.routeIncomingNICID,
   488  					OutgoingInterfaces:     outgoingInterfaces,
   489  				}
   490  
   491  				err := s.AddMulticastRoute(protocol, addresses, route)
   492  
   493  				if !cmp.Equal(err, test.wantErr, cmpopts.EquateErrors()) {
   494  					t.Errorf("got s.AddMulticastRoute(%d, %#v, %#v) = %s, want %s", protocol, addresses, route, err, test.wantErr)
   495  				}
   496  
   497  				outgoingEp, ok := endpoints[outgoingNICID]
   498  				if !ok {
   499  					t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", outgoingNICID)
   500  				}
   501  
   502  				p := outgoingEp.Read()
   503  
   504  				if (p != nil) != test.expectForward {
   505  					t.Fatalf("got outgoingEp.Read() = %#v, want = (_ == nil) = %t", p, test.expectForward)
   506  				}
   507  
   508  				if test.expectForward {
   509  					checkEchoRequest(t, protocol, p, srcAddr, dstAddr, packetTTL-1)
   510  					p.DecRef()
   511  				}
   512  			})
   513  		}
   514  	}
   515  }
   516  
   517  func TestEnableMulticastForwardingE(t *testing.T) {
   518  	eventDispatcher := &fakeMulticastEventDispatcher{}
   519  
   520  	type enableMulticastForwardingResult struct {
   521  		AlreadyEnabled bool
   522  		Err            tcpip.Error
   523  	}
   524  
   525  	tests := []struct {
   526  		name            string
   527  		eventDispatcher stack.MulticastForwardingEventDispatcher
   528  		wantResult      []enableMulticastForwardingResult
   529  	}{
   530  		{
   531  			name:            "success",
   532  			eventDispatcher: eventDispatcher,
   533  			wantResult:      []enableMulticastForwardingResult{{false, nil}},
   534  		},
   535  		{
   536  			name:            "already enabled",
   537  			eventDispatcher: eventDispatcher,
   538  			wantResult:      []enableMulticastForwardingResult{{false, nil}, {true, nil}},
   539  		},
   540  		{
   541  			name:            "invalid event dispatcher",
   542  			eventDispatcher: nil,
   543  			wantResult:      []enableMulticastForwardingResult{{false, &tcpip.ErrInvalidOptionValue{}}},
   544  		},
   545  	}
   546  	for _, test := range tests {
   547  		for _, protocol := range []tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} {
   548  			t.Run(fmt.Sprintf("%s %d", test.name, protocol), func(t *testing.T) {
   549  				s := stack.New(stack.Options{
   550  					NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   551  					TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   552  				})
   553  				defer s.Destroy()
   554  
   555  				for _, wantResult := range test.wantResult {
   556  					alreadyEnabled, err := s.EnableMulticastForwardingForProtocol(protocol, test.eventDispatcher)
   557  					result := enableMulticastForwardingResult{alreadyEnabled, err}
   558  					if !cmp.Equal(result, wantResult, cmpopts.EquateErrors()) {
   559  						t.Errorf("s.EnableMulticastForwardingForProtocol(%d, %#v) = (%t, %s), want = (%t, %s)", protocol, test.eventDispatcher, alreadyEnabled, err, wantResult.AlreadyEnabled, wantResult.Err)
   560  					}
   561  				}
   562  			})
   563  		}
   564  	}
   565  }
   566  
   567  func TestMulticastRouteLastUsedTime(t *testing.T) {
   568  	endpointConfigs := map[tcpip.NICID]endpointAddrType{
   569  		incomingNICID: incomingEndpointAddr,
   570  		outgoingNICID: outgoingEndpointAddr,
   571  		otherNICID:    otherEndpointAddr,
   572  	}
   573  
   574  	tests := []struct {
   575  		name             string
   576  		srcAddr, dstAddr addrType
   577  		wantErr          tcpip.Error
   578  	}{
   579  		{
   580  			name:    "success",
   581  			srcAddr: remoteUnicastAddr,
   582  			dstAddr: multicastAddr,
   583  			wantErr: nil,
   584  		},
   585  		{
   586  			name:    "no matching route",
   587  			srcAddr: remoteUnicastAddr,
   588  			dstAddr: otherMulticastAddr,
   589  			wantErr: &tcpip.ErrHostUnreachable{},
   590  		},
   591  		{
   592  			name:    "multicast source",
   593  			srcAddr: multicastAddr,
   594  			dstAddr: multicastAddr,
   595  			wantErr: &tcpip.ErrBadAddress{},
   596  		},
   597  		{
   598  			name:    "any source",
   599  			srcAddr: anyAddr,
   600  			dstAddr: multicastAddr,
   601  			wantErr: &tcpip.ErrBadAddress{},
   602  		},
   603  		{
   604  			name:    "link-local unicast source",
   605  			srcAddr: linkLocalUnicastAddr,
   606  			dstAddr: multicastAddr,
   607  			wantErr: &tcpip.ErrBadAddress{},
   608  		},
   609  		{
   610  			name:    "empty source",
   611  			srcAddr: emptyAddr,
   612  			dstAddr: multicastAddr,
   613  			wantErr: &tcpip.ErrBadAddress{},
   614  		},
   615  		{
   616  			name:    "unicast destination",
   617  			srcAddr: remoteUnicastAddr,
   618  			dstAddr: remoteUnicastAddr,
   619  			wantErr: &tcpip.ErrBadAddress{},
   620  		},
   621  		{
   622  			name:    "empty destination",
   623  			srcAddr: remoteUnicastAddr,
   624  			dstAddr: emptyAddr,
   625  			wantErr: &tcpip.ErrBadAddress{},
   626  		},
   627  		{
   628  			name:    "link-local multicast destination",
   629  			srcAddr: remoteUnicastAddr,
   630  			dstAddr: linkLocalMulticastAddr,
   631  			wantErr: &tcpip.ErrBadAddress{},
   632  		},
   633  	}
   634  
   635  	for _, test := range tests {
   636  		for _, protocol := range []tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} {
   637  			t.Run(fmt.Sprintf("%s %d", test.name, protocol), func(t *testing.T) {
   638  				clock := faketime.NewManualClock()
   639  				s := stack.New(stack.Options{
   640  					NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   641  					TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   642  					Clock:              clock,
   643  				})
   644  				defer s.Destroy()
   645  
   646  				if _, err := s.EnableMulticastForwardingForProtocol(protocol, &fakeMulticastEventDispatcher{}); err != nil {
   647  					t.Fatalf("s.EnableMulticastForwardingForProtocol(%d, _): (_, %s)", protocol, err)
   648  				}
   649  
   650  				endpoints := make(map[tcpip.NICID]*channel.Endpoint)
   651  				for nicID, addrType := range endpointConfigs {
   652  					ep := channel.New(1, ipv4.MaxTotalSize, "")
   653  					defer ep.Close()
   654  
   655  					if err := s.CreateNIC(nicID, ep); err != nil {
   656  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   657  					}
   658  					addr := tcpip.ProtocolAddress{
   659  						Protocol:          protocol,
   660  						AddressWithPrefix: getEndpointAddr(protocol, addrType),
   661  					}
   662  					if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
   663  						t.Fatalf("s.AddProtocolAddress(%d, %#v, {}): %s", nicID, addr, err)
   664  					}
   665  					s.SetNICMulticastForwarding(nicID, protocol, true /* enabled */)
   666  					endpoints[nicID] = ep
   667  				}
   668  
   669  				srcAddr := getAddr(protocol, remoteUnicastAddr)
   670  				dstAddr := getAddr(protocol, multicastAddr)
   671  
   672  				outgoingInterfaces := []stack.MulticastRouteOutgoingInterface{
   673  					{ID: outgoingNICID, MinTTL: routeMinTTL},
   674  				}
   675  
   676  				addresses := stack.UnicastSourceAndMulticastDestination{
   677  					Source:      srcAddr,
   678  					Destination: dstAddr,
   679  				}
   680  
   681  				route := stack.MulticastRoute{
   682  					ExpectedInputInterface: incomingNICID,
   683  					OutgoingInterfaces:     outgoingInterfaces,
   684  				}
   685  
   686  				if err := s.AddMulticastRoute(protocol, addresses, route); err != nil {
   687  					t.Fatalf("s.AddMulticastRoute(%d, %#v, %#v) = %s, want = nil", protocol, addresses, route, err)
   688  				}
   689  
   690  				incomingEp, ok := endpoints[incomingNICID]
   691  				if !ok {
   692  					t.Fatalf("Got endpoints[%d] = (_, false), want (_, true)", incomingNICID)
   693  				}
   694  
   695  				clock.Advance(10 * time.Second)
   696  
   697  				injectPacket(incomingEp, protocol, srcAddr, dstAddr, packetTTL)
   698  				p := incomingEp.Read()
   699  
   700  				if p != nil {
   701  					t.Fatalf("Expected no ICMP packet through incoming NIC, instead found: %#v", p)
   702  				}
   703  
   704  				addresses = stack.UnicastSourceAndMulticastDestination{
   705  					Source:      getAddr(protocol, test.srcAddr),
   706  					Destination: getAddr(protocol, test.dstAddr),
   707  				}
   708  				timestamp, err := s.MulticastRouteLastUsedTime(protocol, addresses)
   709  
   710  				if !cmp.Equal(err, test.wantErr, cmpopts.EquateErrors()) {
   711  					t.Errorf("s.MulticastRouteLastUsedTime(%d, %#v) = (_, %s), want = (_, %s)", protocol, addresses, err, test.wantErr)
   712  				}
   713  
   714  				if test.wantErr == nil {
   715  					wantTimestamp := clock.NowMonotonic()
   716  					if diff := cmp.Diff(wantTimestamp, timestamp, cmp.AllowUnexported(tcpip.MonotonicTime{})); diff != "" {
   717  						t.Errorf("s.MulticastRouteLastUsedTime(%d, %#v) timestamp mismatch (-want +got):\n%s", protocol, addresses, diff)
   718  					}
   719  				}
   720  			})
   721  		}
   722  	}
   723  }
   724  
   725  func TestRemoveMulticastRoute(t *testing.T) {
   726  	endpointConfigs := map[tcpip.NICID]endpointAddrType{
   727  		incomingNICID: incomingEndpointAddr,
   728  		outgoingNICID: outgoingEndpointAddr,
   729  		otherNICID:    otherEndpointAddr,
   730  	}
   731  
   732  	tests := []struct {
   733  		name             string
   734  		srcAddr, dstAddr addrType
   735  		wantErr          tcpip.Error
   736  	}{
   737  		{
   738  			name:    "success",
   739  			srcAddr: remoteUnicastAddr,
   740  			dstAddr: multicastAddr,
   741  			wantErr: nil,
   742  		},
   743  		{
   744  			name:    "no matching route",
   745  			srcAddr: remoteUnicastAddr,
   746  			dstAddr: otherMulticastAddr,
   747  			wantErr: &tcpip.ErrHostUnreachable{},
   748  		},
   749  		{
   750  			name:    "multicast source",
   751  			srcAddr: multicastAddr,
   752  			dstAddr: multicastAddr,
   753  			wantErr: &tcpip.ErrBadAddress{},
   754  		},
   755  		{
   756  			name:    "any source",
   757  			srcAddr: anyAddr,
   758  			dstAddr: multicastAddr,
   759  			wantErr: &tcpip.ErrBadAddress{},
   760  		},
   761  		{
   762  			name:    "link-local unicast source",
   763  			srcAddr: linkLocalUnicastAddr,
   764  			dstAddr: multicastAddr,
   765  			wantErr: &tcpip.ErrBadAddress{},
   766  		},
   767  		{
   768  			name:    "empty source",
   769  			srcAddr: emptyAddr,
   770  			dstAddr: multicastAddr,
   771  			wantErr: &tcpip.ErrBadAddress{},
   772  		},
   773  		{
   774  			name:    "unicast destination",
   775  			srcAddr: remoteUnicastAddr,
   776  			dstAddr: remoteUnicastAddr,
   777  			wantErr: &tcpip.ErrBadAddress{},
   778  		},
   779  		{
   780  			name:    "empty destination",
   781  			srcAddr: remoteUnicastAddr,
   782  			dstAddr: emptyAddr,
   783  			wantErr: &tcpip.ErrBadAddress{},
   784  		},
   785  		{
   786  			name:    "link-local multicast destination",
   787  			srcAddr: remoteUnicastAddr,
   788  			dstAddr: linkLocalMulticastAddr,
   789  			wantErr: &tcpip.ErrBadAddress{},
   790  		},
   791  	}
   792  
   793  	for _, test := range tests {
   794  		for _, protocol := range []tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} {
   795  			t.Run(fmt.Sprintf("%s %d", test.name, protocol), func(t *testing.T) {
   796  				s := stack.New(stack.Options{
   797  					NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   798  					TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   799  				})
   800  				defer s.Destroy()
   801  
   802  				if _, err := s.EnableMulticastForwardingForProtocol(protocol, &fakeMulticastEventDispatcher{}); err != nil {
   803  					t.Fatalf("s.EnableMulticastForwardingForProtocol(%d, _): (_, %s)", protocol, err)
   804  				}
   805  
   806  				endpoints := make(map[tcpip.NICID]*channel.Endpoint)
   807  				for nicID, addrType := range endpointConfigs {
   808  					ep := channel.New(1, ipv4.MaxTotalSize, "")
   809  					defer ep.Close()
   810  
   811  					if err := s.CreateNIC(nicID, ep); err != nil {
   812  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   813  					}
   814  					addr := tcpip.ProtocolAddress{
   815  						Protocol:          protocol,
   816  						AddressWithPrefix: getEndpointAddr(protocol, addrType),
   817  					}
   818  					if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
   819  						t.Fatalf("s.AddProtocolAddress(%d, %#v, {}): %s", nicID, addr, err)
   820  					}
   821  					s.SetNICMulticastForwarding(nicID, protocol, true /* enabled */)
   822  					endpoints[nicID] = ep
   823  				}
   824  
   825  				srcAddr := getAddr(protocol, remoteUnicastAddr)
   826  				dstAddr := getAddr(protocol, multicastAddr)
   827  
   828  				outgoingInterfaces := []stack.MulticastRouteOutgoingInterface{
   829  					{ID: outgoingNICID, MinTTL: routeMinTTL},
   830  				}
   831  
   832  				addresses := stack.UnicastSourceAndMulticastDestination{
   833  					Source:      srcAddr,
   834  					Destination: dstAddr,
   835  				}
   836  
   837  				route := stack.MulticastRoute{
   838  					ExpectedInputInterface: incomingNICID,
   839  					OutgoingInterfaces:     outgoingInterfaces,
   840  				}
   841  
   842  				if err := s.AddMulticastRoute(protocol, addresses, route); err != nil {
   843  					t.Fatalf("got s.AddMulticastRoute(%d, %#v, %#v) = %s, want = nil", protocol, addresses, route, err)
   844  				}
   845  
   846  				addresses = stack.UnicastSourceAndMulticastDestination{
   847  					Source:      getAddr(protocol, test.srcAddr),
   848  					Destination: getAddr(protocol, test.dstAddr),
   849  				}
   850  				err := s.RemoveMulticastRoute(protocol, addresses)
   851  
   852  				if !cmp.Equal(err, test.wantErr, cmpopts.EquateErrors()) {
   853  					t.Errorf("got s.RemoveMulticastRoute(%d, %#v) = %s, want %s", protocol, addresses, err, test.wantErr)
   854  				}
   855  
   856  				incomingEp, ok := endpoints[incomingNICID]
   857  				if !ok {
   858  					t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", incomingNICID)
   859  				}
   860  
   861  				injectPacket(incomingEp, protocol, srcAddr, dstAddr, packetTTL)
   862  				p := incomingEp.Read()
   863  
   864  				if p != nil {
   865  					// An ICMP error should never be sent in response to a multicast
   866  					// packet.
   867  					t.Errorf("expected no ICMP packet through incoming NIC, instead found: %#v", p)
   868  				}
   869  
   870  				outgoingEp, ok := endpoints[outgoingNICID]
   871  				if !ok {
   872  					t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", outgoingNICID)
   873  				}
   874  
   875  				p = outgoingEp.Read()
   876  
   877  				// If the route was successfully removed, then the packet should not be
   878  				// forwarded.
   879  				expectForward := test.wantErr != nil
   880  				if (p != nil) != expectForward {
   881  					t.Fatalf("got outgoingEp.Read() = %#v, want = (_ == nil) = %t", p, expectForward)
   882  				}
   883  
   884  				if expectForward {
   885  					checkEchoRequest(t, protocol, p, srcAddr, dstAddr, packetTTL-1)
   886  					p.DecRef()
   887  				}
   888  			})
   889  		}
   890  	}
   891  }
   892  
   893  func TestMulticastForwarding(t *testing.T) {
   894  	endpointConfigs := map[tcpip.NICID]endpointAddrType{
   895  		incomingNICID:      incomingEndpointAddr,
   896  		outgoingNICID:      outgoingEndpointAddr,
   897  		otherOutgoingNICID: otherOutgoingEndpointAddr,
   898  		otherNICID:         otherEndpointAddr,
   899  	}
   900  
   901  	contains := func(want tcpip.NICID, items []tcpip.NICID) bool {
   902  		for _, item := range items {
   903  			if want == item {
   904  				return true
   905  			}
   906  		}
   907  		return false
   908  	}
   909  
   910  	tests := []struct {
   911  		name                                 string
   912  		dstAddr                              addrType
   913  		ttl                                  uint8
   914  		routeInputInterface                  tcpip.NICID
   915  		disableMulticastForwardingForNIC     bool
   916  		updateMulticastForwardingForProtocol func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, stack.MulticastForwardingEventDispatcher)
   917  		removeOutputInterface                tcpip.NICID
   918  		expectMissingRouteEvent              bool
   919  		expectUnexpectedInputInterfaceEvent  bool
   920  		joinMulticastGroup                   bool
   921  		expectedForwardingInterfaces         []tcpip.NICID
   922  	}{
   923  		{
   924  			name:                         "forward only",
   925  			dstAddr:                      multicastAddr,
   926  			ttl:                          packetTTL,
   927  			routeInputInterface:          incomingNICID,
   928  			expectedForwardingInterfaces: []tcpip.NICID{outgoingNICID, otherOutgoingNICID},
   929  		},
   930  		{
   931  			name:                         "forward and local",
   932  			dstAddr:                      multicastAddr,
   933  			ttl:                          packetTTL,
   934  			routeInputInterface:          incomingNICID,
   935  			joinMulticastGroup:           true,
   936  			expectedForwardingInterfaces: []tcpip.NICID{outgoingNICID, otherOutgoingNICID},
   937  		},
   938  		{
   939  			name:                         "local only",
   940  			dstAddr:                      linkLocalMulticastAddr,
   941  			ttl:                          packetTTL,
   942  			routeInputInterface:          incomingNICID,
   943  			joinMulticastGroup:           true,
   944  			expectedForwardingInterfaces: []tcpip.NICID{},
   945  		},
   946  		{
   947  			name:                             "multicast forwarding disabled for NIC",
   948  			disableMulticastForwardingForNIC: true,
   949  			dstAddr:                          multicastAddr,
   950  			ttl:                              packetTTL,
   951  			routeInputInterface:              incomingNICID,
   952  			expectedForwardingInterfaces:     []tcpip.NICID{},
   953  		},
   954  		{
   955  			name:    "multicast forwarding disabled for protocol",
   956  			dstAddr: multicastAddr,
   957  			updateMulticastForwardingForProtocol: func(t *testing.T, s *stack.Stack, protocol tcpip.NetworkProtocolNumber, disp stack.MulticastForwardingEventDispatcher) {
   958  				s.DisableMulticastForwardingForProtocol(protocol)
   959  			},
   960  			ttl:                          packetTTL,
   961  			routeInputInterface:          incomingNICID,
   962  			expectedForwardingInterfaces: []tcpip.NICID{},
   963  		},
   964  		{
   965  			name:    "route table cleared after multicast forwarding disabled for protocol",
   966  			dstAddr: multicastAddr,
   967  			updateMulticastForwardingForProtocol: func(t *testing.T, s *stack.Stack, protocol tcpip.NetworkProtocolNumber, disp stack.MulticastForwardingEventDispatcher) {
   968  				t.Helper()
   969  
   970  				s.DisableMulticastForwardingForProtocol(protocol)
   971  				if _, err := s.EnableMulticastForwardingForProtocol(protocol, disp); err != nil {
   972  					t.Fatalf("s.EnableMulticastForwardingForProtocol(%d, _): (_, %s)", protocol, err)
   973  				}
   974  			},
   975  			ttl:                          packetTTL,
   976  			routeInputInterface:          incomingNICID,
   977  			expectMissingRouteEvent:      true,
   978  			expectedForwardingInterfaces: []tcpip.NICID{},
   979  		},
   980  		{
   981  			name:                                "unexpected input interface",
   982  			dstAddr:                             multicastAddr,
   983  			ttl:                                 packetTTL,
   984  			routeInputInterface:                 otherNICID,
   985  			expectUnexpectedInputInterfaceEvent: true,
   986  			expectedForwardingInterfaces:        []tcpip.NICID{},
   987  		},
   988  		{
   989  			name:                         "output interface removed",
   990  			dstAddr:                      multicastAddr,
   991  			ttl:                          packetTTL,
   992  			routeInputInterface:          incomingNICID,
   993  			removeOutputInterface:        outgoingNICID,
   994  			expectedForwardingInterfaces: []tcpip.NICID{otherOutgoingNICID},
   995  		},
   996  		{
   997  			name:                         "ttl greater than outgoingNICID route min",
   998  			dstAddr:                      multicastAddr,
   999  			ttl:                          routeMinTTL + 1,
  1000  			routeInputInterface:          incomingNICID,
  1001  			expectedForwardingInterfaces: []tcpip.NICID{outgoingNICID, otherOutgoingNICID},
  1002  		},
  1003  		{
  1004  			name:                         "ttl same as outgoingNICID route min",
  1005  			dstAddr:                      multicastAddr,
  1006  			ttl:                          routeMinTTL,
  1007  			routeInputInterface:          incomingNICID,
  1008  			expectedForwardingInterfaces: []tcpip.NICID{outgoingNICID},
  1009  		},
  1010  		{
  1011  			name:                         "ttl less than outgoingNICID route min",
  1012  			dstAddr:                      multicastAddr,
  1013  			ttl:                          routeMinTTL - 1,
  1014  			routeInputInterface:          incomingNICID,
  1015  			expectedForwardingInterfaces: []tcpip.NICID{},
  1016  		},
  1017  		{
  1018  			name:                         "no matching route",
  1019  			dstAddr:                      otherMulticastAddr,
  1020  			ttl:                          packetTTL,
  1021  			routeInputInterface:          incomingNICID,
  1022  			expectMissingRouteEvent:      true,
  1023  			expectedForwardingInterfaces: []tcpip.NICID{},
  1024  		},
  1025  	}
  1026  
  1027  	for _, test := range tests {
  1028  		for _, protocol := range []tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} {
  1029  			ipv4EventDispatcher := &fakeMulticastEventDispatcher{}
  1030  			ipv6EventDispatcher := &fakeMulticastEventDispatcher{}
  1031  
  1032  			eventDispatchers := map[tcpip.NetworkProtocolNumber]*fakeMulticastEventDispatcher{
  1033  				ipv4.ProtocolNumber: ipv4EventDispatcher,
  1034  				ipv6.ProtocolNumber: ipv6EventDispatcher,
  1035  			}
  1036  
  1037  			t.Run(fmt.Sprintf("%s %d", test.name, protocol), func(t *testing.T) {
  1038  				s := stack.New(stack.Options{
  1039  					NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  1040  					TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
  1041  				})
  1042  				defer s.Destroy()
  1043  
  1044  				eventDispatcher, ok := eventDispatchers[protocol]
  1045  				if !ok {
  1046  					t.Fatalf("eventDispatchers[%d] = (_, false), want (_, true)", protocol)
  1047  				}
  1048  
  1049  				if _, err := s.EnableMulticastForwardingForProtocol(protocol, eventDispatcher); err != nil {
  1050  					t.Fatalf("s.EnableMulticastForwardingForProtocol(%d, %#v): (_, %s)", protocol, eventDispatcher, err)
  1051  				}
  1052  
  1053  				endpoints := make(map[tcpip.NICID]*channel.Endpoint)
  1054  				for nicID, addrType := range endpointConfigs {
  1055  					ep := channel.New(1, ipv4.MaxTotalSize, "")
  1056  					defer ep.Close()
  1057  
  1058  					if err := s.CreateNIC(nicID, ep); err != nil {
  1059  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
  1060  					}
  1061  					addr := tcpip.ProtocolAddress{
  1062  						Protocol:          protocol,
  1063  						AddressWithPrefix: getEndpointAddr(protocol, addrType),
  1064  					}
  1065  					if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
  1066  						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
  1067  					}
  1068  
  1069  					s.SetNICMulticastForwarding(nicID, protocol, true /* enable */)
  1070  					endpoints[nicID] = ep
  1071  				}
  1072  
  1073  				if err := s.SetForwardingDefaultAndAllNICs(protocol, true /* enabled */); err != nil {
  1074  					t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protocol, err)
  1075  				}
  1076  
  1077  				srcAddr := getAddr(protocol, remoteUnicastAddr)
  1078  				dstAddr := getAddr(protocol, test.dstAddr)
  1079  
  1080  				outgoingInterfaces := []stack.MulticastRouteOutgoingInterface{
  1081  					{ID: outgoingNICID, MinTTL: routeMinTTL},
  1082  					{ID: otherOutgoingNICID, MinTTL: routeMinTTL + 1},
  1083  				}
  1084  				addresses := stack.UnicastSourceAndMulticastDestination{
  1085  					Source:      srcAddr,
  1086  					Destination: getAddr(protocol, multicastAddr),
  1087  				}
  1088  
  1089  				route := stack.MulticastRoute{
  1090  					ExpectedInputInterface: test.routeInputInterface,
  1091  					OutgoingInterfaces:     outgoingInterfaces,
  1092  				}
  1093  
  1094  				if err := s.AddMulticastRoute(protocol, addresses, route); err != nil {
  1095  					t.Fatalf("AddMulticastRoute(%d, %#v, %#v): %s", protocol, addresses, route, err)
  1096  				}
  1097  
  1098  				if test.disableMulticastForwardingForNIC {
  1099  					for nicID := range endpoints {
  1100  						s.SetNICMulticastForwarding(nicID, protocol, false /* enable */)
  1101  					}
  1102  				}
  1103  
  1104  				if test.updateMulticastForwardingForProtocol != nil {
  1105  					test.updateMulticastForwardingForProtocol(t, s, protocol, eventDispatcher)
  1106  				}
  1107  
  1108  				if test.removeOutputInterface != 0 {
  1109  					if err := s.RemoveNIC(test.removeOutputInterface); err != nil {
  1110  						t.Fatalf("RemoveNIC(%d): %s", test.removeOutputInterface, err)
  1111  					}
  1112  				}
  1113  
  1114  				// Add a route that can be used to send an ICMP echo reply (if the packet
  1115  				// is delivered locally).
  1116  				s.SetRouteTable([]tcpip.Route{
  1117  					{
  1118  						Destination: header.IPv4EmptySubnet,
  1119  						NIC:         otherNICID,
  1120  					},
  1121  					{
  1122  						Destination: header.IPv6EmptySubnet,
  1123  						NIC:         otherNICID,
  1124  					},
  1125  				})
  1126  
  1127  				if test.joinMulticastGroup {
  1128  					if err := s.JoinGroup(protocol, incomingNICID, dstAddr); err != nil {
  1129  						t.Fatalf("JoinGroup(%d, %d, %s): %s", protocol, incomingNICID, dstAddr, err)
  1130  					}
  1131  				}
  1132  
  1133  				incomingEp, ok := endpoints[incomingNICID]
  1134  				if !ok {
  1135  					t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", incomingNICID)
  1136  				}
  1137  
  1138  				injectPacket(incomingEp, protocol, srcAddr, dstAddr, test.ttl)
  1139  				p := incomingEp.Read()
  1140  
  1141  				if p != nil {
  1142  					// An ICMP error should never be sent in response to a multicast packet.
  1143  					t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", p)
  1144  				}
  1145  
  1146  				for _, nicID := range []tcpip.NICID{outgoingNICID, otherOutgoingNICID} {
  1147  					outgoingEp, ok := endpoints[nicID]
  1148  					if !ok {
  1149  						t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", nicID)
  1150  					}
  1151  
  1152  					p := outgoingEp.Read()
  1153  
  1154  					expectForward := contains(nicID, test.expectedForwardingInterfaces)
  1155  
  1156  					if (p != nil) != expectForward {
  1157  						t.Fatalf("got outgoingEp.Read() = %#v, want = (_ == nil) = %t", p, expectForward)
  1158  					}
  1159  
  1160  					if expectForward {
  1161  						checkEchoRequest(t, protocol, p, srcAddr, dstAddr, test.ttl-1)
  1162  						p.DecRef()
  1163  					}
  1164  				}
  1165  
  1166  				otherEp, ok := endpoints[otherNICID]
  1167  				if !ok {
  1168  					t.Fatalf("got endpoints[%d] = (_, false), want (_, true)", otherNICID)
  1169  				}
  1170  
  1171  				p = otherEp.Read()
  1172  
  1173  				if (p != nil) != test.joinMulticastGroup {
  1174  					t.Fatalf("got otherEp.Read() = %#v, want = (_ == nil) = %t", p, test.joinMulticastGroup)
  1175  				}
  1176  
  1177  				incomingEpAddrType, ok := endpointConfigs[incomingNICID]
  1178  				if !ok {
  1179  					t.Fatalf("got endpointConfigs[%d] = (_, false), want (_, true)", incomingNICID)
  1180  				}
  1181  
  1182  				if test.joinMulticastGroup {
  1183  					checkEchoReply(t, protocol, p, getEndpointAddr(protocol, incomingEpAddrType).Address, srcAddr)
  1184  					p.DecRef()
  1185  				}
  1186  
  1187  				wantUnexpectedInputInterfaceEvent := func() *onUnexpectedInputInterfaceData {
  1188  					if test.expectUnexpectedInputInterfaceEvent {
  1189  						return &onUnexpectedInputInterfaceData{stack.MulticastPacketContext{stack.UnicastSourceAndMulticastDestination{srcAddr, dstAddr}, incomingNICID}, test.routeInputInterface}
  1190  					}
  1191  					return nil
  1192  				}()
  1193  
  1194  				if diff := cmp.Diff(wantUnexpectedInputInterfaceEvent, eventDispatcher.onUnexpectedInputInterfaceData, cmp.AllowUnexported(onUnexpectedInputInterfaceData{})); diff != "" {
  1195  					t.Errorf("onUnexpectedInputInterfaceData mismatch (-want +got):\n%s", diff)
  1196  				}
  1197  
  1198  				wantMissingRouteEvent := func() *onMissingRouteData {
  1199  					if test.expectMissingRouteEvent {
  1200  						return &onMissingRouteData{stack.MulticastPacketContext{stack.UnicastSourceAndMulticastDestination{srcAddr, dstAddr}, incomingNICID}}
  1201  					}
  1202  					return nil
  1203  				}()
  1204  
  1205  				if diff := cmp.Diff(wantMissingRouteEvent, eventDispatcher.onMissingRouteData, cmp.AllowUnexported(onMissingRouteData{})); diff != "" {
  1206  					t.Errorf("onMissingRouteData mismatch (-want +got):\n%s", diff)
  1207  				}
  1208  			})
  1209  		}
  1210  	}
  1211  }
  1212  
  1213  func TestMain(m *testing.M) {
  1214  	refs.SetLeakMode(refs.LeaksPanic)
  1215  	code := m.Run()
  1216  	refs.DoLeakCheck()
  1217  	os.Exit(code)
  1218  }