gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/multicast_group_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ip_test
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"testing"
    21  	"time"
    22  
    23  	"gvisor.dev/gvisor/pkg/buffer"
    24  	"gvisor.dev/gvisor/pkg/refs"
    25  	"gvisor.dev/gvisor/pkg/tcpip"
    26  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    27  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    28  	"gvisor.dev/gvisor/pkg/tcpip/header"
    29  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    30  	"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
    31  	iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
    32  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    33  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    34  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    35  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    36  )
    37  
    38  const (
    39  	linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
    40  
    41  	defaultIPv4PrefixLength = 24
    42  
    43  	igmpMembershipQuery    = uint8(header.IGMPMembershipQuery)
    44  	igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
    45  	igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
    46  	igmpLeaveGroup         = uint8(header.IGMPLeaveGroup)
    47  	mldQuery               = uint8(header.ICMPv6MulticastListenerQuery)
    48  	mldReport              = uint8(header.ICMPv6MulticastListenerReport)
    49  	mldDone                = uint8(header.ICMPv6MulticastListenerDone)
    50  
    51  	maxUnsolicitedReports = 2
    52  )
    53  
    54  var (
    55  	stackIPv4Addr      = testutil.MustParse4("10.0.0.1")
    56  	linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1")
    57  	linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2")
    58  
    59  	ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3")
    60  	ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4")
    61  	ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5")
    62  	ipv6MulticastAddr1 = testutil.MustParse6("ff02::3")
    63  	ipv6MulticastAddr2 = testutil.MustParse6("ff02::4")
    64  	ipv6MulticastAddr3 = testutil.MustParse6("ff02::5")
    65  )
    66  
    67  var (
    68  	// unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
    69  	// NIC will wait before sending an unsolicited report after joining a
    70  	// multicast group, in deciseconds.
    71  	unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
    72  		const decisecond = time.Second / 10
    73  		if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
    74  			panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
    75  		}
    76  		return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
    77  	}()
    78  
    79  	ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1)
    80  )
    81  
    82  // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
    83  // sent to the provided address with the passed fields set.
    84  func validateMLDPacket(t *testing.T, p *stack.PacketBuffer, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
    85  	t.Helper()
    86  
    87  	payload := stack.PayloadSince(p.NetworkHeader())
    88  	defer payload.Release()
    89  	checker.IPv6WithExtHdr(t, payload,
    90  		checker.IPv6ExtHdr(
    91  			checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
    92  		),
    93  		checker.SrcAddr(linkLocalIPv6Addr1),
    94  		checker.DstAddr(remoteAddress),
    95  		// Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
    96  		checker.TTL(1),
    97  		checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
    98  			checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
    99  			checker.MLDMulticastAddress(groupAddress),
   100  		),
   101  	)
   102  }
   103  
   104  func validateMLDv2ReportPacket(t *testing.T, p *stack.PacketBuffer, addrs []tcpip.Address, recordType header.MLDv2ReportRecordType) {
   105  	t.Helper()
   106  	payload := stack.PayloadSince(p.NetworkHeader())
   107  	defer payload.Release()
   108  	iptestutil.ValidateMLDv2Report(t, payload, linkLocalIPv6Addr1, addrs, recordType)
   109  }
   110  
   111  // validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
   112  // sent to the provided address with the passed fields set.
   113  func validateIGMPPacket(t *testing.T, p *stack.PacketBuffer, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
   114  	t.Helper()
   115  
   116  	payload := stack.PayloadSince(p.NetworkHeader())
   117  	defer payload.Release()
   118  	checker.IPv4(t, payload,
   119  		checker.SrcAddr(stackIPv4Addr),
   120  		checker.DstAddr(remoteAddress),
   121  		// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
   122  		checker.TTL(1),
   123  		checker.IPv4RouterAlert(),
   124  		checker.IGMP(
   125  			checker.IGMPType(header.IGMPType(igmpType)),
   126  			checker.IGMPMaxRespTime(header.DecisecondToDuration(uint16(maxRespTime))),
   127  			checker.IGMPGroupAddress(groupAddress),
   128  		),
   129  	)
   130  }
   131  
   132  func validateIGMPv3ReportPacket(t *testing.T, p *stack.PacketBuffer, addrs []tcpip.Address, recordType header.IGMPv3ReportRecordType) {
   133  	t.Helper()
   134  
   135  	payload := stack.PayloadSince(p.NetworkHeader())
   136  	defer payload.Release()
   137  	iptestutil.ValidateIGMPv3Report(t, payload, stackIPv4Addr, addrs, recordType)
   138  }
   139  
   140  type multicastTestContext struct {
   141  	s     *stack.Stack
   142  	e     *channel.Endpoint
   143  	clock *faketime.ManualClock
   144  }
   145  
   146  func newMulticastTestContext(t *testing.T, v4, mgpEnabled bool) multicastTestContext {
   147  	t.Helper()
   148  
   149  	e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
   150  	s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
   151  	return multicastTestContext{
   152  		s:     s,
   153  		e:     e,
   154  		clock: clock,
   155  	}
   156  }
   157  
   158  func (ctx *multicastTestContext) cleanup() {
   159  	ctx.s.Close()
   160  	ctx.s.Wait()
   161  	ctx.e.Close()
   162  	refs.DoRepeatedLeakCheck()
   163  }
   164  
   165  func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
   166  	t.Helper()
   167  
   168  	igmpEnabled := v4 && mgpEnabled
   169  	mldEnabled := !v4 && mgpEnabled
   170  
   171  	clock := faketime.NewManualClock()
   172  	s := stack.New(stack.Options{
   173  		NetworkProtocols: []stack.NetworkProtocolFactory{
   174  			ipv4.NewProtocolWithOptions(ipv4.Options{
   175  				IGMP: ipv4.IGMPOptions{
   176  					Enabled: igmpEnabled,
   177  				},
   178  			}),
   179  			ipv6.NewProtocolWithOptions(ipv6.Options{
   180  				MLD: ipv6.MLDOptions{
   181  					Enabled: mldEnabled,
   182  				},
   183  			}),
   184  		},
   185  		Clock: clock,
   186  	})
   187  	if err := s.CreateNIC(nicID, e); err != nil {
   188  		t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
   189  	}
   190  	addr := tcpip.ProtocolAddress{
   191  		Protocol: ipv4.ProtocolNumber,
   192  		AddressWithPrefix: tcpip.AddressWithPrefix{
   193  			Address:   stackIPv4Addr,
   194  			PrefixLen: defaultIPv4PrefixLength,
   195  		},
   196  	}
   197  	if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
   198  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
   199  	}
   200  	protocolAddr := tcpip.ProtocolAddress{
   201  		Protocol:          ipv6.ProtocolNumber,
   202  		AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
   203  	}
   204  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   205  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   206  	}
   207  
   208  	return s, clock
   209  }
   210  
   211  // checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
   212  // when it is created with an IPv6 address.
   213  //
   214  // To not interfere with tests, checkInitialIPv6Groups will leave the added
   215  // address's solicited node multicast group so that the tests can all assume
   216  // the NIC has not joined any IPv6 groups.
   217  func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) uint64 {
   218  	t.Helper()
   219  
   220  	var reportCounter uint64
   221  
   222  	reportCounter++
   223  	iptestutil.CheckMLDv2Stats(t, s, 0, 0, reportCounter)
   224  	if p := e.Read(); p == nil {
   225  		t.Fatal("expected a report message to be sent")
   226  	} else {
   227  		v := stack.PayloadSince(p.NetworkHeader())
   228  		iptestutil.ValidateMLDv2Report(t, v, linkLocalIPv6Addr1, []tcpip.Address{ipv6AddrSNMC}, header.MLDv2ReportRecordChangeToExcludeMode)
   229  		v.Release()
   230  		p.DecRef()
   231  	}
   232  
   233  	// Leave the group to not affect the tests. This is fine since we are not
   234  	// testing DAD or the solicited node address specifically.
   235  	if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
   236  		t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
   237  	}
   238  	for i := 0; i < 2; i++ {
   239  		reportCounter++
   240  		iptestutil.CheckMLDv2Stats(t, s, 0, 0, reportCounter)
   241  		if p := e.Read(); p == nil {
   242  			t.Fatal("expected a report message to be sent")
   243  		} else {
   244  			v := stack.PayloadSince(p.NetworkHeader())
   245  			iptestutil.ValidateMLDv2Report(t, v, linkLocalIPv6Addr1, []tcpip.Address{ipv6AddrSNMC}, header.MLDv2ReportRecordChangeToIncludeMode)
   246  			v.Release()
   247  			p.DecRef()
   248  		}
   249  
   250  		clock.Advance(ipv6.UnsolicitedReportIntervalMax)
   251  	}
   252  
   253  	// Should not send any more packets.
   254  	clock.Advance(time.Hour)
   255  	if p := e.Read(); p != nil {
   256  		t.Fatalf("sent unexpected packet = %#v", p)
   257  	}
   258  
   259  	return reportCounter
   260  }
   261  
   262  // createAndInjectIGMPPacket creates and injects an IGMP packet with the
   263  // specified fields.
   264  func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address, extraLength int) {
   265  	options := header.IPv4OptionsSerializer{
   266  		&header.IPv4SerializableRouterAlertOption{},
   267  	}
   268  	buf := make([]byte, header.IPv4MinimumSize+int(options.Length())+header.IGMPQueryMinimumSize+extraLength)
   269  	ip := header.IPv4(buf)
   270  	ip.Encode(&header.IPv4Fields{
   271  		TotalLength: uint16(len(buf)),
   272  		TTL:         header.IGMPTTL,
   273  		Protocol:    uint8(header.IGMPProtocolNumber),
   274  		SrcAddr:     remoteIPv4Addr,
   275  		DstAddr:     header.IPv4AllSystems,
   276  		Options:     options,
   277  	})
   278  	ip.SetChecksum(^ip.CalculateChecksum())
   279  
   280  	igmp := header.IGMP(ip.Payload())
   281  	igmp.SetType(header.IGMPType(igmpType))
   282  	igmp.SetMaxRespTime(maxRespTime)
   283  	igmp.SetGroupAddress(groupAddress)
   284  	igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
   285  
   286  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   287  		Payload: buffer.MakeWithData(buf),
   288  	})
   289  	e.InjectInbound(ipv4.ProtocolNumber, pkt)
   290  	pkt.DecRef()
   291  }
   292  
   293  // createAndInjectMLDPacket creates and injects an MLD packet with the
   294  // specified fields.
   295  func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address, extraLength int) {
   296  	extensionHeaders := header.IPv6ExtHdrSerializer{
   297  		header.IPv6SerializableHopByHopExtHdr{
   298  			&header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
   299  		},
   300  	}
   301  
   302  	extensionHeadersLength := extensionHeaders.Length()
   303  	payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize + extraLength
   304  	buf := make([]byte, header.IPv6MinimumSize+payloadLength)
   305  
   306  	ip := header.IPv6(buf)
   307  	ip.Encode(&header.IPv6Fields{
   308  		PayloadLength:     uint16(payloadLength),
   309  		HopLimit:          header.MLDHopLimit,
   310  		TransportProtocol: header.ICMPv6ProtocolNumber,
   311  		SrcAddr:           linkLocalIPv6Addr2,
   312  		DstAddr:           header.IPv6AllNodesMulticastAddress,
   313  		ExtensionHeaders:  extensionHeaders,
   314  	})
   315  
   316  	icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:])
   317  	icmp.SetType(header.ICMPv6Type(mldType))
   318  	mld := header.MLD(icmp.MessageBody())
   319  	mld.SetMaximumResponseDelay(uint16(maxRespDelay))
   320  	mld.SetMulticastAddress(groupAddress)
   321  	icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   322  		Header: icmp,
   323  		Src:    linkLocalIPv6Addr2,
   324  		Dst:    header.IPv6AllNodesMulticastAddress,
   325  	}))
   326  
   327  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   328  		Payload: buffer.MakeWithData(buf),
   329  	})
   330  	e.InjectInbound(ipv6.ProtocolNumber, pkt)
   331  	pkt.DecRef()
   332  }
   333  
   334  // TestMGPDisabled tests that the multicast group protocol is not enabled by
   335  // default.
   336  func TestMGPDisabled(t *testing.T) {
   337  	tests := []struct {
   338  		name              string
   339  		protoNum          tcpip.NetworkProtocolNumber
   340  		multicastAddr     tcpip.Address
   341  		sentReportStat    func(*stack.Stack) *tcpip.StatCounter
   342  		receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
   343  		rxQuery           func(*channel.Endpoint)
   344  	}{
   345  		{
   346  			name:          "IGMP",
   347  			protoNum:      ipv4.ProtocolNumber,
   348  			multicastAddr: ipv4MulticastAddr1,
   349  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   350  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
   351  			},
   352  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   353  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   354  			},
   355  			rxQuery: func(e *channel.Endpoint) {
   356  				createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any, 0 /* extraLength */)
   357  			},
   358  		},
   359  		{
   360  			name:          "MLD",
   361  			protoNum:      ipv6.ProtocolNumber,
   362  			multicastAddr: ipv6MulticastAddr1,
   363  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
   364  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
   365  			},
   366  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   367  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   368  			},
   369  			rxQuery: func(e *channel.Endpoint) {
   370  				createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any, 0 /* extraLength */)
   371  			},
   372  		},
   373  	}
   374  
   375  	for _, test := range tests {
   376  		t.Run(test.name, func(t *testing.T) {
   377  			ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
   378  			defer ctx.cleanup()
   379  			s := ctx.s
   380  			e := ctx.e
   381  			clock := ctx.clock
   382  
   383  			// This NIC may join multicast groups when it is enabled but since MGP is
   384  			// disabled, no reports should be sent.
   385  			sentReportStat := test.sentReportStat(s)
   386  			if got := sentReportStat.Value(); got != 0 {
   387  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
   388  			}
   389  			clock.Advance(time.Hour)
   390  			if p := e.Read(); p != nil {
   391  				t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p)
   392  			}
   393  
   394  			// Test joining a specific group explicitly and verify that no reports are
   395  			// sent.
   396  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   397  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   398  			}
   399  			if got := sentReportStat.Value(); got != 0 {
   400  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
   401  			}
   402  			clock.Advance(time.Hour)
   403  			if p := e.Read(); p != nil {
   404  				t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p)
   405  			}
   406  
   407  			// Inject a general query message. This should only trigger a report to be
   408  			// sent if the MGP was enabled.
   409  			test.rxQuery(e)
   410  			if got := test.receivedQueryStat(s).Value(); got != 1 {
   411  				t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
   412  			}
   413  			clock.Advance(time.Hour)
   414  			if p := e.Read(); p != nil {
   415  				t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p)
   416  			}
   417  		})
   418  	}
   419  }
   420  
   421  func TestMGPReceiveCounters(t *testing.T) {
   422  	tests := []struct {
   423  		name         string
   424  		headerType   uint8
   425  		maxRespTime  byte
   426  		groupAddress tcpip.Address
   427  		statCounter  func(*stack.Stack) *tcpip.StatCounter
   428  		rxMGPkt      func(*channel.Endpoint, byte, byte, tcpip.Address, int)
   429  	}{
   430  		{
   431  			name:         "IGMP Membership Query",
   432  			headerType:   igmpMembershipQuery,
   433  			maxRespTime:  unsolicitedIGMPReportIntervalMaxTenthSec,
   434  			groupAddress: header.IPv4Any,
   435  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   436  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   437  			},
   438  			rxMGPkt: createAndInjectIGMPPacket,
   439  		},
   440  		{
   441  			name:         "IGMPv1 Membership Report",
   442  			headerType:   igmpv1MembershipReport,
   443  			maxRespTime:  0,
   444  			groupAddress: header.IPv4AllSystems,
   445  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   446  				return s.Stats().IGMP.PacketsReceived.V1MembershipReport
   447  			},
   448  			rxMGPkt: createAndInjectIGMPPacket,
   449  		},
   450  		{
   451  			name:         "IGMPv2 Membership Report",
   452  			headerType:   igmpv2MembershipReport,
   453  			maxRespTime:  0,
   454  			groupAddress: header.IPv4AllSystems,
   455  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   456  				return s.Stats().IGMP.PacketsReceived.V2MembershipReport
   457  			},
   458  			rxMGPkt: createAndInjectIGMPPacket,
   459  		},
   460  		{
   461  			name:         "IGMP Leave Group",
   462  			headerType:   igmpLeaveGroup,
   463  			maxRespTime:  0,
   464  			groupAddress: header.IPv4AllRoutersGroup,
   465  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   466  				return s.Stats().IGMP.PacketsReceived.LeaveGroup
   467  			},
   468  			rxMGPkt: createAndInjectIGMPPacket,
   469  		},
   470  		{
   471  			name:         "MLD Query",
   472  			headerType:   mldQuery,
   473  			maxRespTime:  0,
   474  			groupAddress: header.IPv6Any,
   475  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   476  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   477  			},
   478  			rxMGPkt: createAndInjectMLDPacket,
   479  		},
   480  		{
   481  			name:         "MLD Report",
   482  			headerType:   mldReport,
   483  			maxRespTime:  0,
   484  			groupAddress: header.IPv6Any,
   485  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   486  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
   487  			},
   488  			rxMGPkt: createAndInjectMLDPacket,
   489  		},
   490  		{
   491  			name:         "MLD Done",
   492  			headerType:   mldDone,
   493  			maxRespTime:  0,
   494  			groupAddress: header.IPv6Any,
   495  			statCounter: func(s *stack.Stack) *tcpip.StatCounter {
   496  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
   497  			},
   498  			rxMGPkt: createAndInjectMLDPacket,
   499  		},
   500  	}
   501  
   502  	for _, test := range tests {
   503  		t.Run(test.name, func(t *testing.T) {
   504  			ctx := newMulticastTestContext(t, test.groupAddress.Len() == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
   505  			defer ctx.cleanup()
   506  
   507  			test.rxMGPkt(ctx.e, test.headerType, test.maxRespTime, test.groupAddress, 0 /* extraLength */)
   508  			if got := test.statCounter(ctx.s).Value(); got != 1 {
   509  				t.Fatalf("got %s received = %d, want = 1", test.name, got)
   510  			}
   511  		})
   512  	}
   513  }
   514  
   515  // TestMGPJoinGroup tests that when explicitly joining a multicast group, the
   516  // stack schedules and sends correct Membership Reports.
   517  func TestMGPJoinGroup(t *testing.T) {
   518  	type subTest struct {
   519  		name           string
   520  		enterVersion   func(e *channel.Endpoint)
   521  		validateReport func(*testing.T, *stack.PacketBuffer)
   522  		checkStats     func(*testing.T, *stack.Stack, uint64, uint64, uint64)
   523  	}
   524  
   525  	tests := []struct {
   526  		name                        string
   527  		protoNum                    tcpip.NetworkProtocolNumber
   528  		multicastAddr               tcpip.Address
   529  		maxUnsolicitedResponseDelay time.Duration
   530  		receivedQueryStat           func(*stack.Stack) *tcpip.StatCounter
   531  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64
   532  		subTests                    []subTest
   533  	}{
   534  		{
   535  			name:                        "IGMP",
   536  			protoNum:                    ipv4.ProtocolNumber,
   537  			multicastAddr:               ipv4MulticastAddr1,
   538  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
   539  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   540  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   541  			},
   542  			subTests: []subTest{
   543  				{
   544  					name: "V2",
   545  					enterVersion: func(e *channel.Endpoint) {
   546  						// V2 query for unrelated group.
   547  						createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */)
   548  					},
   549  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   550  						t.Helper()
   551  
   552  						validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   553  					},
   554  					checkStats: iptestutil.CheckIGMPv2Stats,
   555  				},
   556  				{
   557  					name:         "V3",
   558  					enterVersion: func(*channel.Endpoint) {},
   559  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   560  						t.Helper()
   561  
   562  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToExcludeMode)
   563  					},
   564  					checkStats: iptestutil.CheckIGMPv3Stats,
   565  				},
   566  			},
   567  		},
   568  		{
   569  			name:                        "MLD",
   570  			protoNum:                    ipv6.ProtocolNumber,
   571  			multicastAddr:               ipv6MulticastAddr1,
   572  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
   573  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   574  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   575  			},
   576  			checkInitialGroups: checkInitialIPv6Groups,
   577  			subTests: []subTest{
   578  				{
   579  					name: "V1",
   580  					enterVersion: func(e *channel.Endpoint) {
   581  						// V1 query for unrelated group.
   582  						createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */)
   583  					},
   584  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   585  						t.Helper()
   586  
   587  						validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   588  					},
   589  					checkStats: iptestutil.CheckMLDv1Stats,
   590  				},
   591  				{
   592  					name:         "V2",
   593  					enterVersion: func(*channel.Endpoint) {},
   594  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   595  						t.Helper()
   596  
   597  						validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToExcludeMode)
   598  					},
   599  					checkStats: iptestutil.CheckMLDv2Stats,
   600  				},
   601  			},
   602  		},
   603  	}
   604  
   605  	for _, test := range tests {
   606  		t.Run(test.name, func(t *testing.T) {
   607  			for _, subTest := range test.subTests {
   608  				t.Run(subTest.name, func(t *testing.T) {
   609  					ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   610  					defer ctx.cleanup()
   611  					s, e, clock := ctx.s, ctx.e, ctx.clock
   612  
   613  					var reportCounter uint64
   614  					var leaveCounter uint64
   615  					var reportV2Counter uint64
   616  					if test.checkInitialGroups != nil {
   617  						reportV2Counter = test.checkInitialGroups(t, e, s, clock)
   618  					}
   619  
   620  					subTest.enterVersion(e)
   621  
   622  					// Test joining a specific address explicitly and verify a Report is sent
   623  					// immediately.
   624  					if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   625  						t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   626  					}
   627  					reportCounter++
   628  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
   629  					if p := e.Read(); p == nil {
   630  						t.Fatal("expected a report message to be sent")
   631  					} else {
   632  						subTest.validateReport(t, p)
   633  						p.DecRef()
   634  					}
   635  					if t.Failed() {
   636  						t.FailNow()
   637  					}
   638  
   639  					// Verify the second report is sent by the maximum unsolicited response
   640  					// interval.
   641  					p := e.Read()
   642  					if p != nil {
   643  						t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p)
   644  					}
   645  					clock.Advance(test.maxUnsolicitedResponseDelay)
   646  					reportCounter++
   647  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
   648  					if p := e.Read(); p == nil {
   649  						t.Fatal("expected a report message to be sent")
   650  					} else {
   651  						subTest.validateReport(t, p)
   652  						p.DecRef()
   653  					}
   654  
   655  					// Should not send any more packets.
   656  					clock.Advance(time.Hour)
   657  					if p := e.Read(); p != nil {
   658  						t.Fatalf("sent unexpected packet = %#v", p)
   659  					}
   660  				})
   661  			}
   662  		})
   663  	}
   664  }
   665  
   666  // TestMGPLeaveGroup tests that when leaving a previously joined multicast
   667  // group the stack sends a leave/done message.
   668  func TestMGPLeaveGroup(t *testing.T) {
   669  	type subTest struct {
   670  		name           string
   671  		enterVersion   func(e *channel.Endpoint)
   672  		validateReport func(*testing.T, *stack.PacketBuffer)
   673  		validateLeave  func(*testing.T, *stack.PacketBuffer)
   674  		leaveCount     uint8
   675  		checkStats     func(*testing.T, *stack.Stack, uint64, uint64, uint64)
   676  	}
   677  
   678  	tests := []struct {
   679  		name                        string
   680  		protoNum                    tcpip.NetworkProtocolNumber
   681  		multicastAddr               tcpip.Address
   682  		maxUnsolicitedResponseDelay time.Duration
   683  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64
   684  		subTests                    []subTest
   685  	}{
   686  		{
   687  			name:                        "IGMP",
   688  			protoNum:                    ipv4.ProtocolNumber,
   689  			multicastAddr:               ipv4MulticastAddr1,
   690  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
   691  			subTests: []subTest{
   692  				{
   693  					name: "V2",
   694  					enterVersion: func(e *channel.Endpoint) {
   695  						// V2 query for unrelated group.
   696  						createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */)
   697  					},
   698  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   699  						t.Helper()
   700  
   701  						validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   702  					},
   703  					validateLeave: func(t *testing.T, p *stack.PacketBuffer) {
   704  						t.Helper()
   705  
   706  						validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
   707  					},
   708  					leaveCount: 1,
   709  					checkStats: iptestutil.CheckIGMPv2Stats,
   710  				},
   711  				{
   712  					name:         "V3",
   713  					enterVersion: func(*channel.Endpoint) {},
   714  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   715  						t.Helper()
   716  
   717  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToExcludeMode)
   718  					},
   719  					validateLeave: func(t *testing.T, p *stack.PacketBuffer) {
   720  						t.Helper()
   721  
   722  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToIncludeMode)
   723  					},
   724  					leaveCount: 2,
   725  					checkStats: iptestutil.CheckIGMPv3Stats,
   726  				},
   727  			},
   728  		},
   729  		{
   730  			name:                        "MLD",
   731  			protoNum:                    ipv6.ProtocolNumber,
   732  			multicastAddr:               ipv6MulticastAddr1,
   733  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
   734  			checkInitialGroups:          checkInitialIPv6Groups,
   735  			subTests: []subTest{
   736  				{
   737  					name: "V1",
   738  					enterVersion: func(e *channel.Endpoint) {
   739  						// V1 query for unrelated group.
   740  						createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */)
   741  					},
   742  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   743  						t.Helper()
   744  
   745  						validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   746  					},
   747  					validateLeave: func(t *testing.T, p *stack.PacketBuffer) {
   748  						t.Helper()
   749  
   750  						validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
   751  					},
   752  					leaveCount: 1,
   753  					checkStats: iptestutil.CheckMLDv1Stats,
   754  				},
   755  				{
   756  					name:         "V2",
   757  					enterVersion: func(*channel.Endpoint) {},
   758  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
   759  						t.Helper()
   760  
   761  						validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToExcludeMode)
   762  					},
   763  					validateLeave: func(t *testing.T, p *stack.PacketBuffer) {
   764  						t.Helper()
   765  
   766  						validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToIncludeMode)
   767  					},
   768  					leaveCount: 2,
   769  					checkStats: iptestutil.CheckMLDv2Stats,
   770  				},
   771  			},
   772  		},
   773  	}
   774  
   775  	for _, test := range tests {
   776  		t.Run(test.name, func(t *testing.T) {
   777  			for _, subTest := range test.subTests {
   778  				t.Run(subTest.name, func(t *testing.T) {
   779  					ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   780  					defer ctx.cleanup()
   781  					s, e, clock := ctx.s, ctx.e, ctx.clock
   782  
   783  					var reportCounter uint64
   784  					var leaveCounter uint64
   785  					var reportV2Counter uint64
   786  					if test.checkInitialGroups != nil {
   787  						reportV2Counter = test.checkInitialGroups(t, e, s, clock)
   788  					}
   789  
   790  					subTest.enterVersion(e)
   791  
   792  					if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   793  						t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
   794  					}
   795  					reportCounter++
   796  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
   797  					if p := e.Read(); p == nil {
   798  						t.Fatal("expected a report message to be sent")
   799  					} else {
   800  						subTest.validateReport(t, p)
   801  						p.DecRef()
   802  					}
   803  					if t.Failed() {
   804  						t.FailNow()
   805  					}
   806  
   807  					// Leaving the group should trigger an leave/done message to be sent.
   808  					if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   809  						t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
   810  					}
   811  					for i := subTest.leaveCount; i > 0; i-- {
   812  						leaveCounter++
   813  						subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
   814  						if p := e.Read(); p == nil {
   815  							t.Fatal("expected a leave message to be sent")
   816  						} else {
   817  							subTest.validateLeave(t, p)
   818  							p.DecRef()
   819  						}
   820  						clock.Advance(test.maxUnsolicitedResponseDelay)
   821  					}
   822  
   823  					// Should not send any more packets.
   824  					clock.Advance(time.Hour)
   825  					if p := e.Read(); p != nil {
   826  						t.Fatalf("sent unexpected packet = %#v", p)
   827  					}
   828  				})
   829  			}
   830  		})
   831  	}
   832  }
   833  
   834  // TestMGPQueryMessages tests that a report is sent in response to query
   835  // messages.
   836  func TestMGPQueryMessages(t *testing.T) {
   837  	type subTest struct {
   838  		name           string
   839  		enterVersion   func(e *channel.Endpoint)
   840  		validateReport func(*testing.T, *stack.PacketBuffer, bool)
   841  		checkStats     func(*testing.T, *stack.Stack, uint64, uint64, uint64)
   842  		rxQuery        func(*channel.Endpoint, uint8, tcpip.Address)
   843  	}
   844  
   845  	tests := []struct {
   846  		name                        string
   847  		protoNum                    tcpip.NetworkProtocolNumber
   848  		multicastAddr               tcpip.Address
   849  		maxUnsolicitedResponseDelay time.Duration
   850  		receivedQueryStat           func(*stack.Stack) *tcpip.StatCounter
   851  		maxRespTimeToDuration       func(uint16) time.Duration
   852  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64
   853  		subTests                    []subTest
   854  	}{
   855  		{
   856  			name:                        "IGMP",
   857  			protoNum:                    ipv4.ProtocolNumber,
   858  			multicastAddr:               ipv4MulticastAddr1,
   859  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
   860  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   861  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
   862  			},
   863  			maxRespTimeToDuration: header.DecisecondToDuration,
   864  			subTests: []subTest{
   865  				{
   866  					name: "V2",
   867  					enterVersion: func(e *channel.Endpoint) {
   868  						// V2 query for unrelated group.
   869  						createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */)
   870  					},
   871  					validateReport: func(t *testing.T, p *stack.PacketBuffer, _ bool) {
   872  						t.Helper()
   873  
   874  						validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
   875  					},
   876  					rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
   877  						createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress, 0 /* extraLength */)
   878  					},
   879  					checkStats: iptestutil.CheckIGMPv2Stats,
   880  				},
   881  				{
   882  					name:         "V3",
   883  					enterVersion: func(*channel.Endpoint) {},
   884  					validateReport: func(t *testing.T, p *stack.PacketBuffer, queryResponse bool) {
   885  						t.Helper()
   886  
   887  						recordType := header.IGMPv3ReportRecordChangeToExcludeMode
   888  						if queryResponse {
   889  							recordType = header.IGMPv3ReportRecordModeIsExclude
   890  						}
   891  
   892  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, recordType)
   893  					},
   894  					rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
   895  						createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress, header.IGMPv3QueryMinimumSize-header.IGMPQueryMinimumSize /* extraLength */)
   896  					},
   897  					checkStats: iptestutil.CheckIGMPv3Stats,
   898  				},
   899  			},
   900  		},
   901  		{
   902  			name:                        "MLD",
   903  			protoNum:                    ipv6.ProtocolNumber,
   904  			multicastAddr:               ipv6MulticastAddr1,
   905  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
   906  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
   907  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
   908  			},
   909  			maxRespTimeToDuration: func(d uint16) time.Duration {
   910  				return time.Duration(d) * time.Millisecond
   911  			},
   912  			checkInitialGroups: checkInitialIPv6Groups,
   913  			subTests: []subTest{
   914  				{
   915  					name: "V1",
   916  					enterVersion: func(e *channel.Endpoint) {
   917  						// V1 query for unrelated group.
   918  						createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */)
   919  					},
   920  					validateReport: func(t *testing.T, p *stack.PacketBuffer, _ bool) {
   921  						t.Helper()
   922  
   923  						validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
   924  					},
   925  					rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
   926  						createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress, 0 /* extraLength */)
   927  					},
   928  					checkStats: iptestutil.CheckMLDv1Stats,
   929  				},
   930  				{
   931  					name:         "V2",
   932  					enterVersion: func(*channel.Endpoint) {},
   933  					validateReport: func(t *testing.T, p *stack.PacketBuffer, queryResponse bool) {
   934  						t.Helper()
   935  
   936  						recordType := header.MLDv2ReportRecordChangeToExcludeMode
   937  						if queryResponse {
   938  							recordType = header.MLDv2ReportRecordModeIsExclude
   939  						}
   940  
   941  						validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, recordType)
   942  					},
   943  					rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
   944  						createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress, header.MLDv2QueryMinimumSize-header.MLDMinimumSize /* extraLength */)
   945  					},
   946  					checkStats: iptestutil.CheckMLDv2Stats,
   947  				},
   948  			},
   949  		},
   950  	}
   951  
   952  	for _, test := range tests {
   953  		t.Run(test.name, func(t *testing.T) {
   954  			addrTests := []struct {
   955  				name          string
   956  				multicastAddr tcpip.Address
   957  				expectReport  bool
   958  			}{
   959  				{
   960  					name:          "Unspecified",
   961  					multicastAddr: tcpip.AddrFromSlice([]byte(strings.Repeat("\x00", test.multicastAddr.Len()))),
   962  					expectReport:  true,
   963  				},
   964  				{
   965  					name:          "Specified",
   966  					multicastAddr: test.multicastAddr,
   967  					expectReport:  true,
   968  				},
   969  				{
   970  					name: "Specified other address",
   971  					multicastAddr: func() tcpip.Address {
   972  						addrCopy := test.multicastAddr
   973  						addrBytes := addrCopy.AsSlice()
   974  						addrBytes[len(addrBytes)-1]++
   975  						return tcpip.AddrFromSlice(addrBytes)
   976  					}(),
   977  					expectReport: false,
   978  				},
   979  			}
   980  
   981  			for _, addrTest := range addrTests {
   982  				t.Run(addrTest.name, func(t *testing.T) {
   983  					for _, subTest := range test.subTests {
   984  						t.Run(subTest.name, func(t *testing.T) {
   985  							ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
   986  							defer ctx.cleanup()
   987  							s, e, clock := ctx.s, ctx.e, ctx.clock
   988  
   989  							var reportCounter uint64
   990  							var leaveCounter uint64
   991  							var reportV2Counter uint64
   992  							if test.checkInitialGroups != nil {
   993  								reportV2Counter = test.checkInitialGroups(t, e, s, clock)
   994  							}
   995  
   996  							subTest.enterVersion(e)
   997  
   998  							if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
   999  								t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
  1000  							}
  1001  							for i := 0; i < maxUnsolicitedReports; i++ {
  1002  								reportCounter++
  1003  								subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1004  								if p := e.Read(); p == nil {
  1005  									t.Fatalf("expected %d-th report message to be sent", i)
  1006  								} else {
  1007  									subTest.validateReport(t, p, false /* queryResponse */)
  1008  									p.DecRef()
  1009  								}
  1010  								clock.Advance(test.maxUnsolicitedResponseDelay)
  1011  							}
  1012  							if t.Failed() {
  1013  								t.FailNow()
  1014  							}
  1015  
  1016  							// Should not send any more packets until a query.
  1017  							clock.Advance(time.Hour)
  1018  							if p := e.Read(); p != nil {
  1019  								t.Fatalf("sent unexpected packet = %#v", p)
  1020  							}
  1021  
  1022  							// Receive a query message which should trigger a report to be sent at
  1023  							// some time before the maximum response time if the report is
  1024  							// targeted at the host.
  1025  							const maxRespTime = 100
  1026  							subTest.rxQuery(e, maxRespTime, addrTest.multicastAddr)
  1027  							if p := e.Read(); p != nil {
  1028  								t.Fatalf("sent unexpected packet = %#v", p)
  1029  							}
  1030  
  1031  							if addrTest.expectReport {
  1032  								clock.Advance(test.maxRespTimeToDuration(maxRespTime))
  1033  								reportCounter++
  1034  								subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1035  								if p := e.Read(); p == nil {
  1036  									t.Fatal("expected a report message to be sent")
  1037  								} else {
  1038  									subTest.validateReport(t, p, true /* queryResponse */)
  1039  									p.DecRef()
  1040  								}
  1041  							}
  1042  
  1043  							// Should not send any more packets.
  1044  							clock.Advance(time.Hour)
  1045  							if p := e.Read(); p != nil {
  1046  								t.Fatalf("sent unexpected packet = %#v", p)
  1047  							}
  1048  						})
  1049  					}
  1050  				})
  1051  			}
  1052  		})
  1053  	}
  1054  }
  1055  
  1056  // TestMGPQueryMessages tests that no further reports or leave/done messages
  1057  // are sent after receiving a report.
  1058  func TestMGPReportMessages(t *testing.T) {
  1059  	type subTest struct {
  1060  		name           string
  1061  		enterVersion   func(e *channel.Endpoint)
  1062  		validateReport func(*testing.T, *stack.PacketBuffer)
  1063  		validateLeave  func(*testing.T, *stack.PacketBuffer)
  1064  		leaveCount     uint8
  1065  		checkStats     func(*testing.T, *stack.Stack, uint64, uint64, uint64)
  1066  	}
  1067  
  1068  	tests := []struct {
  1069  		name                        string
  1070  		protoNum                    tcpip.NetworkProtocolNumber
  1071  		multicastAddr               tcpip.Address
  1072  		maxUnsolicitedResponseDelay time.Duration
  1073  		rxReport                    func(*channel.Endpoint)
  1074  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64
  1075  		subTests                    []subTest
  1076  	}{
  1077  		{
  1078  			name:          "IGMP",
  1079  			protoNum:      ipv4.ProtocolNumber,
  1080  			multicastAddr: ipv4MulticastAddr1,
  1081  			rxReport: func(e *channel.Endpoint) {
  1082  				createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1, 0 /* extraLength */)
  1083  			},
  1084  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
  1085  			subTests: []subTest{
  1086  				{
  1087  					name: "V2",
  1088  					enterVersion: func(e *channel.Endpoint) {
  1089  						// V2 query for unrelated group.
  1090  						createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */)
  1091  					},
  1092  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
  1093  						t.Helper()
  1094  
  1095  						validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
  1096  					},
  1097  					leaveCount: 0,
  1098  					checkStats: iptestutil.CheckIGMPv2Stats,
  1099  				},
  1100  				{
  1101  					name:         "V3",
  1102  					enterVersion: func(*channel.Endpoint) {},
  1103  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
  1104  						t.Helper()
  1105  
  1106  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToExcludeMode)
  1107  					},
  1108  					validateLeave: func(t *testing.T, p *stack.PacketBuffer) {
  1109  						t.Helper()
  1110  
  1111  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToIncludeMode)
  1112  					},
  1113  					leaveCount: 2,
  1114  					checkStats: iptestutil.CheckIGMPv3Stats,
  1115  				},
  1116  			},
  1117  		},
  1118  		{
  1119  			name:          "MLD",
  1120  			protoNum:      ipv6.ProtocolNumber,
  1121  			multicastAddr: ipv6MulticastAddr1,
  1122  			rxReport: func(e *channel.Endpoint) {
  1123  				createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1, 0 /* extraLength */)
  1124  			},
  1125  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
  1126  			checkInitialGroups:          checkInitialIPv6Groups,
  1127  			subTests: []subTest{
  1128  				{
  1129  					name: "V1",
  1130  					enterVersion: func(e *channel.Endpoint) {
  1131  						// V1 query for unrelated group.
  1132  						createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */)
  1133  					},
  1134  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
  1135  						t.Helper()
  1136  
  1137  						validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
  1138  					},
  1139  					leaveCount: 0,
  1140  					checkStats: iptestutil.CheckMLDv1Stats,
  1141  				},
  1142  				{
  1143  					name:         "V2",
  1144  					enterVersion: func(*channel.Endpoint) {},
  1145  					validateReport: func(t *testing.T, p *stack.PacketBuffer) {
  1146  						t.Helper()
  1147  
  1148  						validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToExcludeMode)
  1149  					},
  1150  					validateLeave: func(t *testing.T, p *stack.PacketBuffer) {
  1151  						t.Helper()
  1152  
  1153  						validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToIncludeMode)
  1154  					},
  1155  					leaveCount: 2,
  1156  					checkStats: iptestutil.CheckMLDv2Stats,
  1157  				},
  1158  			},
  1159  		},
  1160  	}
  1161  
  1162  	for _, test := range tests {
  1163  		t.Run(test.name, func(t *testing.T) {
  1164  			for _, subTest := range test.subTests {
  1165  				t.Run(subTest.name, func(t *testing.T) {
  1166  					ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
  1167  					defer ctx.cleanup()
  1168  					s, e, clock := ctx.s, ctx.e, ctx.clock
  1169  
  1170  					var reportCounter uint64
  1171  					var leaveCounter uint64
  1172  					var reportV2Counter uint64
  1173  					if test.checkInitialGroups != nil {
  1174  						reportV2Counter = test.checkInitialGroups(t, e, s, clock)
  1175  					}
  1176  
  1177  					subTest.enterVersion(e)
  1178  
  1179  					if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
  1180  						t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
  1181  					}
  1182  					reportCounter++
  1183  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1184  					if p := e.Read(); p == nil {
  1185  						t.Fatal("expected a report message to be sent")
  1186  					} else {
  1187  						subTest.validateReport(t, p)
  1188  						p.DecRef()
  1189  					}
  1190  					if t.Failed() {
  1191  						t.FailNow()
  1192  					}
  1193  
  1194  					// Receiving a report for a group we joined should cancel any further
  1195  					// reports.
  1196  					test.rxReport(e)
  1197  					clock.Advance(time.Hour)
  1198  					subTest.enterVersion(e)
  1199  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1200  					if p := e.Read(); p != nil {
  1201  						t.Errorf("sent unexpected packet = %#v", p)
  1202  					}
  1203  					if t.Failed() {
  1204  						t.FailNow()
  1205  					}
  1206  
  1207  					// Leaving a group after getting a report should not send a leave/done
  1208  					// message.
  1209  					if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
  1210  						t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
  1211  					}
  1212  					for i := subTest.leaveCount; i > 0; i-- {
  1213  						leaveCounter++
  1214  						subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1215  						if p := e.Read(); p == nil {
  1216  							t.Fatal("expected a leave message to be sent")
  1217  						} else {
  1218  							subTest.validateLeave(t, p)
  1219  							p.DecRef()
  1220  						}
  1221  						clock.Advance(test.maxUnsolicitedResponseDelay)
  1222  					}
  1223  
  1224  					// Should not send any more packets.
  1225  					clock.Advance(time.Hour)
  1226  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1227  					if p := e.Read(); p != nil {
  1228  						t.Fatalf("sent unexpected packet = %#v", p)
  1229  					}
  1230  				})
  1231  			}
  1232  		})
  1233  	}
  1234  }
  1235  
  1236  func TestMGPWithNICLifecycle(t *testing.T) {
  1237  	type subTest struct {
  1238  		name            string
  1239  		v1Compatibility bool
  1240  		enterVersion    func(e *channel.Endpoint)
  1241  		validateReport  func(*testing.T, *stack.PacketBuffer, tcpip.Address)
  1242  		validateLeave   func(*testing.T, *channel.Endpoint, []tcpip.Address)
  1243  		checkStats      func(*testing.T, *stack.Stack, uint64, uint64, uint64)
  1244  	}
  1245  
  1246  	tests := []struct {
  1247  		name                        string
  1248  		protoNum                    tcpip.NetworkProtocolNumber
  1249  		multicastAddrs              []tcpip.Address
  1250  		finalMulticastAddr          tcpip.Address
  1251  		maxUnsolicitedResponseDelay time.Duration
  1252  		sentReportStat              func(*stack.Stack) *tcpip.StatCounter
  1253  		sentLeaveStat               func(*stack.Stack) *tcpip.StatCounter
  1254  		validateReport              func(*testing.T, *channel.Endpoint, []tcpip.Address)
  1255  		validateLeave               func(*testing.T, *stack.PacketBuffer, tcpip.Address)
  1256  		checkInitialGroups          func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64
  1257  		checkStats                  func(*testing.T, *stack.Stack, uint64, uint64, uint64)
  1258  		subTests                    []subTest
  1259  	}{
  1260  		{
  1261  			name:                        "IGMP",
  1262  			protoNum:                    ipv4.ProtocolNumber,
  1263  			multicastAddrs:              []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
  1264  			finalMulticastAddr:          ipv4MulticastAddr3,
  1265  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
  1266  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1267  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
  1268  			},
  1269  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
  1270  				return s.Stats().IGMP.PacketsSent.LeaveGroup
  1271  			},
  1272  			validateReport: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1273  				t.Helper()
  1274  				iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, stackIPv4Addr, addrs, header.IGMPv3ReportRecordChangeToExcludeMode)
  1275  			},
  1276  			validateLeave: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1277  				t.Helper()
  1278  
  1279  				validateIGMPv3ReportPacket(t, p, []tcpip.Address{addr}, header.IGMPv3ReportRecordChangeToIncludeMode)
  1280  			},
  1281  			checkStats: iptestutil.CheckIGMPv3Stats,
  1282  			subTests: []subTest{
  1283  				{
  1284  					name:            "V2",
  1285  					v1Compatibility: true,
  1286  					enterVersion: func(e *channel.Endpoint) {
  1287  						// V2 query for unrelated group.
  1288  						createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */)
  1289  					},
  1290  					validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1291  						t.Helper()
  1292  
  1293  						validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
  1294  					},
  1295  					validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1296  						t.Helper()
  1297  						iptestutil.ValidMultipleIGMPv2ReportLeaves(t, e, stackIPv4Addr, addrs, true /* leave */)
  1298  					},
  1299  					checkStats: iptestutil.CheckIGMPv2Stats,
  1300  				},
  1301  				{
  1302  					name:            "V3",
  1303  					v1Compatibility: false,
  1304  					enterVersion:    func(*channel.Endpoint) {},
  1305  					validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1306  						t.Helper()
  1307  
  1308  						validateIGMPv3ReportPacket(t, p, []tcpip.Address{addr}, header.IGMPv3ReportRecordChangeToExcludeMode)
  1309  					},
  1310  					validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1311  						t.Helper()
  1312  						iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, stackIPv4Addr, addrs, header.IGMPv3ReportRecordChangeToIncludeMode)
  1313  					},
  1314  					checkStats: iptestutil.CheckIGMPv3Stats,
  1315  				},
  1316  			},
  1317  		},
  1318  		{
  1319  			name:                        "MLD",
  1320  			protoNum:                    ipv6.ProtocolNumber,
  1321  			multicastAddrs:              []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
  1322  			finalMulticastAddr:          ipv6MulticastAddr3,
  1323  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
  1324  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1325  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
  1326  			},
  1327  			sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
  1328  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
  1329  			},
  1330  			validateReport: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1331  				t.Helper()
  1332  
  1333  				iptestutil.ValidateMLDv2RecordsAcrossReports(t, e, linkLocalIPv6Addr1, addrs, header.MLDv2ReportRecordChangeToExcludeMode)
  1334  			},
  1335  			validateLeave: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1336  				t.Helper()
  1337  
  1338  				validateMLDv2ReportPacket(t, p, []tcpip.Address{addr}, header.MLDv2ReportRecordChangeToIncludeMode)
  1339  			},
  1340  			checkInitialGroups: checkInitialIPv6Groups,
  1341  			checkStats:         iptestutil.CheckMLDv2Stats,
  1342  			subTests: []subTest{
  1343  				{
  1344  					name:            "V1",
  1345  					v1Compatibility: true,
  1346  					enterVersion: func(e *channel.Endpoint) {
  1347  						// V1 query for unrelated group.
  1348  						createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */)
  1349  					},
  1350  					validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1351  						t.Helper()
  1352  
  1353  						validateMLDPacket(t, p, addr, mldReport, 0, addr)
  1354  					},
  1355  					validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1356  						t.Helper()
  1357  
  1358  						iptestutil.ValidMultipleMLDv1ReportLeaves(t, e, linkLocalIPv6Addr1, addrs, true /* leave */)
  1359  					},
  1360  					checkStats: iptestutil.CheckMLDv1Stats,
  1361  				},
  1362  				{
  1363  					name:            "V2",
  1364  					v1Compatibility: false,
  1365  					enterVersion:    func(*channel.Endpoint) {},
  1366  					validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1367  						t.Helper()
  1368  
  1369  						validateMLDv2ReportPacket(t, p, []tcpip.Address{addr}, header.MLDv2ReportRecordChangeToExcludeMode)
  1370  					},
  1371  					validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1372  						t.Helper()
  1373  
  1374  						iptestutil.ValidateMLDv2RecordsAcrossReports(t, e, linkLocalIPv6Addr1, addrs, header.MLDv2ReportRecordChangeToIncludeMode)
  1375  					},
  1376  					checkStats: iptestutil.CheckMLDv2Stats,
  1377  				},
  1378  			},
  1379  		},
  1380  	}
  1381  
  1382  	for _, test := range tests {
  1383  		t.Run(test.name, func(t *testing.T) {
  1384  			for _, subTest := range test.subTests {
  1385  				t.Run(subTest.name, func(t *testing.T) {
  1386  					ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
  1387  					defer ctx.cleanup()
  1388  					s, e, clock := ctx.s, ctx.e, ctx.clock
  1389  
  1390  					var reportCounter uint64
  1391  					var leaveCounter uint64
  1392  					var reportV2Counter uint64
  1393  					if test.checkInitialGroups != nil {
  1394  						reportV2Counter = test.checkInitialGroups(t, e, s, clock)
  1395  					}
  1396  
  1397  					subTest.enterVersion(e)
  1398  
  1399  					for _, a := range test.multicastAddrs {
  1400  						if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
  1401  							t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
  1402  						}
  1403  						reportCounter++
  1404  						subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1405  						if p := e.Read(); p == nil {
  1406  							t.Fatalf("expected a report message to be sent for %s", a)
  1407  						} else {
  1408  							subTest.validateReport(t, p, a)
  1409  							p.DecRef()
  1410  						}
  1411  					}
  1412  					if t.Failed() {
  1413  						t.FailNow()
  1414  					}
  1415  
  1416  					// Leave messages should be sent for the joined groups when the NIC is
  1417  					// disabled.
  1418  					if err := s.DisableNIC(nicID); err != nil {
  1419  						t.Fatalf("DisableNIC(%d): %s", nicID, err)
  1420  					}
  1421  					{
  1422  						numMessages := 1
  1423  						if subTest.v1Compatibility {
  1424  							numMessages = len(test.multicastAddrs)
  1425  						}
  1426  						leaveCounter += uint64(numMessages)
  1427  						subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1428  						subTest.validateLeave(t, e, test.multicastAddrs)
  1429  					}
  1430  					if t.Failed() {
  1431  						t.FailNow()
  1432  					}
  1433  
  1434  					// Reports should be sent for the joined groups when the NIC is enabled.
  1435  					if err := s.EnableNIC(nicID); err != nil {
  1436  						t.Fatalf("EnableNIC(%d): %s", nicID, err)
  1437  					}
  1438  					reportV2Counter++
  1439  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1440  					test.validateReport(t, e, test.multicastAddrs)
  1441  					if t.Failed() {
  1442  						t.FailNow()
  1443  					}
  1444  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1445  
  1446  					// Joining/leaving a group while disabled should not send any messages.
  1447  					if err := s.DisableNIC(nicID); err != nil {
  1448  						t.Fatalf("DisableNIC(%d): %s", nicID, err)
  1449  					}
  1450  					reportV2Counter++
  1451  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1452  					if p := e.Read(); p == nil {
  1453  						t.Fatal("expected leave message to be sent")
  1454  					} else {
  1455  						p.DecRef()
  1456  					}
  1457  					for _, a := range test.multicastAddrs {
  1458  						if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
  1459  							t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
  1460  						}
  1461  						subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1462  						if p := e.Read(); p != nil {
  1463  							t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p)
  1464  						}
  1465  					}
  1466  					if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
  1467  						t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
  1468  					}
  1469  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1470  					if p := e.Read(); p != nil {
  1471  						t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p)
  1472  					}
  1473  
  1474  					// A report should only be sent for the group we last joined after
  1475  					// enabling the NIC since the original groups were all left.
  1476  					if err := s.EnableNIC(nicID); err != nil {
  1477  						t.Fatalf("EnableNIC(%d): %s", nicID, err)
  1478  					}
  1479  					reportV2Counter++
  1480  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1481  					test.validateReport(t, e, []tcpip.Address{test.finalMulticastAddr})
  1482  
  1483  					clock.Advance(test.maxUnsolicitedResponseDelay)
  1484  					reportV2Counter++
  1485  					subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter)
  1486  					test.validateReport(t, e, []tcpip.Address{test.finalMulticastAddr})
  1487  
  1488  					// Should not send any more packets.
  1489  					clock.Advance(time.Hour)
  1490  					if p := e.Read(); p != nil {
  1491  						t.Fatalf("sent unexpected packet = %#v", p)
  1492  					}
  1493  				})
  1494  			}
  1495  		})
  1496  	}
  1497  }
  1498  
  1499  // TestMGPDisabledOnLoopback tests that the multicast group protocol is not
  1500  // performed on loopback interfaces since they have no neighbours.
  1501  func TestMGPDisabledOnLoopback(t *testing.T) {
  1502  	tests := []struct {
  1503  		name           string
  1504  		protoNum       tcpip.NetworkProtocolNumber
  1505  		multicastAddr  tcpip.Address
  1506  		sentReportStat func(*stack.Stack) *tcpip.StatCounter
  1507  	}{
  1508  		{
  1509  			name:          "IGMP",
  1510  			protoNum:      ipv4.ProtocolNumber,
  1511  			multicastAddr: ipv4MulticastAddr1,
  1512  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1513  				return s.Stats().IGMP.PacketsSent.V2MembershipReport
  1514  			},
  1515  		},
  1516  		{
  1517  			name:          "MLD",
  1518  			protoNum:      ipv6.ProtocolNumber,
  1519  			multicastAddr: ipv6MulticastAddr1,
  1520  			sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
  1521  				return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
  1522  			},
  1523  		},
  1524  	}
  1525  
  1526  	for _, test := range tests {
  1527  		t.Run(test.name, func(t *testing.T) {
  1528  			s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
  1529  			defer func() {
  1530  				s.Close()
  1531  				s.Wait()
  1532  			}()
  1533  			sentReportStat := test.sentReportStat(s)
  1534  			if got := sentReportStat.Value(); got != 0 {
  1535  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1536  			}
  1537  			clock.Advance(time.Hour)
  1538  			if got := sentReportStat.Value(); got != 0 {
  1539  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1540  			}
  1541  
  1542  			// Test joining a specific group explicitly and verify that no reports are
  1543  			// sent.
  1544  			if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
  1545  				t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
  1546  			}
  1547  			if got := sentReportStat.Value(); got != 0 {
  1548  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1549  			}
  1550  			clock.Advance(time.Hour)
  1551  			if got := sentReportStat.Value(); got != 0 {
  1552  				t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
  1553  			}
  1554  		})
  1555  	}
  1556  }
  1557  
  1558  func TestMGPCoalescedQueryResponseRecords(t *testing.T) {
  1559  	const igmpv3MLDv2ReportRecordHeaderLen = 4
  1560  
  1561  	type subTest struct {
  1562  		name           string
  1563  		enterVersion   func(e *channel.Endpoint)
  1564  		validateReport func(*testing.T, *stack.PacketBuffer)
  1565  		checkStats     func(*testing.T, *stack.Stack, uint64, uint64, uint64)
  1566  	}
  1567  
  1568  	genAddr := func(bytes []byte, i uint16) tcpip.Address {
  1569  		bytes[len(bytes)-1] = byte(i & 0xFF)
  1570  		bytes[len(bytes)-2] = byte(i >> 8)
  1571  		return tcpip.AddrFromSlice(bytes[:])
  1572  	}
  1573  
  1574  	calcMaxRecordsPerMessage := func(hdrLen, recordLen uint16) uint16 {
  1575  		return (header.IPv6MinimumMTU - hdrLen) / recordLen
  1576  	}
  1577  
  1578  	tests := []struct {
  1579  		name                              string
  1580  		protoNum                          tcpip.NetworkProtocolNumber
  1581  		maxUnsolicitedResponseDelay       time.Duration
  1582  		receivedQueryStat                 func(*stack.Stack) *tcpip.StatCounter
  1583  		checkInitialGroups                func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64
  1584  		validateReport                    func(*testing.T, *stack.PacketBuffer, tcpip.Address)
  1585  		checkStats                        func(*testing.T, *stack.Stack, uint64)
  1586  		genAddr                           func(uint16) tcpip.Address
  1587  		maxRecordsPerMessage              uint16
  1588  		rxQuery                           func(*channel.Endpoint, uint8)
  1589  		validateReportWithMultipleRecords func(*testing.T, *channel.Endpoint, []tcpip.Address)
  1590  	}{
  1591  		{
  1592  			name:                        "IGMP",
  1593  			protoNum:                    ipv4.ProtocolNumber,
  1594  			maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
  1595  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
  1596  				return s.Stats().IGMP.PacketsReceived.MembershipQuery
  1597  			},
  1598  			validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1599  				t.Helper()
  1600  
  1601  				validateIGMPv3ReportPacket(t, p, []tcpip.Address{addr}, header.IGMPv3ReportRecordChangeToExcludeMode)
  1602  			},
  1603  			checkStats: func(t *testing.T, s *stack.Stack, reports uint64) {
  1604  				t.Helper()
  1605  				iptestutil.CheckIGMPv3Stats(t, s, 0, 0, reports)
  1606  			},
  1607  			genAddr: func(i uint16) tcpip.Address {
  1608  				bytes := [header.IPv4AddressSize]byte{224, 1, 0, 0}
  1609  				return genAddr(bytes[:], i)
  1610  			},
  1611  			maxRecordsPerMessage: calcMaxRecordsPerMessage(header.IPv4MinimumSize+8 /* size of IGMPv3 report header */, igmpv3MLDv2ReportRecordHeaderLen+header.IPv4AddressSize),
  1612  			rxQuery: func(e *channel.Endpoint, maxRespTime uint8) {
  1613  				createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, header.IPv4Any, header.IGMPv3QueryMinimumSize-header.IGMPQueryMinimumSize /* extraLength */)
  1614  			},
  1615  			validateReportWithMultipleRecords: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1616  				t.Helper()
  1617  				iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, stackIPv4Addr, addrs, header.IGMPv3ReportRecordModeIsExclude)
  1618  			},
  1619  		},
  1620  		{
  1621  			name:                        "MLD",
  1622  			protoNum:                    ipv6.ProtocolNumber,
  1623  			maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
  1624  			receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
  1625  				return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
  1626  			},
  1627  			checkInitialGroups: checkInitialIPv6Groups,
  1628  			validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) {
  1629  				t.Helper()
  1630  
  1631  				validateMLDv2ReportPacket(t, p, []tcpip.Address{addr}, header.MLDv2ReportRecordChangeToExcludeMode)
  1632  			},
  1633  			checkStats: func(t *testing.T, s *stack.Stack, reports uint64) {
  1634  				t.Helper()
  1635  				iptestutil.CheckMLDv2Stats(t, s, 0, 0, reports)
  1636  			},
  1637  			genAddr: func(i uint16) tcpip.Address {
  1638  				bytes := [header.IPv6AddressSize]byte{0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}
  1639  				return genAddr(bytes[:], i)
  1640  			},
  1641  			maxRecordsPerMessage: calcMaxRecordsPerMessage(header.IPv6MinimumSize+8 /* size of MLDv2 report header */, igmpv3MLDv2ReportRecordHeaderLen+header.IPv6AddressSize),
  1642  			rxQuery: func(e *channel.Endpoint, maxRespTime uint8) {
  1643  				createAndInjectMLDPacket(e, mldQuery, maxRespTime, header.IPv6Any, header.MLDv2QueryMinimumSize-header.MLDMinimumSize /* extraLength */)
  1644  			},
  1645  			validateReportWithMultipleRecords: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) {
  1646  				t.Helper()
  1647  
  1648  				iptestutil.ValidateMLDv2RecordsAcrossReports(t, e, linkLocalIPv6Addr1, addrs, header.MLDv2ReportRecordModeIsExclude)
  1649  			},
  1650  		},
  1651  	}
  1652  
  1653  	subTests := []struct {
  1654  		name            string
  1655  		extraRecords    uint16
  1656  		expectedReports uint64
  1657  	}{
  1658  		{
  1659  			name:            "No extra records",
  1660  			extraRecords:    0,
  1661  			expectedReports: 1,
  1662  		},
  1663  		{
  1664  			name:            "One extra record",
  1665  			extraRecords:    1,
  1666  			expectedReports: 2,
  1667  		},
  1668  		{
  1669  			name:            "Two extra records",
  1670  			extraRecords:    2,
  1671  			expectedReports: 2,
  1672  		},
  1673  	}
  1674  
  1675  	for _, test := range tests {
  1676  		t.Run(test.name, func(t *testing.T) {
  1677  			for _, subTest := range subTests {
  1678  				t.Run(subTest.name, func(t *testing.T) {
  1679  					ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
  1680  					defer ctx.cleanup()
  1681  					s, e, clock := ctx.s, ctx.e, ctx.clock
  1682  
  1683  					var reportV2Counter uint64
  1684  					if test.checkInitialGroups != nil {
  1685  						reportV2Counter = test.checkInitialGroups(t, e, s, clock)
  1686  					}
  1687  
  1688  					addrs := make([]tcpip.Address, test.maxRecordsPerMessage+subTest.extraRecords)
  1689  					for i := 0; i < len(addrs); i++ {
  1690  						addr := test.genAddr(uint16(i))
  1691  						addrs[i] = addr
  1692  
  1693  						if err := s.JoinGroup(test.protoNum, nicID, addr); err != nil {
  1694  							t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, addr, err)
  1695  						}
  1696  						reportV2Counter++
  1697  						test.checkStats(t, s, reportV2Counter)
  1698  						if p := e.Read(); p == nil {
  1699  							t.Fatal("expected a report message to be sent")
  1700  						} else {
  1701  							test.validateReport(t, p, addr)
  1702  							p.DecRef()
  1703  						}
  1704  						if t.Failed() {
  1705  							t.FailNow()
  1706  						}
  1707  
  1708  						// Verify the second report is sent by the maximum unsolicited response
  1709  						// interval.
  1710  						p := e.Read()
  1711  						if p != nil {
  1712  							t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p)
  1713  						}
  1714  						clock.Advance(test.maxUnsolicitedResponseDelay)
  1715  						reportV2Counter++
  1716  						test.checkStats(t, s, reportV2Counter)
  1717  						if p := e.Read(); p == nil {
  1718  							t.Fatal("expected a report message to be sent")
  1719  						} else {
  1720  							test.validateReport(t, p, addr)
  1721  							p.DecRef()
  1722  						}
  1723  					}
  1724  
  1725  					// Should not send any more packets.
  1726  					clock.Advance(time.Hour)
  1727  					if p := e.Read(); p != nil {
  1728  						t.Fatalf("sent unexpected packet = %#v", p)
  1729  					}
  1730  					test.checkStats(t, s, reportV2Counter)
  1731  
  1732  					// Receive a query which should send a few reports which together hold
  1733  					// records for all the groups we joined.
  1734  					test.rxQuery(e, 1)
  1735  					clock.Advance(time.Second)
  1736  					reportV2Counter += subTest.expectedReports
  1737  					test.checkStats(t, s, reportV2Counter)
  1738  					test.validateReportWithMultipleRecords(t, e, addrs)
  1739  				})
  1740  			}
  1741  		})
  1742  	}
  1743  }