gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/ipv4/igmp_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ipv4_test
    16  
    17  import (
    18  	"testing"
    19  	"time"
    20  
    21  	"gvisor.dev/gvisor/pkg/buffer"
    22  	"gvisor.dev/gvisor/pkg/refs"
    23  	"gvisor.dev/gvisor/pkg/tcpip"
    24  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    25  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    26  	"gvisor.dev/gvisor/pkg/tcpip/header"
    27  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    28  	iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
    29  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    30  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    31  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    32  )
    33  
    34  const (
    35  	linkAddr            = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
    36  	nicID               = 1
    37  	defaultTTL          = 1
    38  	defaultPrefixLength = 24
    39  )
    40  
    41  var (
    42  	stackAddr           = testutil.MustParse4("10.0.0.1")
    43  	remoteAddr          = testutil.MustParse4("10.0.0.2")
    44  	multicastAddr1      = testutil.MustParse4("224.0.0.3")
    45  	multicastAddr2      = testutil.MustParse4("224.0.0.4")
    46  	multicastAddr3      = testutil.MustParse4("224.0.0.5")
    47  	multicastAddr4      = testutil.MustParse4("224.0.0.6")
    48  	unusedMulticastAddr = testutil.MustParse4("224.0.0.7")
    49  )
    50  
    51  // validateIgmpPacket checks that a passed packet is an IPv4 IGMP packet sent
    52  // to the provided address with the passed fields set. Raises a t.Error if any
    53  // field does not match.
    54  func validateIgmpPacket(t *testing.T, pkt *stack.PacketBuffer, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) {
    55  	t.Helper()
    56  
    57  	payload := stack.PayloadSince(pkt.NetworkHeader())
    58  	defer payload.Release()
    59  	checker.IPv4(t, payload,
    60  		checker.SrcAddr(srcAddr),
    61  		checker.DstAddr(dstAddr),
    62  		// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
    63  		checker.TTL(1),
    64  		checker.IPv4RouterAlert(),
    65  		checker.IGMP(
    66  			checker.IGMPType(igmpType),
    67  			checker.IGMPMaxRespTime(header.DecisecondToDuration(uint16(maxRespTime))),
    68  			checker.IGMPGroupAddress(groupAddress),
    69  		),
    70  	)
    71  }
    72  
    73  func validateIgmpv3ReportPacket(t *testing.T, pkt *stack.PacketBuffer, srcAddr, groupAddress tcpip.Address) {
    74  	t.Helper()
    75  
    76  	payload := stack.PayloadSince(pkt.NetworkHeader())
    77  	defer payload.Release()
    78  	iptestutil.ValidateIGMPv3Report(t, payload, srcAddr, []tcpip.Address{groupAddress}, header.IGMPv3ReportRecordChangeToExcludeMode)
    79  }
    80  
    81  type igmpTestContext struct {
    82  	s     *stack.Stack
    83  	ep    *channel.Endpoint
    84  	clock *faketime.ManualClock
    85  }
    86  
    87  func (ctx igmpTestContext) cleanup() {
    88  	ctx.s.Close()
    89  	ctx.s.Wait()
    90  	ctx.ep.Close()
    91  	refs.DoRepeatedLeakCheck()
    92  }
    93  
    94  func newIGMPTestContext(t *testing.T, igmpEnabled bool) igmpTestContext {
    95  	t.Helper()
    96  
    97  	// Create an endpoint of queue size 2, since no more than 2 packets are ever
    98  	// queued in the tests in this file.
    99  	e := channel.New(2, 1280, linkAddr)
   100  	clock := faketime.NewManualClock()
   101  	s := stack.New(stack.Options{
   102  		NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
   103  			IGMP: ipv4.IGMPOptions{
   104  				Enabled: igmpEnabled,
   105  			},
   106  		})},
   107  		Clock: clock,
   108  	})
   109  	if err := s.CreateNIC(nicID, e); err != nil {
   110  		t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
   111  	}
   112  
   113  	return igmpTestContext{
   114  		ep:    e,
   115  		s:     s,
   116  		clock: clock,
   117  	}
   118  }
   119  
   120  func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) {
   121  	var options header.IPv4OptionsSerializer
   122  	if hasRouterAlertOption {
   123  		options = header.IPv4OptionsSerializer{
   124  			&header.IPv4SerializableRouterAlertOption{},
   125  		}
   126  	}
   127  	buf := make([]byte, header.IPv4MinimumSize+int(options.Length())+header.IGMPQueryMinimumSize)
   128  
   129  	ip := header.IPv4(buf)
   130  	ip.Encode(&header.IPv4Fields{
   131  		TotalLength: uint16(len(buf)),
   132  		TTL:         ttl,
   133  		Protocol:    uint8(header.IGMPProtocolNumber),
   134  		SrcAddr:     srcAddr,
   135  		DstAddr:     dstAddr,
   136  		Options:     options,
   137  	})
   138  	ip.SetChecksum(^ip.CalculateChecksum())
   139  
   140  	igmp := header.IGMP(ip.Payload())
   141  	igmp.SetType(igmpType)
   142  	igmp.SetMaxRespTime(maxRespTime)
   143  	igmp.SetGroupAddress(groupAddress)
   144  	igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
   145  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   146  		Payload: buffer.MakeWithData(buf),
   147  	})
   148  	e.InjectInbound(ipv4.ProtocolNumber, pkt)
   149  	pkt.DecRef()
   150  }
   151  
   152  // TestIGMPV1Present tests the node's ability to fallback to V1 when a V1
   153  // router is detected. V1 present status is expected to be reset when the NIC
   154  // cycles.
   155  func TestIGMPV1Present(t *testing.T) {
   156  	ctx := newIGMPTestContext(t, true /* igmpEnabled */)
   157  	defer ctx.cleanup()
   158  	s := ctx.s
   159  	e := ctx.ep
   160  
   161  	protocolAddr := tcpip.ProtocolAddress{
   162  		Protocol:          ipv4.ProtocolNumber,
   163  		AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
   164  	}
   165  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   166  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   167  	}
   168  
   169  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr1); err != nil {
   170  		t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr1, err)
   171  	}
   172  
   173  	// This NIC will send an IGMPv3 report immediately, before this test can get
   174  	// the IGMPv1 General Membership Query in.
   175  	{
   176  		p := e.Read()
   177  		if p == nil {
   178  			t.Fatal("unable to Read IGMP packet, expected V3MembershipReport")
   179  		}
   180  		if got := s.Stats().IGMP.PacketsSent.V3MembershipReport.Value(); got != 1 {
   181  			t.Fatalf("got V3MembershipReport messages sent = %d, want = 1", got)
   182  		}
   183  		validateIgmpv3ReportPacket(t, p, stackAddr, multicastAddr1)
   184  		p.DecRef()
   185  	}
   186  	if t.Failed() {
   187  		t.FailNow()
   188  	}
   189  
   190  	// Inject an IGMPv1 General Membership Query which is identical to a standard
   191  	// membership query except the Max Response Time is set to 0, which will tell
   192  	// the stack that this is a router using IGMPv1.
   193  	createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, multicastAddr1, true /* hasRouterAlertOption */)
   194  	if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
   195  		t.Fatalf("got Membership Queries received = %d, want = 1", got)
   196  	}
   197  
   198  	// Before advancing the clock, verify that this host has not sent a
   199  	// V1MembershipReport yet.
   200  	if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
   201  		t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
   202  	}
   203  
   204  	// Verify the solicited Membership Report is sent. Now that this NIC has seen
   205  	// an IGMPv1 query, it should send an IGMPv1 Membership Report.
   206  	if p := e.Read(); p != nil {
   207  		t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p)
   208  	}
   209  	ctx.clock.Advance(ipv4.UnsolicitedReportIntervalMax)
   210  	{
   211  		p := e.Read()
   212  		if p == nil {
   213  			t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
   214  		}
   215  		if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
   216  			t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
   217  		}
   218  		validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr1, multicastAddr1)
   219  		p.DecRef()
   220  	}
   221  
   222  	// Cycling the interface should reset the V1 present flag.
   223  	if err := s.DisableNIC(nicID); err != nil {
   224  		t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
   225  	}
   226  	if err := s.EnableNIC(nicID); err != nil {
   227  		t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
   228  	}
   229  	{
   230  		p := e.Read()
   231  		if p == nil {
   232  			t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
   233  		}
   234  		if got := s.Stats().IGMP.PacketsSent.V3MembershipReport.Value(); got != 2 {
   235  			t.Fatalf("got V3MembershipReport messages sent = %d, want = 2", got)
   236  		}
   237  		validateIgmpv3ReportPacket(t, p, stackAddr, multicastAddr1)
   238  		p.DecRef()
   239  	}
   240  }
   241  
   242  func TestSendQueuedIGMPReports(t *testing.T) {
   243  	tests := []struct {
   244  		name            string
   245  		v2Compatibility bool
   246  		validate        func(t *testing.T, e *channel.Endpoint, localAddress tcpip.Address, groupAddresses []tcpip.Address)
   247  		checkStats      func(*testing.T, *stack.Stack, uint64, uint64, uint64)
   248  	}{
   249  		{
   250  			name:            "V2 Compatibility",
   251  			v2Compatibility: true,
   252  			validate: func(t *testing.T, e *channel.Endpoint, localAddress tcpip.Address, groupAddresses []tcpip.Address) {
   253  				t.Helper()
   254  
   255  				iptestutil.ValidMultipleIGMPv2ReportLeaves(t, e, localAddress, groupAddresses, false /* leave */)
   256  			},
   257  			checkStats: iptestutil.CheckIGMPv2Stats,
   258  		},
   259  		{
   260  			name:            "V3",
   261  			v2Compatibility: false,
   262  			validate: func(t *testing.T, e *channel.Endpoint, localAddress tcpip.Address, groupAddresses []tcpip.Address) {
   263  				t.Helper()
   264  
   265  				iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, localAddress, groupAddresses, header.IGMPv3ReportRecordChangeToExcludeMode)
   266  			},
   267  			checkStats: iptestutil.CheckIGMPv3Stats,
   268  		},
   269  	}
   270  
   271  	for _, test := range tests {
   272  		t.Run(test.name, func(t *testing.T) {
   273  			ctx := newIGMPTestContext(t, true /* igmpEnabled */)
   274  			defer ctx.cleanup()
   275  			s := ctx.s
   276  			e := ctx.ep
   277  			clock := ctx.clock
   278  
   279  			checkVersion := func() {
   280  				if test.v2Compatibility {
   281  					ep, err := s.GetNetworkEndpoint(nicID, header.IPv4ProtocolNumber)
   282  					if err != nil {
   283  						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv4ProtocolNumber, err)
   284  					}
   285  
   286  					igmpEP, ok := ep.(ipv4.IGMPEndpoint)
   287  					if !ok {
   288  						t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, igmpEP)
   289  					}
   290  
   291  					igmpEP.SetIGMPVersion(ipv4.IGMPVersion2)
   292  				}
   293  			}
   294  			protocolAddr := tcpip.ProtocolAddress{
   295  				Protocol: ipv4.ProtocolNumber,
   296  				AddressWithPrefix: tcpip.AddressWithPrefix{
   297  					Address:   stackAddr,
   298  					PrefixLen: defaultPrefixLength,
   299  				},
   300  			}
   301  			// Multicast traffic is not accepted unless we have an address so add an
   302  			// address and check the version which receives a multicast packet.
   303  			if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   304  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   305  			}
   306  			checkVersion()
   307  			if err := s.RemoveAddress(nicID, protocolAddr.AddressWithPrefix.Address); err != nil {
   308  				t.Fatalf("RemoveAddress(%d, %s): %s", nicID, protocolAddr.AddressWithPrefix.Address, err)
   309  			}
   310  
   311  			var reportCounter uint64
   312  			var doneCounter uint64
   313  			var reportV2Counter uint64
   314  			test.checkStats(t, s, reportCounter, doneCounter, reportV2Counter)
   315  
   316  			// Joining groups without an assigned address should queue IGMP packets;
   317  			// none should be sent without an assigned address.
   318  			multicastAddrs := []tcpip.Address{multicastAddr1, multicastAddr2}
   319  			for _, multicastAddr := range multicastAddrs {
   320  				if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
   321  					t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
   322  				}
   323  			}
   324  			test.checkStats(t, s, reportCounter, doneCounter, reportV2Counter)
   325  			if p := e.Read(); p != nil {
   326  				t.Fatalf("got unexpected packet = %#v", p)
   327  			}
   328  
   329  			// The initial set of IGMP reports that were queued should be sent once an
   330  			// address is assigned.
   331  			if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   332  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   333  			}
   334  
   335  			// We expect two batches of reports to be sent (1 batch when the address
   336  			// is assigned, and another after the maximum unsolicited report interval.
   337  			for i := 0; i < 2; i++ {
   338  				// IGMPv2 always sends a single message per group.
   339  				//
   340  				// IGMPv3 sends a single message per group when we first get an
   341  				// address assigned, but later reports (sent by the state changed
   342  				// timer) coalesce records for groups.
   343  				if test.v2Compatibility || i == 0 {
   344  					reportCounter += uint64(len(multicastAddrs))
   345  				} else {
   346  					reportCounter++
   347  				}
   348  				test.checkStats(t, s, reportCounter, doneCounter, reportV2Counter)
   349  				test.validate(t, e, stackAddr, multicastAddrs)
   350  
   351  				if t.Failed() {
   352  					t.FailNow()
   353  				}
   354  
   355  				clock.Advance(ipv4.UnsolicitedReportIntervalMax)
   356  			}
   357  
   358  			// Should have no more packets to send after the initial set of unsolicited
   359  			// reports.
   360  			clock.Advance(time.Hour)
   361  			if p := e.Read(); p != nil {
   362  				t.Fatalf("got unexpected packet = %#v", p)
   363  			}
   364  		})
   365  	}
   366  }
   367  
   368  func TestIGMPPacketValidation(t *testing.T) {
   369  	tests := []struct {
   370  		name                     string
   371  		messageType              header.IGMPType
   372  		stackAddresses           []tcpip.AddressWithPrefix
   373  		srcAddr                  tcpip.Address
   374  		includeRouterAlertOption bool
   375  		ttl                      uint8
   376  		expectValidIGMP          bool
   377  		getMessageTypeStatValue  func(tcpip.Stats) uint64
   378  	}{
   379  		{
   380  			name:                     "valid",
   381  			messageType:              header.IGMPLeaveGroup,
   382  			includeRouterAlertOption: true,
   383  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   384  			srcAddr:                  remoteAddr,
   385  			ttl:                      1,
   386  			expectValidIGMP:          true,
   387  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
   388  		},
   389  		{
   390  			name:                     "bad ttl",
   391  			messageType:              header.IGMPv1MembershipReport,
   392  			includeRouterAlertOption: true,
   393  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   394  			srcAddr:                  remoteAddr,
   395  			ttl:                      2,
   396  			expectValidIGMP:          false,
   397  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
   398  		},
   399  		{
   400  			name:                     "missing router alert ip option",
   401  			messageType:              header.IGMPv2MembershipReport,
   402  			includeRouterAlertOption: false,
   403  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   404  			srcAddr:                  remoteAddr,
   405  			ttl:                      1,
   406  			expectValidIGMP:          false,
   407  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
   408  		},
   409  		{
   410  			name:                     "igmp leave group and src ip does not belong to nic subnet",
   411  			messageType:              header.IGMPLeaveGroup,
   412  			includeRouterAlertOption: true,
   413  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   414  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   415  			ttl:                      1,
   416  			expectValidIGMP:          false,
   417  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
   418  		},
   419  		{
   420  			name:                     "igmp query and src ip does not belong to nic subnet",
   421  			messageType:              header.IGMPMembershipQuery,
   422  			includeRouterAlertOption: true,
   423  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   424  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   425  			ttl:                      1,
   426  			expectValidIGMP:          true,
   427  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() },
   428  		},
   429  		{
   430  			name:                     "igmp report v1 and src ip does not belong to nic subnet",
   431  			messageType:              header.IGMPv1MembershipReport,
   432  			includeRouterAlertOption: true,
   433  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   434  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   435  			ttl:                      1,
   436  			expectValidIGMP:          false,
   437  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
   438  		},
   439  		{
   440  			name:                     "igmp report v2 and src ip does not belong to nic subnet",
   441  			messageType:              header.IGMPv2MembershipReport,
   442  			includeRouterAlertOption: true,
   443  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   444  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   445  			ttl:                      1,
   446  			expectValidIGMP:          false,
   447  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
   448  		},
   449  		{
   450  			name:                     "src ip belongs to the subnet of the nic's second address",
   451  			messageType:              header.IGMPv2MembershipReport,
   452  			includeRouterAlertOption: true,
   453  			stackAddresses: []tcpip.AddressWithPrefix{
   454  				{Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24},
   455  				{Address: stackAddr, PrefixLen: 24},
   456  			},
   457  			srcAddr:                 remoteAddr,
   458  			ttl:                     1,
   459  			expectValidIGMP:         true,
   460  			getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
   461  		},
   462  	}
   463  
   464  	for _, test := range tests {
   465  		t.Run(test.name, func(t *testing.T) {
   466  			ctx := newIGMPTestContext(t, true /* igmpEnabled */)
   467  			defer ctx.cleanup()
   468  			s := ctx.s
   469  			e := ctx.ep
   470  
   471  			for _, address := range test.stackAddresses {
   472  				protocolAddr := tcpip.ProtocolAddress{
   473  					Protocol:          ipv4.ProtocolNumber,
   474  					AddressWithPrefix: address,
   475  				}
   476  				if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   477  					t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   478  				}
   479  			}
   480  			stats := s.Stats()
   481  			// Verify that every relevant stats is zero'd before we send a packet.
   482  			if got := test.getMessageTypeStatValue(s.Stats()); got != 0 {
   483  				t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got)
   484  			}
   485  			if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 {
   486  				t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got)
   487  			}
   488  			if got := stats.IP.PacketsDelivered.Value(); got != 0 {
   489  				t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got)
   490  			}
   491  			createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption)
   492  			// We always expect the packet to pass IP validation.
   493  			if got := stats.IP.PacketsDelivered.Value(); got != 1 {
   494  				t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got)
   495  			}
   496  			// Even when the IGMP-specific validation checks fail, we expect the
   497  			// corresponding IGMP counter to be incremented.
   498  			if got := test.getMessageTypeStatValue(s.Stats()); got != 1 {
   499  				t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got)
   500  			}
   501  			var expectedInvalidCount uint64
   502  			if !test.expectValidIGMP {
   503  				expectedInvalidCount = 1
   504  			}
   505  			if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount {
   506  				t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount)
   507  			}
   508  		})
   509  	}
   510  }
   511  
   512  func TestGetSetIGMPVersion(t *testing.T) {
   513  	const nicID = 1
   514  
   515  	c := newIGMPTestContext(t, true /* igmpEnabled */)
   516  	defer c.cleanup()
   517  	s := c.s
   518  	e := c.ep
   519  
   520  	ep, err := s.GetNetworkEndpoint(nicID, header.IPv4ProtocolNumber)
   521  	if err != nil {
   522  		t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv4ProtocolNumber, err)
   523  	}
   524  	igmpEP, ok := ep.(ipv4.IGMPEndpoint)
   525  	if !ok {
   526  		t.Fatalf("got (%T).(%T) = (_, false), want = (_ true)", ep, igmpEP)
   527  	}
   528  	if got := igmpEP.GetIGMPVersion(); got != ipv4.IGMPVersion3 {
   529  		t.Errorf("got igmpEP.GetIGMPVersion() = %d, want = %d", got, ipv4.IGMPVersion3)
   530  	}
   531  
   532  	protocolAddr := tcpip.ProtocolAddress{
   533  		Protocol:          ipv4.ProtocolNumber,
   534  		AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
   535  	}
   536  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   537  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   538  	}
   539  
   540  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr1); err != nil {
   541  		t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr1, err)
   542  	}
   543  	if p := e.Read(); p == nil {
   544  		t.Fatal("expected a report message to be sent")
   545  	} else {
   546  		validateIgmpv3ReportPacket(t, p, stackAddr, multicastAddr1)
   547  		p.DecRef()
   548  	}
   549  
   550  	if got := igmpEP.SetIGMPVersion(ipv4.IGMPVersion2); got != ipv4.IGMPVersion3 {
   551  		t.Errorf("got igmpEP.SetIGMPVersion(%d) = %d, want = %d", ipv4.IGMPVersion2, got, ipv4.IGMPVersion3)
   552  	}
   553  	if got := igmpEP.GetIGMPVersion(); got != ipv4.IGMPVersion2 {
   554  		t.Errorf("got igmpEP.GetIGMPVersion() = %d, want = %d", got, ipv4.IGMPVersion2)
   555  	}
   556  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr2); err != nil {
   557  		t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr2, err)
   558  	}
   559  	if p := e.Read(); p == nil {
   560  		t.Fatal("expected a report message to be sent")
   561  	} else {
   562  		validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr2, multicastAddr2)
   563  		p.DecRef()
   564  	}
   565  
   566  	if got := igmpEP.SetIGMPVersion(ipv4.IGMPVersion1); got != ipv4.IGMPVersion2 {
   567  		t.Errorf("got igmpEP.SetIGMPVersion(%d) = %d, want = %d", ipv4.IGMPVersion1, got, ipv4.IGMPVersion2)
   568  	}
   569  	if got := igmpEP.GetIGMPVersion(); got != ipv4.IGMPVersion1 {
   570  		t.Errorf("got igmpEP.GetIGMPVersion() = %d, want = %d", got, ipv4.IGMPVersion1)
   571  	}
   572  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr3); err != nil {
   573  		t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr3, err)
   574  	}
   575  	if p := e.Read(); p == nil {
   576  		t.Fatal("expected a report message to be sent")
   577  	} else {
   578  		validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr3, multicastAddr3)
   579  		p.DecRef()
   580  	}
   581  
   582  	if got := igmpEP.SetIGMPVersion(ipv4.IGMPVersion3); got != ipv4.IGMPVersion1 {
   583  		t.Errorf("got igmpEP.SetIGMPVersion(%d) = %d, want = %d", ipv4.IGMPVersion3, got, ipv4.IGMPVersion1)
   584  	}
   585  	if got := igmpEP.GetIGMPVersion(); got != ipv4.IGMPVersion3 {
   586  		t.Errorf("got igmpEP.GetIGMPVersion() = %d, want = %d", got, ipv4.IGMPVersion3)
   587  	}
   588  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr4); err != nil {
   589  		t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr4, err)
   590  	}
   591  	if p := e.Read(); p == nil {
   592  		t.Fatal("expected a report message to be sent")
   593  	} else {
   594  		validateIgmpv3ReportPacket(t, p, stackAddr, multicastAddr4)
   595  		p.DecRef()
   596  	}
   597  }