github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/multicast_group_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ip_test
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/SagerNet/gvisor/pkg/tcpip"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/faketime"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/link/loopback"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    34  )
    35  
    36  const (
    37  	linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
    38  
    39  	defaultIPv4PrefixLength = 24
    40  
    41  	igmpMembershipQuery    = uint8(header.IGMPMembershipQuery)
    42  	igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
    43  	igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
    44  	igmpLeaveGroup         = uint8(header.IGMPLeaveGroup)
    45  	mldQuery               = uint8(header.ICMPv6MulticastListenerQuery)
    46  	mldReport              = uint8(header.ICMPv6MulticastListenerReport)
    47  	mldDone                = uint8(header.ICMPv6MulticastListenerDone)
    48  
    49  	maxUnsolicitedReports = 2
    50  )
    51  
    52  var (
    53  	stackIPv4Addr      = testutil.MustParse4("10.0.0.1")
    54  	linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1")
    55  	linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2")
    56  
    57  	ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3")
    58  	ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4")
    59  	ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5")
    60  	ipv6MulticastAddr1 = testutil.MustParse6("ff02::3")
    61  	ipv6MulticastAddr2 = testutil.MustParse6("ff02::4")
    62  	ipv6MulticastAddr3 = testutil.MustParse6("ff02::5")
    63  )
    64  
    65  var (
    66  	// unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
    67  	// NIC will wait before sending an unsolicited report after joining a
    68  	// multicast group, in deciseconds.
    69  	unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
    70  		const decisecond = time.Second / 10
    71  		if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
    72  			panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
    73  		}
    74  		return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
    75  	}()
    76  
    77  	ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1)
    78  )
    79  
    80  // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
    81  // sent to the provided address with the passed fields set.
    82  func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
    83  	t.Helper()
    84  
    85  	payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
    86  	checker.IPv6WithExtHdr(t, payload,
    87  		checker.IPv6ExtHdr(
    88  			checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
    89  		),
    90  		checker.SrcAddr(linkLocalIPv6Addr1),
    91  		checker.DstAddr(remoteAddress),
    92  		// Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
    93  		checker.TTL(1),
    94  		checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
    95  			checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
    96  			checker.MLDMulticastAddress(groupAddress),
    97  		),
    98  	)
    99  }
   100  
   101  // validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
   102  // sent to the provided address with the passed fields set.
   103  func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
   104  	t.Helper()
   105  
   106  	payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
   107  	checker.IPv4(t, payload,
   108  		checker.SrcAddr(stackIPv4Addr),
   109  		checker.DstAddr(remoteAddress),
   110  		// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
   111  		checker.TTL(1),
   112  		checker.IPv4RouterAlert(),
   113  		checker.IGMP(
   114  			checker.IGMPType(header.IGMPType(igmpType)),
   115  			checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
   116  			checker.IGMPGroupAddress(groupAddress),
   117  		),
   118  	)
   119  }
   120  
   121  func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
   122  	t.Helper()
   123  
   124  	e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
   125  	s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
   126  	return e, s, clock
   127  }
   128  
   129  func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
   130  	t.Helper()
   131  
   132  	igmpEnabled := v4 && mgpEnabled
   133  	mldEnabled := !v4 && mgpEnabled
   134  
   135  	clock := faketime.NewManualClock()
   136  	s := stack.New(stack.Options{
   137  		NetworkProtocols: []stack.NetworkProtocolFactory{
   138  			ipv4.NewProtocolWithOptions(ipv4.Options{
   139  				IGMP: ipv4.IGMPOptions{
   140  					Enabled: igmpEnabled,
   141  				},
   142  			}),
   143  			ipv6.NewProtocolWithOptions(ipv6.Options{
   144  				MLD: ipv6.MLDOptions{
   145  					Enabled: mldEnabled,
   146  				},
   147  			}),
   148  		},
   149  		Clock: clock,
   150  	})
   151  	if err := s.CreateNIC(nicID, e); err != nil {
   152  		t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
   153  	}
   154  	addr := tcpip.AddressWithPrefix{
   155  		Address:   stackIPv4Addr,
   156  		PrefixLen: defaultIPv4PrefixLength,
   157  	}
   158  	if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
   159  		t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
   160  	}
   161  	if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
   162  		t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
   163  	}
   164  
   165  	return s, clock
   166  }
   167  
   168  // checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
   169  // when it is created with an IPv6 address.
   170  //
   171  // To not interfere with tests, checkInitialIPv6Groups will leave the added
   172  // address's solicited node multicast group so that the tests can all assume
   173  // the NIC has not joined any IPv6 groups.
   174  func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
   175  	t.Helper()
   176  
   177  	stats := s.Stats().ICMP.V6.PacketsSent
   178  
   179  	reportCounter++
   180  	if got := stats.MulticastListenerReport.Value(); got != reportCounter {
   181  		t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
   182  	}
   183  	if p, ok := e.Read(); !ok {
   184  		t.Fatal("expected a report message to be sent")
   185  	} else {
   186  		validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
   187  	}
   188  
   189  	// Leave the group to not affect the tests. This is fine since we are not
   190  	// testing DAD or the solicited node address specifically.
   191  	if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
   192  		t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
   193  	}
   194  	leaveCounter++
   195  	if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
   196  		t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
   197  	}
   198  	if p, ok := e.Read(); !ok {
   199  		t.Fatal("expected a report message to be sent")
   200  	} else {
   201  		validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC)
   202  	}
   203  
   204  	// Should not send any more packets.
   205  	clock.Advance(time.Hour)
   206  	if p, ok := e.Read(); ok {
   207  		t.Fatalf("sent unexpected packet = %#v", p)
   208  	}
   209  
   210  	return reportCounter, leaveCounter
   211  }
   212  
   213  // createAndInjectIGMPPacket creates and injects an IGMP packet with the
   214  // specified fields.
   215  func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
   216  	options := header.IPv4OptionsSerializer{
   217  		&header.IPv4SerializableRouterAlertOption{},
   218  	}
   219  	buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize)
   220  	ip := header.IPv4(buf)
   221  	ip.Encode(&header.IPv4Fields{
   222  		TotalLength: uint16(len(buf)),
   223  		TTL:         header.IGMPTTL,
   224  		Protocol:    uint8(header.IGMPProtocolNumber),
   225  		SrcAddr:     remoteIPv4Addr,
   226  		DstAddr:     header.IPv4AllSystems,
   227  		Options:     options,
   228  	})
   229  	ip.SetChecksum(^ip.CalculateChecksum())
   230  
   231  	igmp := header.IGMP(ip.Payload())
   232  	igmp.SetType(header.IGMPType(igmpType))
   233  	igmp.SetMaxRespTime(maxRespTime)
   234  	igmp.SetGroupAddress(groupAddress)
   235  	igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
   236  
   237  	e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   238  		Data: buf.ToVectorisedView(),
   239  	}))
   240  }
   241  
   242  // createAndInjectMLDPacket creates and injects an MLD packet with the
   243  // specified fields.
   244  func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
   245  	extensionHeaders := header.IPv6ExtHdrSerializer{
   246  		header.IPv6SerializableHopByHopExtHdr{
   247  			&header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
   248  		},
   249  	}
   250  
   251  	extensionHeadersLength := extensionHeaders.Length()
   252  	payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize
   253  	buf := buffer.NewView(header.IPv6MinimumSize + payloadLength)
   254  
   255  	ip := header.IPv6(buf)
   256  	ip.Encode(&header.IPv6Fields{
   257  		PayloadLength:     uint16(payloadLength),
   258  		HopLimit:          header.MLDHopLimit,
   259  		TransportProtocol: header.ICMPv6ProtocolNumber,
   260  		SrcAddr:           linkLocalIPv6Addr2,
   261  		DstAddr:           header.IPv6AllNodesMulticastAddress,
   262  		ExtensionHeaders:  extensionHeaders,
   263  	})
   264  
   265  	icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:])
   266  	icmp.SetType(header.ICMPv6Type(mldType))
   267  	mld := header.MLD(icmp.MessageBody())
   268  	mld.SetMaximumResponseDelay(uint16(maxRespDelay))
   269  	mld.SetMulticastAddress(groupAddress)
   270  	icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   271  		Header: icmp,
   272  		Src:    linkLocalIPv6Addr2,
   273  		Dst:    header.IPv6AllNodesMulticastAddress,
   274  	}))
   275  
   276  	e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   277  		Data: buf.ToVectorisedView(),
   278  	}))
   279  }
   280  
   281  // TestMGPDisabled tests that the multicast group protocol is not enabled by
   282  // default.
   283  func TestMGPDisabled(t *testing.T) {
   284  	tests := []struct {
   285  		name              string
   286  		protoNum          tcpip.NetworkProtocolNumber
   287  		multicastAddr     tcpip.Address
   288  		sentReportStat    func(*stack.Stack) *tcpip.StatCounter
   289  		receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
   290  		rxQuery           func(*channel.Endpoint)
   291  	}{
   292  		{
   293  			name:          "IGMP",
   294  			protoNum:      ipv4.ProtocolNumber,
   295  			multicastAddr: ipv4MulticastAddr1,
   296  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   297  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   298  			},
   299  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   300  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   301  			},
   302  			rxQuery: func(e *channel.Endpoint) {
   303  				createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
   304  			},
   305  		},
   306  		{
   307  			name:          "MLD",
   308  			protoNum:      ipv6.ProtocolNumber,
   309  			multicastAddr: ipv6MulticastAddr1,
   310  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   311  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
   312  			},
   313  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   314  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   315  			},
   316  			rxQuery: func(e *channel.Endpoint) {
   317  				createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
   318  			},
   319  		},
   320  	}
   321  
   322  	for _, test := range tests {
   323  		t.Run(test.name, func(t *testing.T) {
   324  			e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
   325  
   326  			// This NIC may join multicast groups when it is enabled but since MGP is
   327  			// disabled, no reports should be sent.
   328  			sentReportStat := test.sentReportStat(s)
   329  			if got := sentReportStat.Value(); got != 0 {
   330  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
   331  			}
   332  			clock.Advance(time.Hour)
   333  			if p, ok := e.Read(); ok {
   334  				t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
   335  			}
   336  
   337  			// Test joining a specific group explicitly and verify that no reports are
   338  			// sent.
   339  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   340  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   341  			}
   342  			if got := sentReportStat.Value(); got != 0 {
   343  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
   344  			}
   345  			clock.Advance(time.Hour)
   346  			if p, ok := e.Read(); ok {
   347  				t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
   348  			}
   349  
   350  			// Inject a general query message. This should only trigger a report to be
   351  			// sent if the MGP was enabled.
   352  			test.rxQuery(e)
   353  			if got := test.receivedQueryStat(s).Value(); got != 1 {
   354  				t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
   355  			}
   356  			clock.Advance(time.Hour)
   357  			if p, ok := e.Read(); ok {
   358  				t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
   359  			}
   360  		})
   361  	}
   362  }
   363  
   364  func TestMGPReceiveCounters(t *testing.T) {
   365  	tests := []struct {
   366  		name         string
   367  		headerType   uint8
   368  		maxRespTime  byte
   369  		groupAddress tcpip.Address
   370  		statCounter  func(*stack.Stack) *tcpip.StatCounter
   371  		rxMGPkt      func(*channel.Endpoint, byte, byte, tcpip.Address)
   372  	}{
   373  		{
   374  			name:         "IGMP Membership Query",
   375  			headerType:   igmpMembershipQuery,
   376  			maxRespTime:  unsolicitedIGMPReportIntervalMaxTenthSec,
   377  			groupAddress: header.IPv4Any,
   378  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   379  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   380  			},
   381  			rxMGPkt: createAndInjectIGMPPacket,
   382  		},
   383  		{
   384  			name:         "IGMPv1 Membership Report",
   385  			headerType:   igmpv1MembershipReport,
   386  			maxRespTime:  0,
   387  			groupAddress: header.IPv4AllSystems,
   388  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   389  				return s.Stats().IGMP.PacketsReceived.V1MembershipReport
   390  			},
   391  			rxMGPkt: createAndInjectIGMPPacket,
   392  		},
   393  		{
   394  			name:         "IGMPv2 Membership Report",
   395  			headerType:   igmpv2MembershipReport,
   396  			maxRespTime:  0,
   397  			groupAddress: header.IPv4AllSystems,
   398  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   399  				return s.Stats().IGMP.PacketsReceived.V2MembershipReport
   400  			},
   401  			rxMGPkt: createAndInjectIGMPPacket,
   402  		},
   403  		{
   404  			name:         "IGMP Leave Group",
   405  			headerType:   igmpLeaveGroup,
   406  			maxRespTime:  0,
   407  			groupAddress: header.IPv4AllRoutersGroup,
   408  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   409  				return s.Stats().IGMP.PacketsReceived.LeaveGroup
   410  			},
   411  			rxMGPkt: createAndInjectIGMPPacket,
   412  		},
   413  		{
   414  			name:         "MLD Query",
   415  			headerType:   mldQuery,
   416  			maxRespTime:  0,
   417  			groupAddress: header.IPv6Any,
   418  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   419  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   420  			},
   421  			rxMGPkt: createAndInjectMLDPacket,
   422  		},
   423  		{
   424  			name:         "MLD Report",
   425  			headerType:   mldReport,
   426  			maxRespTime:  0,
   427  			groupAddress: header.IPv6Any,
   428  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   429  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
   430  			},
   431  			rxMGPkt: createAndInjectMLDPacket,
   432  		},
   433  		{
   434  			name:         "MLD Done",
   435  			headerType:   mldDone,
   436  			maxRespTime:  0,
   437  			groupAddress: header.IPv6Any,
   438  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   439  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
   440  			},
   441  			rxMGPkt: createAndInjectMLDPacket,
   442  		},
   443  	}
   444  
   445  	for _, test := range tests {
   446  		t.Run(test.name, func(t *testing.T) {
   447  			e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
   448  
   449  			test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
   450  			if got := test.statCounter(s).Value(); got != 1 {
   451  				t.Fatalf("got %s received = %d, want = 1", test.name, got)
   452  			}
   453  		})
   454  	}
   455  }
   456  
   457  // TestMGPJoinGroup tests that when explicitly joining a multicast group, the
   458  // stack schedules and sends correct Membership Reports.
   459  func TestMGPJoinGroup(t *testing.T) {
   460  	tests := []struct {
   461  		name                        string
   462  		protoNum                    tcpip.NetworkProtocolNumber
   463  		multicastAddr               tcpip.Address
   464  		maxUnsolicitedResponseDelay time.Duration
   465  		sentReportStat              func(*stack.Stack) *tcpip.StatCounter
   466  		receivedQueryStat           func(*stack.Stack) *tcpip.StatCounter
   467  		validateReport              func(*testing.T, channel.PacketInfo)
   468  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
   469  	}{
   470  		{
   471  			name:                        "IGMP",
   472  			protoNum:                    ipv4.ProtocolNumber,
   473  			multicastAddr:               ipv4MulticastAddr1,
   474  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
   475  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   476  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   477  			},
   478  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   479  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   480  			},
   481  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   482  				t.Helper()
   483  
   484  				validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   485  			},
   486  		},
   487  		{
   488  			name:                        "MLD",
   489  			protoNum:                    ipv6.ProtocolNumber,
   490  			multicastAddr:               ipv6MulticastAddr1,
   491  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
   492  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   493  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
   494  			},
   495  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   496  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   497  			},
   498  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   499  				t.Helper()
   500  
   501  				validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   502  			},
   503  			checkInitialGroups: checkInitialIPv6Groups,
   504  		},
   505  	}
   506  
   507  	for _, test := range tests {
   508  		t.Run(test.name, func(t *testing.T) {
   509  			e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   510  
   511  			var reportCounter uint64
   512  			if test.checkInitialGroups != nil {
   513  				reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
   514  			}
   515  
   516  			// Test joining a specific address explicitly and verify a Report is sent
   517  			// immediately.
   518  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   519  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   520  			}
   521  			reportCounter++
   522  			sentReportStat := test.sentReportStat(s)
   523  			if got := sentReportStat.Value(); got != reportCounter {
   524  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
   525  			}
   526  			if p, ok := e.Read(); !ok {
   527  				t.Fatal("expected a report message to be sent")
   528  			} else {
   529  				test.validateReport(t, p)
   530  			}
   531  			if t.Failed() {
   532  				t.FailNow()
   533  			}
   534  
   535  			// Verify the second report is sent by the maximum unsolicited response
   536  			// interval.
   537  			p, ok := e.Read()
   538  			if ok {
   539  				t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
   540  			}
   541  			clock.Advance(test.maxUnsolicitedResponseDelay)
   542  			reportCounter++
   543  			if got := sentReportStat.Value(); got != reportCounter {
   544  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
   545  			}
   546  			if p, ok := e.Read(); !ok {
   547  				t.Fatal("expected a report message to be sent")
   548  			} else {
   549  				test.validateReport(t, p)
   550  			}
   551  
   552  			// Should not send any more packets.
   553  			clock.Advance(time.Hour)
   554  			if p, ok := e.Read(); ok {
   555  				t.Fatalf("sent unexpected packet = %#v", p)
   556  			}
   557  		})
   558  	}
   559  }
   560  
   561  // TestMGPLeaveGroup tests that when leaving a previously joined multicast
   562  // group the stack sends a leave/done message.
   563  func TestMGPLeaveGroup(t *testing.T) {
   564  	tests := []struct {
   565  		name               string
   566  		protoNum           tcpip.NetworkProtocolNumber
   567  		multicastAddr      tcpip.Address
   568  		sentReportStat     func(*stack.Stack) *tcpip.StatCounter
   569  		sentLeaveStat      func(*stack.Stack) *tcpip.StatCounter
   570  		validateReport     func(*testing.T, channel.PacketInfo)
   571  		validateLeave      func(*testing.T, channel.PacketInfo)
   572  		checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
   573  	}{
   574  		{
   575  			name:          "IGMP",
   576  			protoNum:      ipv4.ProtocolNumber,
   577  			multicastAddr: ipv4MulticastAddr1,
   578  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   579  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   580  			},
   581  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
   582  				return s.Stats().IGMP.PacketsSent.LeaveGroup
   583  			},
   584  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   585  				t.Helper()
   586  
   587  				validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   588  			},
   589  			validateLeave: func(t *testing.T, p channel.PacketInfo) {
   590  				t.Helper()
   591  
   592  				validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
   593  			},
   594  		},
   595  		{
   596  			name:          "MLD",
   597  			protoNum:      ipv6.ProtocolNumber,
   598  			multicastAddr: ipv6MulticastAddr1,
   599  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   600  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
   601  			},
   602  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
   603  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
   604  			},
   605  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   606  				t.Helper()
   607  
   608  				validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   609  			},
   610  			validateLeave: func(t *testing.T, p channel.PacketInfo) {
   611  				t.Helper()
   612  
   613  				validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
   614  			},
   615  			checkInitialGroups: checkInitialIPv6Groups,
   616  		},
   617  	}
   618  
   619  	for _, test := range tests {
   620  		t.Run(test.name, func(t *testing.T) {
   621  			e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   622  
   623  			var reportCounter uint64
   624  			var leaveCounter uint64
   625  			if test.checkInitialGroups != nil {
   626  				reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
   627  			}
   628  
   629  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   630  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   631  			}
   632  			reportCounter++
   633  			if got := test.sentReportStat(s).Value(); got != reportCounter {
   634  				t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
   635  			}
   636  			if p, ok := e.Read(); !ok {
   637  				t.Fatal("expected a report message to be sent")
   638  			} else {
   639  				test.validateReport(t, p)
   640  			}
   641  			if t.Failed() {
   642  				t.FailNow()
   643  			}
   644  
   645  			// Leaving the group should trigger an leave/done message to be sent.
   646  			if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   647  				t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
   648  			}
   649  			leaveCounter++
   650  			if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
   651  				t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
   652  			}
   653  			if p, ok := e.Read(); !ok {
   654  				t.Fatal("expected a leave message to be sent")
   655  			} else {
   656  				test.validateLeave(t, p)
   657  			}
   658  
   659  			// Should not send any more packets.
   660  			clock.Advance(time.Hour)
   661  			if p, ok := e.Read(); ok {
   662  				t.Fatalf("sent unexpected packet = %#v", p)
   663  			}
   664  		})
   665  	}
   666  }
   667  
   668  // TestMGPQueryMessages tests that a report is sent in response to query
   669  // messages.
   670  func TestMGPQueryMessages(t *testing.T) {
   671  	tests := []struct {
   672  		name                        string
   673  		protoNum                    tcpip.NetworkProtocolNumber
   674  		multicastAddr               tcpip.Address
   675  		maxUnsolicitedResponseDelay time.Duration
   676  		sentReportStat              func(*stack.Stack) *tcpip.StatCounter
   677  		receivedQueryStat           func(*stack.Stack) *tcpip.StatCounter
   678  		rxQuery                     func(*channel.Endpoint, uint8, tcpip.Address)
   679  		validateReport              func(*testing.T, channel.PacketInfo)
   680  		maxRespTimeToDuration       func(uint8) time.Duration
   681  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
   682  	}{
   683  		{
   684  			name:                        "IGMP",
   685  			protoNum:                    ipv4.ProtocolNumber,
   686  			multicastAddr:               ipv4MulticastAddr1,
   687  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
   688  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   689  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   690  			},
   691  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   692  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   693  			},
   694  			rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
   695  				createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
   696  			},
   697  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   698  				t.Helper()
   699  
   700  				validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   701  			},
   702  			maxRespTimeToDuration: header.DecisecondToDuration,
   703  		},
   704  		{
   705  			name:                        "MLD",
   706  			protoNum:                    ipv6.ProtocolNumber,
   707  			multicastAddr:               ipv6MulticastAddr1,
   708  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
   709  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   710  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
   711  			},
   712  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   713  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   714  			},
   715  			rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
   716  				createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
   717  			},
   718  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   719  				t.Helper()
   720  
   721  				validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   722  			},
   723  			maxRespTimeToDuration: func(d uint8) time.Duration {
   724  				return time.Duration(d) * time.Millisecond
   725  			},
   726  			checkInitialGroups: checkInitialIPv6Groups,
   727  		},
   728  	}
   729  
   730  	for _, test := range tests {
   731  		t.Run(test.name, func(t *testing.T) {
   732  			subTests := []struct {
   733  				name          string
   734  				multicastAddr tcpip.Address
   735  				expectReport  bool
   736  			}{
   737  				{
   738  					name:          "Unspecified",
   739  					multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
   740  					expectReport:  true,
   741  				},
   742  				{
   743  					name:          "Specified",
   744  					multicastAddr: test.multicastAddr,
   745  					expectReport:  true,
   746  				},
   747  				{
   748  					name: "Specified other address",
   749  					multicastAddr: func() tcpip.Address {
   750  						addrBytes := []byte(test.multicastAddr)
   751  						addrBytes[len(addrBytes)-1]++
   752  						return tcpip.Address(addrBytes)
   753  					}(),
   754  					expectReport: false,
   755  				},
   756  			}
   757  
   758  			for _, subTest := range subTests {
   759  				t.Run(subTest.name, func(t *testing.T) {
   760  					e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   761  
   762  					var reportCounter uint64
   763  					if test.checkInitialGroups != nil {
   764  						reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
   765  					}
   766  
   767  					if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   768  						t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   769  					}
   770  					sentReportStat := test.sentReportStat(s)
   771  					for i := 0; i < maxUnsolicitedReports; i++ {
   772  						sentReportStat := test.sentReportStat(s)
   773  						reportCounter++
   774  						if got := sentReportStat.Value(); got != reportCounter {
   775  							t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
   776  						}
   777  						if p, ok := e.Read(); !ok {
   778  							t.Fatalf("expected %d-th report message to be sent", i)
   779  						} else {
   780  							test.validateReport(t, p)
   781  						}
   782  						clock.Advance(test.maxUnsolicitedResponseDelay)
   783  					}
   784  					if t.Failed() {
   785  						t.FailNow()
   786  					}
   787  
   788  					// Should not send any more packets until a query.
   789  					clock.Advance(time.Hour)
   790  					if p, ok := e.Read(); ok {
   791  						t.Fatalf("sent unexpected packet = %#v", p)
   792  					}
   793  
   794  					// Receive a query message which should trigger a report to be sent at
   795  					// some time before the maximum response time if the report is
   796  					// targeted at the host.
   797  					const maxRespTime = 100
   798  					test.rxQuery(e, maxRespTime, subTest.multicastAddr)
   799  					if p, ok := e.Read(); ok {
   800  						t.Fatalf("sent unexpected packet = %#v", p.Pkt)
   801  					}
   802  
   803  					if subTest.expectReport {
   804  						clock.Advance(test.maxRespTimeToDuration(maxRespTime))
   805  						reportCounter++
   806  						if got := sentReportStat.Value(); got != reportCounter {
   807  							t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
   808  						}
   809  						if p, ok := e.Read(); !ok {
   810  							t.Fatal("expected a report message to be sent")
   811  						} else {
   812  							test.validateReport(t, p)
   813  						}
   814  					}
   815  
   816  					// Should not send any more packets.
   817  					clock.Advance(time.Hour)
   818  					if p, ok := e.Read(); ok {
   819  						t.Fatalf("sent unexpected packet = %#v", p)
   820  					}
   821  				})
   822  			}
   823  		})
   824  	}
   825  }
   826  
   827  // TestMGPQueryMessages tests that no further reports or leave/done messages
   828  // are sent after receiving a report.
   829  func TestMGPReportMessages(t *testing.T) {
   830  	tests := []struct {
   831  		name                  string
   832  		protoNum              tcpip.NetworkProtocolNumber
   833  		multicastAddr         tcpip.Address
   834  		sentReportStat        func(*stack.Stack) *tcpip.StatCounter
   835  		sentLeaveStat         func(*stack.Stack) *tcpip.StatCounter
   836  		rxReport              func(*channel.Endpoint)
   837  		validateReport        func(*testing.T, channel.PacketInfo)
   838  		maxRespTimeToDuration func(uint8) time.Duration
   839  		checkInitialGroups    func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
   840  	}{
   841  		{
   842  			name:          "IGMP",
   843  			protoNum:      ipv4.ProtocolNumber,
   844  			multicastAddr: ipv4MulticastAddr1,
   845  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   846  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   847  			},
   848  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
   849  				return s.Stats().IGMP.PacketsSent.LeaveGroup
   850  			},
   851  			rxReport: func(e *channel.Endpoint) {
   852  				createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   853  			},
   854  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   855  				t.Helper()
   856  
   857  				validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   858  			},
   859  			maxRespTimeToDuration: header.DecisecondToDuration,
   860  		},
   861  		{
   862  			name:          "MLD",
   863  			protoNum:      ipv6.ProtocolNumber,
   864  			multicastAddr: ipv6MulticastAddr1,
   865  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   866  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
   867  			},
   868  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
   869  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
   870  			},
   871  			rxReport: func(e *channel.Endpoint) {
   872  				createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
   873  			},
   874  			validateReport: func(t *testing.T, p channel.PacketInfo) {
   875  				t.Helper()
   876  
   877  				validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   878  			},
   879  			maxRespTimeToDuration: func(d uint8) time.Duration {
   880  				return time.Duration(d) * time.Millisecond
   881  			},
   882  			checkInitialGroups: checkInitialIPv6Groups,
   883  		},
   884  	}
   885  
   886  	for _, test := range tests {
   887  		t.Run(test.name, func(t *testing.T) {
   888  			e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   889  
   890  			var reportCounter uint64
   891  			var leaveCounter uint64
   892  			if test.checkInitialGroups != nil {
   893  				reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
   894  			}
   895  
   896  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   897  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   898  			}
   899  			sentReportStat := test.sentReportStat(s)
   900  			reportCounter++
   901  			if got := sentReportStat.Value(); got != reportCounter {
   902  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
   903  			}
   904  			if p, ok := e.Read(); !ok {
   905  				t.Fatal("expected a report message to be sent")
   906  			} else {
   907  				test.validateReport(t, p)
   908  			}
   909  			if t.Failed() {
   910  				t.FailNow()
   911  			}
   912  
   913  			// Receiving a report for a group we joined should cancel any further
   914  			// reports.
   915  			test.rxReport(e)
   916  			clock.Advance(time.Hour)
   917  			if got := sentReportStat.Value(); got != reportCounter {
   918  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
   919  			}
   920  			if p, ok := e.Read(); ok {
   921  				t.Errorf("sent unexpected packet = %#v", p)
   922  			}
   923  			if t.Failed() {
   924  				t.FailNow()
   925  			}
   926  
   927  			// Leaving a group after getting a report should not send a leave/done
   928  			// message.
   929  			if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   930  				t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
   931  			}
   932  			clock.Advance(time.Hour)
   933  			if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
   934  				t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
   935  			}
   936  
   937  			// Should not send any more packets.
   938  			clock.Advance(time.Hour)
   939  			if p, ok := e.Read(); ok {
   940  				t.Fatalf("sent unexpected packet = %#v", p)
   941  			}
   942  		})
   943  	}
   944  }
   945  
   946  func TestMGPWithNICLifecycle(t *testing.T) {
   947  	tests := []struct {
   948  		name                        string
   949  		protoNum                    tcpip.NetworkProtocolNumber
   950  		multicastAddrs              []tcpip.Address
   951  		finalMulticastAddr          tcpip.Address
   952  		maxUnsolicitedResponseDelay time.Duration
   953  		sentReportStat              func(*stack.Stack) *tcpip.StatCounter
   954  		sentLeaveStat               func(*stack.Stack) *tcpip.StatCounter
   955  		validateReport              func(*testing.T, channel.PacketInfo, tcpip.Address)
   956  		validateLeave               func(*testing.T, channel.PacketInfo, tcpip.Address)
   957  		getAndCheckGroupAddress     func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
   958  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
   959  	}{
   960  		{
   961  			name:                        "IGMP",
   962  			protoNum:                    ipv4.ProtocolNumber,
   963  			multicastAddrs:              []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
   964  			finalMulticastAddr:          ipv4MulticastAddr3,
   965  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
   966  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   967  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   968  			},
   969  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
   970  				return s.Stats().IGMP.PacketsSent.LeaveGroup
   971  			},
   972  			validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
   973  				t.Helper()
   974  
   975  				validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
   976  			},
   977  			validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
   978  				t.Helper()
   979  
   980  				validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
   981  			},
   982  			getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
   983  				t.Helper()
   984  
   985  				ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
   986  				if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
   987  					t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
   988  				}
   989  				addr := header.IGMP(ipv4.Payload()).GroupAddress()
   990  				s, ok := seen[addr]
   991  				if !ok {
   992  					t.Fatalf("unexpectedly got a packet for group %s", addr)
   993  				}
   994  				if s {
   995  					t.Fatalf("already saw packet for group %s", addr)
   996  				}
   997  				seen[addr] = true
   998  				return addr
   999  			},
  1000  		},
  1001  		{
  1002  			name:                        "MLD",
  1003  			protoNum:                    ipv6.ProtocolNumber,
  1004  			multicastAddrs:              []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
  1005  			finalMulticastAddr:          ipv6MulticastAddr3,
  1006  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
  1007  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1008  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
  1009  			},
  1010  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
  1011  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
  1012  			},
  1013  			validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
  1014  				t.Helper()
  1015  
  1016  				validateMLDPacket(t, p, addr, mldReport, 0, addr)
  1017  			},
  1018  			validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
  1019  				t.Helper()
  1020  
  1021  				validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr)
  1022  			},
  1023  			getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
  1024  				t.Helper()
  1025  
  1026  				ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
  1027  
  1028  				ipv6HeaderIter := header.MakeIPv6PayloadIterator(
  1029  					header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
  1030  					buffer.View(ipv6.Payload()).ToVectorisedView(),
  1031  				)
  1032  
  1033  				var transport header.IPv6RawPayloadHeader
  1034  				for {
  1035  					h, done, err := ipv6HeaderIter.Next()
  1036  					if err != nil {
  1037  						t.Fatalf("ipv6HeaderIter.Next(): %s", err)
  1038  					}
  1039  					if done {
  1040  						t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
  1041  					}
  1042  					if t, ok := h.(header.IPv6RawPayloadHeader); ok {
  1043  						transport = t
  1044  						break
  1045  					}
  1046  				}
  1047  
  1048  				if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
  1049  					t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
  1050  				}
  1051  				icmpv6 := header.ICMPv6(transport.Buf.ToView())
  1052  				if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
  1053  					t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
  1054  				}
  1055  				addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
  1056  				s, ok := seen[addr]
  1057  				if !ok {
  1058  					t.Fatalf("unexpectedly got a packet for group %s", addr)
  1059  				}
  1060  				if s {
  1061  					t.Fatalf("already saw packet for group %s", addr)
  1062  				}
  1063  				seen[addr] = true
  1064  				return addr
  1065  			},
  1066  			checkInitialGroups: checkInitialIPv6Groups,
  1067  		},
  1068  	}
  1069  
  1070  	for _, test := range tests {
  1071  		t.Run(test.name, func(t *testing.T) {
  1072  			e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
  1073  
  1074  			var reportCounter uint64
  1075  			var leaveCounter uint64
  1076  			if test.checkInitialGroups != nil {
  1077  				reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
  1078  			}
  1079  
  1080  			sentReportStat := test.sentReportStat(s)
  1081  			for _, a := range test.multicastAddrs {
  1082  				if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
  1083  					t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
  1084  				}
  1085  				reportCounter++
  1086  				if got := sentReportStat.Value(); got != reportCounter {
  1087  					t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
  1088  				}
  1089  				if p, ok := e.Read(); !ok {
  1090  					t.Fatalf("expected a report message to be sent for %s", a)
  1091  				} else {
  1092  					test.validateReport(t, p, a)
  1093  				}
  1094  			}
  1095  			if t.Failed() {
  1096  				t.FailNow()
  1097  			}
  1098  
  1099  			// Leave messages should be sent for the joined groups when the NIC is
  1100  			// disabled.
  1101  			if err := s.DisableNIC(nicID); err != nil {
  1102  				t.Fatalf("DisableNIC(%d): %s", nicID, err)
  1103  			}
  1104  			sentLeaveStat := test.sentLeaveStat(s)
  1105  			leaveCounter += uint64(len(test.multicastAddrs))
  1106  			if got := sentLeaveStat.Value(); got != leaveCounter {
  1107  				t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
  1108  			}
  1109  			{
  1110  				seen := make(map[tcpip.Address]bool)
  1111  				for _, a := range test.multicastAddrs {
  1112  					seen[a] = false
  1113  				}
  1114  
  1115  				for i := range test.multicastAddrs {
  1116  					p, ok := e.Read()
  1117  					if !ok {
  1118  						t.Fatalf("expected (%d-th) leave message to be sent", i)
  1119  					}
  1120  
  1121  					test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
  1122  				}
  1123  			}
  1124  			if t.Failed() {
  1125  				t.FailNow()
  1126  			}
  1127  
  1128  			// Reports should be sent for the joined groups when the NIC is enabled.
  1129  			if err := s.EnableNIC(nicID); err != nil {
  1130  				t.Fatalf("EnableNIC(%d): %s", nicID, err)
  1131  			}
  1132  			reportCounter += uint64(len(test.multicastAddrs))
  1133  			if got := sentReportStat.Value(); got != reportCounter {
  1134  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
  1135  			}
  1136  			{
  1137  				seen := make(map[tcpip.Address]bool)
  1138  				for _, a := range test.multicastAddrs {
  1139  					seen[a] = false
  1140  				}
  1141  
  1142  				for i := range test.multicastAddrs {
  1143  					p, ok := e.Read()
  1144  					if !ok {
  1145  						t.Fatalf("expected (%d-th) report message to be sent", i)
  1146  					}
  1147  
  1148  					test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
  1149  				}
  1150  			}
  1151  			if t.Failed() {
  1152  				t.FailNow()
  1153  			}
  1154  
  1155  			// Joining/leaving a group while disabled should not send any messages.
  1156  			if err := s.DisableNIC(nicID); err != nil {
  1157  				t.Fatalf("DisableNIC(%d): %s", nicID, err)
  1158  			}
  1159  			leaveCounter += uint64(len(test.multicastAddrs))
  1160  			if got := sentLeaveStat.Value(); got != leaveCounter {
  1161  				t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
  1162  			}
  1163  			for i := range test.multicastAddrs {
  1164  				if _, ok := e.Read(); !ok {
  1165  					t.Fatalf("expected (%d-th) leave message to be sent", i)
  1166  				}
  1167  			}
  1168  			for _, a := range test.multicastAddrs {
  1169  				if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
  1170  					t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
  1171  				}
  1172  				if got := sentLeaveStat.Value(); got != leaveCounter {
  1173  					t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
  1174  				}
  1175  				if p, ok := e.Read(); ok {
  1176  					t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
  1177  				}
  1178  			}
  1179  			if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
  1180  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
  1181  			}
  1182  			if got := sentReportStat.Value(); got != reportCounter {
  1183  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
  1184  			}
  1185  			if p, ok := e.Read(); ok {
  1186  				t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
  1187  			}
  1188  
  1189  			// A report should only be sent for the group we last joined after
  1190  			// enabling the NIC since the original groups were all left.
  1191  			if err := s.EnableNIC(nicID); err != nil {
  1192  				t.Fatalf("EnableNIC(%d): %s", nicID, err)
  1193  			}
  1194  			reportCounter++
  1195  			if got := sentReportStat.Value(); got != reportCounter {
  1196  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
  1197  			}
  1198  			if p, ok := e.Read(); !ok {
  1199  				t.Fatal("expected a report message to be sent")
  1200  			} else {
  1201  				test.validateReport(t, p, test.finalMulticastAddr)
  1202  			}
  1203  
  1204  			clock.Advance(test.maxUnsolicitedResponseDelay)
  1205  			reportCounter++
  1206  			if got := sentReportStat.Value(); got != reportCounter {
  1207  				t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
  1208  			}
  1209  			if p, ok := e.Read(); !ok {
  1210  				t.Fatal("expected a report message to be sent")
  1211  			} else {
  1212  				test.validateReport(t, p, test.finalMulticastAddr)
  1213  			}
  1214  
  1215  			// Should not send any more packets.
  1216  			clock.Advance(time.Hour)
  1217  			if p, ok := e.Read(); ok {
  1218  				t.Fatalf("sent unexpected packet = %#v", p)
  1219  			}
  1220  		})
  1221  	}
  1222  }
  1223  
  1224  // TestMGPDisabledOnLoopback tests that the multicast group protocol is not
  1225  // performed on loopback interfaces since they have no neighbours.
  1226  func TestMGPDisabledOnLoopback(t *testing.T) {
  1227  	tests := []struct {
  1228  		name           string
  1229  		protoNum       tcpip.NetworkProtocolNumber
  1230  		multicastAddr  tcpip.Address
  1231  		sentReportStat func(*stack.Stack) *tcpip.StatCounter
  1232  	}{
  1233  		{
  1234  			name:          "IGMP",
  1235  			protoNum:      ipv4.ProtocolNumber,
  1236  			multicastAddr: ipv4MulticastAddr1,
  1237  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1238  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
  1239  			},
  1240  		},
  1241  		{
  1242  			name:          "MLD",
  1243  			protoNum:      ipv6.ProtocolNumber,
  1244  			multicastAddr: ipv6MulticastAddr1,
  1245  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1246  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
  1247  			},
  1248  		},
  1249  	}
  1250  
  1251  	for _, test := range tests {
  1252  		t.Run(test.name, func(t *testing.T) {
  1253  			s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
  1254  
  1255  			sentReportStat := test.sentReportStat(s)
  1256  			if got := sentReportStat.Value(); got != 0 {
  1257  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1258  			}
  1259  			clock.Advance(time.Hour)
  1260  			if got := sentReportStat.Value(); got != 0 {
  1261  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1262  			}
  1263  
  1264  			// Test joining a specific group explicitly and verify that no reports are
  1265  			// sent.
  1266  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
  1267  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
  1268  			}
  1269  			if got := sentReportStat.Value(); got != 0 {
  1270  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1271  			}
  1272  			clock.Advance(time.Hour)
  1273  			if got := sentReportStat.Value(); got != 0 {
  1274  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1275  			}
  1276  		})
  1277  	}
  1278  }