gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/multicast_group_test.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ip_test 16 17 import ( 18 "fmt" 19 "strings" 20 "testing" 21 "time" 22 23 "gvisor.dev/gvisor/pkg/buffer" 24 "gvisor.dev/gvisor/pkg/refs" 25 "gvisor.dev/gvisor/pkg/tcpip" 26 "gvisor.dev/gvisor/pkg/tcpip/checker" 27 "gvisor.dev/gvisor/pkg/tcpip/faketime" 28 "gvisor.dev/gvisor/pkg/tcpip/header" 29 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 30 "gvisor.dev/gvisor/pkg/tcpip/link/loopback" 31 iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" 32 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 33 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 34 "gvisor.dev/gvisor/pkg/tcpip/stack" 35 "gvisor.dev/gvisor/pkg/tcpip/testutil" 36 ) 37 38 const ( 39 linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") 40 41 defaultIPv4PrefixLength = 24 42 43 igmpMembershipQuery = uint8(header.IGMPMembershipQuery) 44 igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) 45 igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) 46 igmpLeaveGroup = uint8(header.IGMPLeaveGroup) 47 mldQuery = uint8(header.ICMPv6MulticastListenerQuery) 48 mldReport = uint8(header.ICMPv6MulticastListenerReport) 49 mldDone = uint8(header.ICMPv6MulticastListenerDone) 50 51 maxUnsolicitedReports = 2 52 ) 53 54 var ( 55 stackIPv4Addr = testutil.MustParse4("10.0.0.1") 56 linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1") 57 linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2") 58 59 ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3") 60 ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4") 61 ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5") 62 ipv6MulticastAddr1 = testutil.MustParse6("ff02::3") 63 ipv6MulticastAddr2 = testutil.MustParse6("ff02::4") 64 ipv6MulticastAddr3 = testutil.MustParse6("ff02::5") 65 ) 66 67 var ( 68 // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the 69 // NIC will wait before sending an unsolicited report after joining a 70 // multicast group, in deciseconds. 71 unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { 72 const decisecond = time.Second / 10 73 if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { 74 panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) 75 } 76 return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) 77 }() 78 79 ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1) 80 ) 81 82 // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet 83 // sent to the provided address with the passed fields set. 84 func validateMLDPacket(t *testing.T, p *stack.PacketBuffer, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { 85 t.Helper() 86 87 payload := stack.PayloadSince(p.NetworkHeader()) 88 defer payload.Release() 89 checker.IPv6WithExtHdr(t, payload, 90 checker.IPv6ExtHdr( 91 checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), 92 ), 93 checker.SrcAddr(linkLocalIPv6Addr1), 94 checker.DstAddr(remoteAddress), 95 // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. 96 checker.TTL(1), 97 checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, 98 checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), 99 checker.MLDMulticastAddress(groupAddress), 100 ), 101 ) 102 } 103 104 func validateMLDv2ReportPacket(t *testing.T, p *stack.PacketBuffer, addrs []tcpip.Address, recordType header.MLDv2ReportRecordType) { 105 t.Helper() 106 payload := stack.PayloadSince(p.NetworkHeader()) 107 defer payload.Release() 108 iptestutil.ValidateMLDv2Report(t, payload, linkLocalIPv6Addr1, addrs, recordType) 109 } 110 111 // validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet 112 // sent to the provided address with the passed fields set. 113 func validateIGMPPacket(t *testing.T, p *stack.PacketBuffer, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { 114 t.Helper() 115 116 payload := stack.PayloadSince(p.NetworkHeader()) 117 defer payload.Release() 118 checker.IPv4(t, payload, 119 checker.SrcAddr(stackIPv4Addr), 120 checker.DstAddr(remoteAddress), 121 // TTL for an IGMP message must be 1 as per RFC 2236 section 2. 122 checker.TTL(1), 123 checker.IPv4RouterAlert(), 124 checker.IGMP( 125 checker.IGMPType(header.IGMPType(igmpType)), 126 checker.IGMPMaxRespTime(header.DecisecondToDuration(uint16(maxRespTime))), 127 checker.IGMPGroupAddress(groupAddress), 128 ), 129 ) 130 } 131 132 func validateIGMPv3ReportPacket(t *testing.T, p *stack.PacketBuffer, addrs []tcpip.Address, recordType header.IGMPv3ReportRecordType) { 133 t.Helper() 134 135 payload := stack.PayloadSince(p.NetworkHeader()) 136 defer payload.Release() 137 iptestutil.ValidateIGMPv3Report(t, payload, stackIPv4Addr, addrs, recordType) 138 } 139 140 type multicastTestContext struct { 141 s *stack.Stack 142 e *channel.Endpoint 143 clock *faketime.ManualClock 144 } 145 146 func newMulticastTestContext(t *testing.T, v4, mgpEnabled bool) multicastTestContext { 147 t.Helper() 148 149 e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) 150 s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) 151 return multicastTestContext{ 152 s: s, 153 e: e, 154 clock: clock, 155 } 156 } 157 158 func (ctx *multicastTestContext) cleanup() { 159 ctx.s.Close() 160 ctx.s.Wait() 161 ctx.e.Close() 162 refs.DoRepeatedLeakCheck() 163 } 164 165 func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { 166 t.Helper() 167 168 igmpEnabled := v4 && mgpEnabled 169 mldEnabled := !v4 && mgpEnabled 170 171 clock := faketime.NewManualClock() 172 s := stack.New(stack.Options{ 173 NetworkProtocols: []stack.NetworkProtocolFactory{ 174 ipv4.NewProtocolWithOptions(ipv4.Options{ 175 IGMP: ipv4.IGMPOptions{ 176 Enabled: igmpEnabled, 177 }, 178 }), 179 ipv6.NewProtocolWithOptions(ipv6.Options{ 180 MLD: ipv6.MLDOptions{ 181 Enabled: mldEnabled, 182 }, 183 }), 184 }, 185 Clock: clock, 186 }) 187 if err := s.CreateNIC(nicID, e); err != nil { 188 t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) 189 } 190 addr := tcpip.ProtocolAddress{ 191 Protocol: ipv4.ProtocolNumber, 192 AddressWithPrefix: tcpip.AddressWithPrefix{ 193 Address: stackIPv4Addr, 194 PrefixLen: defaultIPv4PrefixLength, 195 }, 196 } 197 if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { 198 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) 199 } 200 protocolAddr := tcpip.ProtocolAddress{ 201 Protocol: ipv6.ProtocolNumber, 202 AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(), 203 } 204 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 205 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) 206 } 207 208 return s, clock 209 } 210 211 // checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join 212 // when it is created with an IPv6 address. 213 // 214 // To not interfere with tests, checkInitialIPv6Groups will leave the added 215 // address's solicited node multicast group so that the tests can all assume 216 // the NIC has not joined any IPv6 groups. 217 func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) uint64 { 218 t.Helper() 219 220 var reportCounter uint64 221 222 reportCounter++ 223 iptestutil.CheckMLDv2Stats(t, s, 0, 0, reportCounter) 224 if p := e.Read(); p == nil { 225 t.Fatal("expected a report message to be sent") 226 } else { 227 v := stack.PayloadSince(p.NetworkHeader()) 228 iptestutil.ValidateMLDv2Report(t, v, linkLocalIPv6Addr1, []tcpip.Address{ipv6AddrSNMC}, header.MLDv2ReportRecordChangeToExcludeMode) 229 v.Release() 230 p.DecRef() 231 } 232 233 // Leave the group to not affect the tests. This is fine since we are not 234 // testing DAD or the solicited node address specifically. 235 if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { 236 t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) 237 } 238 for i := 0; i < 2; i++ { 239 reportCounter++ 240 iptestutil.CheckMLDv2Stats(t, s, 0, 0, reportCounter) 241 if p := e.Read(); p == nil { 242 t.Fatal("expected a report message to be sent") 243 } else { 244 v := stack.PayloadSince(p.NetworkHeader()) 245 iptestutil.ValidateMLDv2Report(t, v, linkLocalIPv6Addr1, []tcpip.Address{ipv6AddrSNMC}, header.MLDv2ReportRecordChangeToIncludeMode) 246 v.Release() 247 p.DecRef() 248 } 249 250 clock.Advance(ipv6.UnsolicitedReportIntervalMax) 251 } 252 253 // Should not send any more packets. 254 clock.Advance(time.Hour) 255 if p := e.Read(); p != nil { 256 t.Fatalf("sent unexpected packet = %#v", p) 257 } 258 259 return reportCounter 260 } 261 262 // createAndInjectIGMPPacket creates and injects an IGMP packet with the 263 // specified fields. 264 func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address, extraLength int) { 265 options := header.IPv4OptionsSerializer{ 266 &header.IPv4SerializableRouterAlertOption{}, 267 } 268 buf := make([]byte, header.IPv4MinimumSize+int(options.Length())+header.IGMPQueryMinimumSize+extraLength) 269 ip := header.IPv4(buf) 270 ip.Encode(&header.IPv4Fields{ 271 TotalLength: uint16(len(buf)), 272 TTL: header.IGMPTTL, 273 Protocol: uint8(header.IGMPProtocolNumber), 274 SrcAddr: remoteIPv4Addr, 275 DstAddr: header.IPv4AllSystems, 276 Options: options, 277 }) 278 ip.SetChecksum(^ip.CalculateChecksum()) 279 280 igmp := header.IGMP(ip.Payload()) 281 igmp.SetType(header.IGMPType(igmpType)) 282 igmp.SetMaxRespTime(maxRespTime) 283 igmp.SetGroupAddress(groupAddress) 284 igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) 285 286 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 287 Payload: buffer.MakeWithData(buf), 288 }) 289 e.InjectInbound(ipv4.ProtocolNumber, pkt) 290 pkt.DecRef() 291 } 292 293 // createAndInjectMLDPacket creates and injects an MLD packet with the 294 // specified fields. 295 func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address, extraLength int) { 296 extensionHeaders := header.IPv6ExtHdrSerializer{ 297 header.IPv6SerializableHopByHopExtHdr{ 298 &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, 299 }, 300 } 301 302 extensionHeadersLength := extensionHeaders.Length() 303 payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize + extraLength 304 buf := make([]byte, header.IPv6MinimumSize+payloadLength) 305 306 ip := header.IPv6(buf) 307 ip.Encode(&header.IPv6Fields{ 308 PayloadLength: uint16(payloadLength), 309 HopLimit: header.MLDHopLimit, 310 TransportProtocol: header.ICMPv6ProtocolNumber, 311 SrcAddr: linkLocalIPv6Addr2, 312 DstAddr: header.IPv6AllNodesMulticastAddress, 313 ExtensionHeaders: extensionHeaders, 314 }) 315 316 icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) 317 icmp.SetType(header.ICMPv6Type(mldType)) 318 mld := header.MLD(icmp.MessageBody()) 319 mld.SetMaximumResponseDelay(uint16(maxRespDelay)) 320 mld.SetMulticastAddress(groupAddress) 321 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 322 Header: icmp, 323 Src: linkLocalIPv6Addr2, 324 Dst: header.IPv6AllNodesMulticastAddress, 325 })) 326 327 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 328 Payload: buffer.MakeWithData(buf), 329 }) 330 e.InjectInbound(ipv6.ProtocolNumber, pkt) 331 pkt.DecRef() 332 } 333 334 // TestMGPDisabled tests that the multicast group protocol is not enabled by 335 // default. 336 func TestMGPDisabled(t *testing.T) { 337 tests := []struct { 338 name string 339 protoNum tcpip.NetworkProtocolNumber 340 multicastAddr tcpip.Address 341 sentReportStat func(*stack.Stack) *tcpip.StatCounter 342 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 343 rxQuery func(*channel.Endpoint) 344 }{ 345 { 346 name: "IGMP", 347 protoNum: ipv4.ProtocolNumber, 348 multicastAddr: ipv4MulticastAddr1, 349 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 350 return s.Stats().IGMP.PacketsSent.V2MembershipReport 351 }, 352 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 353 return s.Stats().IGMP.PacketsReceived.MembershipQuery 354 }, 355 rxQuery: func(e *channel.Endpoint) { 356 createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any, 0 /* extraLength */) 357 }, 358 }, 359 { 360 name: "MLD", 361 protoNum: ipv6.ProtocolNumber, 362 multicastAddr: ipv6MulticastAddr1, 363 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 364 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 365 }, 366 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 367 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 368 }, 369 rxQuery: func(e *channel.Endpoint) { 370 createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any, 0 /* extraLength */) 371 }, 372 }, 373 } 374 375 for _, test := range tests { 376 t.Run(test.name, func(t *testing.T) { 377 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) 378 defer ctx.cleanup() 379 s := ctx.s 380 e := ctx.e 381 clock := ctx.clock 382 383 // This NIC may join multicast groups when it is enabled but since MGP is 384 // disabled, no reports should be sent. 385 sentReportStat := test.sentReportStat(s) 386 if got := sentReportStat.Value(); got != 0 { 387 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 388 } 389 clock.Advance(time.Hour) 390 if p := e.Read(); p != nil { 391 t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p) 392 } 393 394 // Test joining a specific group explicitly and verify that no reports are 395 // sent. 396 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 397 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 398 } 399 if got := sentReportStat.Value(); got != 0 { 400 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 401 } 402 clock.Advance(time.Hour) 403 if p := e.Read(); p != nil { 404 t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p) 405 } 406 407 // Inject a general query message. This should only trigger a report to be 408 // sent if the MGP was enabled. 409 test.rxQuery(e) 410 if got := test.receivedQueryStat(s).Value(); got != 1 { 411 t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) 412 } 413 clock.Advance(time.Hour) 414 if p := e.Read(); p != nil { 415 t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p) 416 } 417 }) 418 } 419 } 420 421 func TestMGPReceiveCounters(t *testing.T) { 422 tests := []struct { 423 name string 424 headerType uint8 425 maxRespTime byte 426 groupAddress tcpip.Address 427 statCounter func(*stack.Stack) *tcpip.StatCounter 428 rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address, int) 429 }{ 430 { 431 name: "IGMP Membership Query", 432 headerType: igmpMembershipQuery, 433 maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, 434 groupAddress: header.IPv4Any, 435 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 436 return s.Stats().IGMP.PacketsReceived.MembershipQuery 437 }, 438 rxMGPkt: createAndInjectIGMPPacket, 439 }, 440 { 441 name: "IGMPv1 Membership Report", 442 headerType: igmpv1MembershipReport, 443 maxRespTime: 0, 444 groupAddress: header.IPv4AllSystems, 445 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 446 return s.Stats().IGMP.PacketsReceived.V1MembershipReport 447 }, 448 rxMGPkt: createAndInjectIGMPPacket, 449 }, 450 { 451 name: "IGMPv2 Membership Report", 452 headerType: igmpv2MembershipReport, 453 maxRespTime: 0, 454 groupAddress: header.IPv4AllSystems, 455 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 456 return s.Stats().IGMP.PacketsReceived.V2MembershipReport 457 }, 458 rxMGPkt: createAndInjectIGMPPacket, 459 }, 460 { 461 name: "IGMP Leave Group", 462 headerType: igmpLeaveGroup, 463 maxRespTime: 0, 464 groupAddress: header.IPv4AllRoutersGroup, 465 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 466 return s.Stats().IGMP.PacketsReceived.LeaveGroup 467 }, 468 rxMGPkt: createAndInjectIGMPPacket, 469 }, 470 { 471 name: "MLD Query", 472 headerType: mldQuery, 473 maxRespTime: 0, 474 groupAddress: header.IPv6Any, 475 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 476 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 477 }, 478 rxMGPkt: createAndInjectMLDPacket, 479 }, 480 { 481 name: "MLD Report", 482 headerType: mldReport, 483 maxRespTime: 0, 484 groupAddress: header.IPv6Any, 485 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 486 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport 487 }, 488 rxMGPkt: createAndInjectMLDPacket, 489 }, 490 { 491 name: "MLD Done", 492 headerType: mldDone, 493 maxRespTime: 0, 494 groupAddress: header.IPv6Any, 495 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 496 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone 497 }, 498 rxMGPkt: createAndInjectMLDPacket, 499 }, 500 } 501 502 for _, test := range tests { 503 t.Run(test.name, func(t *testing.T) { 504 ctx := newMulticastTestContext(t, test.groupAddress.Len() == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) 505 defer ctx.cleanup() 506 507 test.rxMGPkt(ctx.e, test.headerType, test.maxRespTime, test.groupAddress, 0 /* extraLength */) 508 if got := test.statCounter(ctx.s).Value(); got != 1 { 509 t.Fatalf("got %s received = %d, want = 1", test.name, got) 510 } 511 }) 512 } 513 } 514 515 // TestMGPJoinGroup tests that when explicitly joining a multicast group, the 516 // stack schedules and sends correct Membership Reports. 517 func TestMGPJoinGroup(t *testing.T) { 518 type subTest struct { 519 name string 520 enterVersion func(e *channel.Endpoint) 521 validateReport func(*testing.T, *stack.PacketBuffer) 522 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 523 } 524 525 tests := []struct { 526 name string 527 protoNum tcpip.NetworkProtocolNumber 528 multicastAddr tcpip.Address 529 maxUnsolicitedResponseDelay time.Duration 530 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 531 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64 532 subTests []subTest 533 }{ 534 { 535 name: "IGMP", 536 protoNum: ipv4.ProtocolNumber, 537 multicastAddr: ipv4MulticastAddr1, 538 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 539 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 540 return s.Stats().IGMP.PacketsReceived.MembershipQuery 541 }, 542 subTests: []subTest{ 543 { 544 name: "V2", 545 enterVersion: func(e *channel.Endpoint) { 546 // V2 query for unrelated group. 547 createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */) 548 }, 549 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 550 t.Helper() 551 552 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 553 }, 554 checkStats: iptestutil.CheckIGMPv2Stats, 555 }, 556 { 557 name: "V3", 558 enterVersion: func(*channel.Endpoint) {}, 559 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 560 t.Helper() 561 562 validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToExcludeMode) 563 }, 564 checkStats: iptestutil.CheckIGMPv3Stats, 565 }, 566 }, 567 }, 568 { 569 name: "MLD", 570 protoNum: ipv6.ProtocolNumber, 571 multicastAddr: ipv6MulticastAddr1, 572 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 573 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 574 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 575 }, 576 checkInitialGroups: checkInitialIPv6Groups, 577 subTests: []subTest{ 578 { 579 name: "V1", 580 enterVersion: func(e *channel.Endpoint) { 581 // V1 query for unrelated group. 582 createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */) 583 }, 584 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 585 t.Helper() 586 587 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 588 }, 589 checkStats: iptestutil.CheckMLDv1Stats, 590 }, 591 { 592 name: "V2", 593 enterVersion: func(*channel.Endpoint) {}, 594 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 595 t.Helper() 596 597 validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToExcludeMode) 598 }, 599 checkStats: iptestutil.CheckMLDv2Stats, 600 }, 601 }, 602 }, 603 } 604 605 for _, test := range tests { 606 t.Run(test.name, func(t *testing.T) { 607 for _, subTest := range test.subTests { 608 t.Run(subTest.name, func(t *testing.T) { 609 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 610 defer ctx.cleanup() 611 s, e, clock := ctx.s, ctx.e, ctx.clock 612 613 var reportCounter uint64 614 var leaveCounter uint64 615 var reportV2Counter uint64 616 if test.checkInitialGroups != nil { 617 reportV2Counter = test.checkInitialGroups(t, e, s, clock) 618 } 619 620 subTest.enterVersion(e) 621 622 // Test joining a specific address explicitly and verify a Report is sent 623 // immediately. 624 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 625 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 626 } 627 reportCounter++ 628 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 629 if p := e.Read(); p == nil { 630 t.Fatal("expected a report message to be sent") 631 } else { 632 subTest.validateReport(t, p) 633 p.DecRef() 634 } 635 if t.Failed() { 636 t.FailNow() 637 } 638 639 // Verify the second report is sent by the maximum unsolicited response 640 // interval. 641 p := e.Read() 642 if p != nil { 643 t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p) 644 } 645 clock.Advance(test.maxUnsolicitedResponseDelay) 646 reportCounter++ 647 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 648 if p := e.Read(); p == nil { 649 t.Fatal("expected a report message to be sent") 650 } else { 651 subTest.validateReport(t, p) 652 p.DecRef() 653 } 654 655 // Should not send any more packets. 656 clock.Advance(time.Hour) 657 if p := e.Read(); p != nil { 658 t.Fatalf("sent unexpected packet = %#v", p) 659 } 660 }) 661 } 662 }) 663 } 664 } 665 666 // TestMGPLeaveGroup tests that when leaving a previously joined multicast 667 // group the stack sends a leave/done message. 668 func TestMGPLeaveGroup(t *testing.T) { 669 type subTest struct { 670 name string 671 enterVersion func(e *channel.Endpoint) 672 validateReport func(*testing.T, *stack.PacketBuffer) 673 validateLeave func(*testing.T, *stack.PacketBuffer) 674 leaveCount uint8 675 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 676 } 677 678 tests := []struct { 679 name string 680 protoNum tcpip.NetworkProtocolNumber 681 multicastAddr tcpip.Address 682 maxUnsolicitedResponseDelay time.Duration 683 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64 684 subTests []subTest 685 }{ 686 { 687 name: "IGMP", 688 protoNum: ipv4.ProtocolNumber, 689 multicastAddr: ipv4MulticastAddr1, 690 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 691 subTests: []subTest{ 692 { 693 name: "V2", 694 enterVersion: func(e *channel.Endpoint) { 695 // V2 query for unrelated group. 696 createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */) 697 }, 698 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 699 t.Helper() 700 701 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 702 }, 703 validateLeave: func(t *testing.T, p *stack.PacketBuffer) { 704 t.Helper() 705 706 validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) 707 }, 708 leaveCount: 1, 709 checkStats: iptestutil.CheckIGMPv2Stats, 710 }, 711 { 712 name: "V3", 713 enterVersion: func(*channel.Endpoint) {}, 714 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 715 t.Helper() 716 717 validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToExcludeMode) 718 }, 719 validateLeave: func(t *testing.T, p *stack.PacketBuffer) { 720 t.Helper() 721 722 validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToIncludeMode) 723 }, 724 leaveCount: 2, 725 checkStats: iptestutil.CheckIGMPv3Stats, 726 }, 727 }, 728 }, 729 { 730 name: "MLD", 731 protoNum: ipv6.ProtocolNumber, 732 multicastAddr: ipv6MulticastAddr1, 733 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 734 checkInitialGroups: checkInitialIPv6Groups, 735 subTests: []subTest{ 736 { 737 name: "V1", 738 enterVersion: func(e *channel.Endpoint) { 739 // V1 query for unrelated group. 740 createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */) 741 }, 742 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 743 t.Helper() 744 745 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 746 }, 747 validateLeave: func(t *testing.T, p *stack.PacketBuffer) { 748 t.Helper() 749 750 validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) 751 }, 752 leaveCount: 1, 753 checkStats: iptestutil.CheckMLDv1Stats, 754 }, 755 { 756 name: "V2", 757 enterVersion: func(*channel.Endpoint) {}, 758 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 759 t.Helper() 760 761 validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToExcludeMode) 762 }, 763 validateLeave: func(t *testing.T, p *stack.PacketBuffer) { 764 t.Helper() 765 766 validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToIncludeMode) 767 }, 768 leaveCount: 2, 769 checkStats: iptestutil.CheckMLDv2Stats, 770 }, 771 }, 772 }, 773 } 774 775 for _, test := range tests { 776 t.Run(test.name, func(t *testing.T) { 777 for _, subTest := range test.subTests { 778 t.Run(subTest.name, func(t *testing.T) { 779 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 780 defer ctx.cleanup() 781 s, e, clock := ctx.s, ctx.e, ctx.clock 782 783 var reportCounter uint64 784 var leaveCounter uint64 785 var reportV2Counter uint64 786 if test.checkInitialGroups != nil { 787 reportV2Counter = test.checkInitialGroups(t, e, s, clock) 788 } 789 790 subTest.enterVersion(e) 791 792 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 793 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 794 } 795 reportCounter++ 796 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 797 if p := e.Read(); p == nil { 798 t.Fatal("expected a report message to be sent") 799 } else { 800 subTest.validateReport(t, p) 801 p.DecRef() 802 } 803 if t.Failed() { 804 t.FailNow() 805 } 806 807 // Leaving the group should trigger an leave/done message to be sent. 808 if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 809 t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) 810 } 811 for i := subTest.leaveCount; i > 0; i-- { 812 leaveCounter++ 813 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 814 if p := e.Read(); p == nil { 815 t.Fatal("expected a leave message to be sent") 816 } else { 817 subTest.validateLeave(t, p) 818 p.DecRef() 819 } 820 clock.Advance(test.maxUnsolicitedResponseDelay) 821 } 822 823 // Should not send any more packets. 824 clock.Advance(time.Hour) 825 if p := e.Read(); p != nil { 826 t.Fatalf("sent unexpected packet = %#v", p) 827 } 828 }) 829 } 830 }) 831 } 832 } 833 834 // TestMGPQueryMessages tests that a report is sent in response to query 835 // messages. 836 func TestMGPQueryMessages(t *testing.T) { 837 type subTest struct { 838 name string 839 enterVersion func(e *channel.Endpoint) 840 validateReport func(*testing.T, *stack.PacketBuffer, bool) 841 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 842 rxQuery func(*channel.Endpoint, uint8, tcpip.Address) 843 } 844 845 tests := []struct { 846 name string 847 protoNum tcpip.NetworkProtocolNumber 848 multicastAddr tcpip.Address 849 maxUnsolicitedResponseDelay time.Duration 850 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 851 maxRespTimeToDuration func(uint16) time.Duration 852 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64 853 subTests []subTest 854 }{ 855 { 856 name: "IGMP", 857 protoNum: ipv4.ProtocolNumber, 858 multicastAddr: ipv4MulticastAddr1, 859 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 860 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 861 return s.Stats().IGMP.PacketsReceived.MembershipQuery 862 }, 863 maxRespTimeToDuration: header.DecisecondToDuration, 864 subTests: []subTest{ 865 { 866 name: "V2", 867 enterVersion: func(e *channel.Endpoint) { 868 // V2 query for unrelated group. 869 createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */) 870 }, 871 validateReport: func(t *testing.T, p *stack.PacketBuffer, _ bool) { 872 t.Helper() 873 874 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 875 }, 876 rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { 877 createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress, 0 /* extraLength */) 878 }, 879 checkStats: iptestutil.CheckIGMPv2Stats, 880 }, 881 { 882 name: "V3", 883 enterVersion: func(*channel.Endpoint) {}, 884 validateReport: func(t *testing.T, p *stack.PacketBuffer, queryResponse bool) { 885 t.Helper() 886 887 recordType := header.IGMPv3ReportRecordChangeToExcludeMode 888 if queryResponse { 889 recordType = header.IGMPv3ReportRecordModeIsExclude 890 } 891 892 validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, recordType) 893 }, 894 rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { 895 createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress, header.IGMPv3QueryMinimumSize-header.IGMPQueryMinimumSize /* extraLength */) 896 }, 897 checkStats: iptestutil.CheckIGMPv3Stats, 898 }, 899 }, 900 }, 901 { 902 name: "MLD", 903 protoNum: ipv6.ProtocolNumber, 904 multicastAddr: ipv6MulticastAddr1, 905 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 906 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 907 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 908 }, 909 maxRespTimeToDuration: func(d uint16) time.Duration { 910 return time.Duration(d) * time.Millisecond 911 }, 912 checkInitialGroups: checkInitialIPv6Groups, 913 subTests: []subTest{ 914 { 915 name: "V1", 916 enterVersion: func(e *channel.Endpoint) { 917 // V1 query for unrelated group. 918 createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */) 919 }, 920 validateReport: func(t *testing.T, p *stack.PacketBuffer, _ bool) { 921 t.Helper() 922 923 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 924 }, 925 rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { 926 createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress, 0 /* extraLength */) 927 }, 928 checkStats: iptestutil.CheckMLDv1Stats, 929 }, 930 { 931 name: "V2", 932 enterVersion: func(*channel.Endpoint) {}, 933 validateReport: func(t *testing.T, p *stack.PacketBuffer, queryResponse bool) { 934 t.Helper() 935 936 recordType := header.MLDv2ReportRecordChangeToExcludeMode 937 if queryResponse { 938 recordType = header.MLDv2ReportRecordModeIsExclude 939 } 940 941 validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, recordType) 942 }, 943 rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { 944 createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress, header.MLDv2QueryMinimumSize-header.MLDMinimumSize /* extraLength */) 945 }, 946 checkStats: iptestutil.CheckMLDv2Stats, 947 }, 948 }, 949 }, 950 } 951 952 for _, test := range tests { 953 t.Run(test.name, func(t *testing.T) { 954 addrTests := []struct { 955 name string 956 multicastAddr tcpip.Address 957 expectReport bool 958 }{ 959 { 960 name: "Unspecified", 961 multicastAddr: tcpip.AddrFromSlice([]byte(strings.Repeat("\x00", test.multicastAddr.Len()))), 962 expectReport: true, 963 }, 964 { 965 name: "Specified", 966 multicastAddr: test.multicastAddr, 967 expectReport: true, 968 }, 969 { 970 name: "Specified other address", 971 multicastAddr: func() tcpip.Address { 972 addrCopy := test.multicastAddr 973 addrBytes := addrCopy.AsSlice() 974 addrBytes[len(addrBytes)-1]++ 975 return tcpip.AddrFromSlice(addrBytes) 976 }(), 977 expectReport: false, 978 }, 979 } 980 981 for _, addrTest := range addrTests { 982 t.Run(addrTest.name, func(t *testing.T) { 983 for _, subTest := range test.subTests { 984 t.Run(subTest.name, func(t *testing.T) { 985 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 986 defer ctx.cleanup() 987 s, e, clock := ctx.s, ctx.e, ctx.clock 988 989 var reportCounter uint64 990 var leaveCounter uint64 991 var reportV2Counter uint64 992 if test.checkInitialGroups != nil { 993 reportV2Counter = test.checkInitialGroups(t, e, s, clock) 994 } 995 996 subTest.enterVersion(e) 997 998 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 999 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 1000 } 1001 for i := 0; i < maxUnsolicitedReports; i++ { 1002 reportCounter++ 1003 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1004 if p := e.Read(); p == nil { 1005 t.Fatalf("expected %d-th report message to be sent", i) 1006 } else { 1007 subTest.validateReport(t, p, false /* queryResponse */) 1008 p.DecRef() 1009 } 1010 clock.Advance(test.maxUnsolicitedResponseDelay) 1011 } 1012 if t.Failed() { 1013 t.FailNow() 1014 } 1015 1016 // Should not send any more packets until a query. 1017 clock.Advance(time.Hour) 1018 if p := e.Read(); p != nil { 1019 t.Fatalf("sent unexpected packet = %#v", p) 1020 } 1021 1022 // Receive a query message which should trigger a report to be sent at 1023 // some time before the maximum response time if the report is 1024 // targeted at the host. 1025 const maxRespTime = 100 1026 subTest.rxQuery(e, maxRespTime, addrTest.multicastAddr) 1027 if p := e.Read(); p != nil { 1028 t.Fatalf("sent unexpected packet = %#v", p) 1029 } 1030 1031 if addrTest.expectReport { 1032 clock.Advance(test.maxRespTimeToDuration(maxRespTime)) 1033 reportCounter++ 1034 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1035 if p := e.Read(); p == nil { 1036 t.Fatal("expected a report message to be sent") 1037 } else { 1038 subTest.validateReport(t, p, true /* queryResponse */) 1039 p.DecRef() 1040 } 1041 } 1042 1043 // Should not send any more packets. 1044 clock.Advance(time.Hour) 1045 if p := e.Read(); p != nil { 1046 t.Fatalf("sent unexpected packet = %#v", p) 1047 } 1048 }) 1049 } 1050 }) 1051 } 1052 }) 1053 } 1054 } 1055 1056 // TestMGPQueryMessages tests that no further reports or leave/done messages 1057 // are sent after receiving a report. 1058 func TestMGPReportMessages(t *testing.T) { 1059 type subTest struct { 1060 name string 1061 enterVersion func(e *channel.Endpoint) 1062 validateReport func(*testing.T, *stack.PacketBuffer) 1063 validateLeave func(*testing.T, *stack.PacketBuffer) 1064 leaveCount uint8 1065 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 1066 } 1067 1068 tests := []struct { 1069 name string 1070 protoNum tcpip.NetworkProtocolNumber 1071 multicastAddr tcpip.Address 1072 maxUnsolicitedResponseDelay time.Duration 1073 rxReport func(*channel.Endpoint) 1074 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64 1075 subTests []subTest 1076 }{ 1077 { 1078 name: "IGMP", 1079 protoNum: ipv4.ProtocolNumber, 1080 multicastAddr: ipv4MulticastAddr1, 1081 rxReport: func(e *channel.Endpoint) { 1082 createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1, 0 /* extraLength */) 1083 }, 1084 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 1085 subTests: []subTest{ 1086 { 1087 name: "V2", 1088 enterVersion: func(e *channel.Endpoint) { 1089 // V2 query for unrelated group. 1090 createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */) 1091 }, 1092 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 1093 t.Helper() 1094 1095 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 1096 }, 1097 leaveCount: 0, 1098 checkStats: iptestutil.CheckIGMPv2Stats, 1099 }, 1100 { 1101 name: "V3", 1102 enterVersion: func(*channel.Endpoint) {}, 1103 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 1104 t.Helper() 1105 1106 validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToExcludeMode) 1107 }, 1108 validateLeave: func(t *testing.T, p *stack.PacketBuffer) { 1109 t.Helper() 1110 1111 validateIGMPv3ReportPacket(t, p, []tcpip.Address{ipv4MulticastAddr1}, header.IGMPv3ReportRecordChangeToIncludeMode) 1112 }, 1113 leaveCount: 2, 1114 checkStats: iptestutil.CheckIGMPv3Stats, 1115 }, 1116 }, 1117 }, 1118 { 1119 name: "MLD", 1120 protoNum: ipv6.ProtocolNumber, 1121 multicastAddr: ipv6MulticastAddr1, 1122 rxReport: func(e *channel.Endpoint) { 1123 createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1, 0 /* extraLength */) 1124 }, 1125 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 1126 checkInitialGroups: checkInitialIPv6Groups, 1127 subTests: []subTest{ 1128 { 1129 name: "V1", 1130 enterVersion: func(e *channel.Endpoint) { 1131 // V1 query for unrelated group. 1132 createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */) 1133 }, 1134 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 1135 t.Helper() 1136 1137 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 1138 }, 1139 leaveCount: 0, 1140 checkStats: iptestutil.CheckMLDv1Stats, 1141 }, 1142 { 1143 name: "V2", 1144 enterVersion: func(*channel.Endpoint) {}, 1145 validateReport: func(t *testing.T, p *stack.PacketBuffer) { 1146 t.Helper() 1147 1148 validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToExcludeMode) 1149 }, 1150 validateLeave: func(t *testing.T, p *stack.PacketBuffer) { 1151 t.Helper() 1152 1153 validateMLDv2ReportPacket(t, p, []tcpip.Address{ipv6MulticastAddr1}, header.MLDv2ReportRecordChangeToIncludeMode) 1154 }, 1155 leaveCount: 2, 1156 checkStats: iptestutil.CheckMLDv2Stats, 1157 }, 1158 }, 1159 }, 1160 } 1161 1162 for _, test := range tests { 1163 t.Run(test.name, func(t *testing.T) { 1164 for _, subTest := range test.subTests { 1165 t.Run(subTest.name, func(t *testing.T) { 1166 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 1167 defer ctx.cleanup() 1168 s, e, clock := ctx.s, ctx.e, ctx.clock 1169 1170 var reportCounter uint64 1171 var leaveCounter uint64 1172 var reportV2Counter uint64 1173 if test.checkInitialGroups != nil { 1174 reportV2Counter = test.checkInitialGroups(t, e, s, clock) 1175 } 1176 1177 subTest.enterVersion(e) 1178 1179 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 1180 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 1181 } 1182 reportCounter++ 1183 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1184 if p := e.Read(); p == nil { 1185 t.Fatal("expected a report message to be sent") 1186 } else { 1187 subTest.validateReport(t, p) 1188 p.DecRef() 1189 } 1190 if t.Failed() { 1191 t.FailNow() 1192 } 1193 1194 // Receiving a report for a group we joined should cancel any further 1195 // reports. 1196 test.rxReport(e) 1197 clock.Advance(time.Hour) 1198 subTest.enterVersion(e) 1199 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1200 if p := e.Read(); p != nil { 1201 t.Errorf("sent unexpected packet = %#v", p) 1202 } 1203 if t.Failed() { 1204 t.FailNow() 1205 } 1206 1207 // Leaving a group after getting a report should not send a leave/done 1208 // message. 1209 if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 1210 t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) 1211 } 1212 for i := subTest.leaveCount; i > 0; i-- { 1213 leaveCounter++ 1214 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1215 if p := e.Read(); p == nil { 1216 t.Fatal("expected a leave message to be sent") 1217 } else { 1218 subTest.validateLeave(t, p) 1219 p.DecRef() 1220 } 1221 clock.Advance(test.maxUnsolicitedResponseDelay) 1222 } 1223 1224 // Should not send any more packets. 1225 clock.Advance(time.Hour) 1226 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1227 if p := e.Read(); p != nil { 1228 t.Fatalf("sent unexpected packet = %#v", p) 1229 } 1230 }) 1231 } 1232 }) 1233 } 1234 } 1235 1236 func TestMGPWithNICLifecycle(t *testing.T) { 1237 type subTest struct { 1238 name string 1239 v1Compatibility bool 1240 enterVersion func(e *channel.Endpoint) 1241 validateReport func(*testing.T, *stack.PacketBuffer, tcpip.Address) 1242 validateLeave func(*testing.T, *channel.Endpoint, []tcpip.Address) 1243 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 1244 } 1245 1246 tests := []struct { 1247 name string 1248 protoNum tcpip.NetworkProtocolNumber 1249 multicastAddrs []tcpip.Address 1250 finalMulticastAddr tcpip.Address 1251 maxUnsolicitedResponseDelay time.Duration 1252 sentReportStat func(*stack.Stack) *tcpip.StatCounter 1253 sentLeaveStat func(*stack.Stack) *tcpip.StatCounter 1254 validateReport func(*testing.T, *channel.Endpoint, []tcpip.Address) 1255 validateLeave func(*testing.T, *stack.PacketBuffer, tcpip.Address) 1256 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64 1257 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 1258 subTests []subTest 1259 }{ 1260 { 1261 name: "IGMP", 1262 protoNum: ipv4.ProtocolNumber, 1263 multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, 1264 finalMulticastAddr: ipv4MulticastAddr3, 1265 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 1266 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1267 return s.Stats().IGMP.PacketsSent.V2MembershipReport 1268 }, 1269 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 1270 return s.Stats().IGMP.PacketsSent.LeaveGroup 1271 }, 1272 validateReport: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1273 t.Helper() 1274 iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, stackIPv4Addr, addrs, header.IGMPv3ReportRecordChangeToExcludeMode) 1275 }, 1276 validateLeave: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1277 t.Helper() 1278 1279 validateIGMPv3ReportPacket(t, p, []tcpip.Address{addr}, header.IGMPv3ReportRecordChangeToIncludeMode) 1280 }, 1281 checkStats: iptestutil.CheckIGMPv3Stats, 1282 subTests: []subTest{ 1283 { 1284 name: "V2", 1285 v1Compatibility: true, 1286 enterVersion: func(e *channel.Endpoint) { 1287 // V2 query for unrelated group. 1288 createAndInjectIGMPPacket(e, igmpMembershipQuery, 1, ipv4MulticastAddr3, 0 /* extraLength */) 1289 }, 1290 validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1291 t.Helper() 1292 1293 validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) 1294 }, 1295 validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1296 t.Helper() 1297 iptestutil.ValidMultipleIGMPv2ReportLeaves(t, e, stackIPv4Addr, addrs, true /* leave */) 1298 }, 1299 checkStats: iptestutil.CheckIGMPv2Stats, 1300 }, 1301 { 1302 name: "V3", 1303 v1Compatibility: false, 1304 enterVersion: func(*channel.Endpoint) {}, 1305 validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1306 t.Helper() 1307 1308 validateIGMPv3ReportPacket(t, p, []tcpip.Address{addr}, header.IGMPv3ReportRecordChangeToExcludeMode) 1309 }, 1310 validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1311 t.Helper() 1312 iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, stackIPv4Addr, addrs, header.IGMPv3ReportRecordChangeToIncludeMode) 1313 }, 1314 checkStats: iptestutil.CheckIGMPv3Stats, 1315 }, 1316 }, 1317 }, 1318 { 1319 name: "MLD", 1320 protoNum: ipv6.ProtocolNumber, 1321 multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, 1322 finalMulticastAddr: ipv6MulticastAddr3, 1323 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 1324 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1325 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 1326 }, 1327 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 1328 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone 1329 }, 1330 validateReport: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1331 t.Helper() 1332 1333 iptestutil.ValidateMLDv2RecordsAcrossReports(t, e, linkLocalIPv6Addr1, addrs, header.MLDv2ReportRecordChangeToExcludeMode) 1334 }, 1335 validateLeave: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1336 t.Helper() 1337 1338 validateMLDv2ReportPacket(t, p, []tcpip.Address{addr}, header.MLDv2ReportRecordChangeToIncludeMode) 1339 }, 1340 checkInitialGroups: checkInitialIPv6Groups, 1341 checkStats: iptestutil.CheckMLDv2Stats, 1342 subTests: []subTest{ 1343 { 1344 name: "V1", 1345 v1Compatibility: true, 1346 enterVersion: func(e *channel.Endpoint) { 1347 // V1 query for unrelated group. 1348 createAndInjectMLDPacket(e, mldQuery, 0, ipv6MulticastAddr3, 0 /* extraLength */) 1349 }, 1350 validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1351 t.Helper() 1352 1353 validateMLDPacket(t, p, addr, mldReport, 0, addr) 1354 }, 1355 validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1356 t.Helper() 1357 1358 iptestutil.ValidMultipleMLDv1ReportLeaves(t, e, linkLocalIPv6Addr1, addrs, true /* leave */) 1359 }, 1360 checkStats: iptestutil.CheckMLDv1Stats, 1361 }, 1362 { 1363 name: "V2", 1364 v1Compatibility: false, 1365 enterVersion: func(*channel.Endpoint) {}, 1366 validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1367 t.Helper() 1368 1369 validateMLDv2ReportPacket(t, p, []tcpip.Address{addr}, header.MLDv2ReportRecordChangeToExcludeMode) 1370 }, 1371 validateLeave: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1372 t.Helper() 1373 1374 iptestutil.ValidateMLDv2RecordsAcrossReports(t, e, linkLocalIPv6Addr1, addrs, header.MLDv2ReportRecordChangeToIncludeMode) 1375 }, 1376 checkStats: iptestutil.CheckMLDv2Stats, 1377 }, 1378 }, 1379 }, 1380 } 1381 1382 for _, test := range tests { 1383 t.Run(test.name, func(t *testing.T) { 1384 for _, subTest := range test.subTests { 1385 t.Run(subTest.name, func(t *testing.T) { 1386 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 1387 defer ctx.cleanup() 1388 s, e, clock := ctx.s, ctx.e, ctx.clock 1389 1390 var reportCounter uint64 1391 var leaveCounter uint64 1392 var reportV2Counter uint64 1393 if test.checkInitialGroups != nil { 1394 reportV2Counter = test.checkInitialGroups(t, e, s, clock) 1395 } 1396 1397 subTest.enterVersion(e) 1398 1399 for _, a := range test.multicastAddrs { 1400 if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { 1401 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) 1402 } 1403 reportCounter++ 1404 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1405 if p := e.Read(); p == nil { 1406 t.Fatalf("expected a report message to be sent for %s", a) 1407 } else { 1408 subTest.validateReport(t, p, a) 1409 p.DecRef() 1410 } 1411 } 1412 if t.Failed() { 1413 t.FailNow() 1414 } 1415 1416 // Leave messages should be sent for the joined groups when the NIC is 1417 // disabled. 1418 if err := s.DisableNIC(nicID); err != nil { 1419 t.Fatalf("DisableNIC(%d): %s", nicID, err) 1420 } 1421 { 1422 numMessages := 1 1423 if subTest.v1Compatibility { 1424 numMessages = len(test.multicastAddrs) 1425 } 1426 leaveCounter += uint64(numMessages) 1427 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1428 subTest.validateLeave(t, e, test.multicastAddrs) 1429 } 1430 if t.Failed() { 1431 t.FailNow() 1432 } 1433 1434 // Reports should be sent for the joined groups when the NIC is enabled. 1435 if err := s.EnableNIC(nicID); err != nil { 1436 t.Fatalf("EnableNIC(%d): %s", nicID, err) 1437 } 1438 reportV2Counter++ 1439 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1440 test.validateReport(t, e, test.multicastAddrs) 1441 if t.Failed() { 1442 t.FailNow() 1443 } 1444 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1445 1446 // Joining/leaving a group while disabled should not send any messages. 1447 if err := s.DisableNIC(nicID); err != nil { 1448 t.Fatalf("DisableNIC(%d): %s", nicID, err) 1449 } 1450 reportV2Counter++ 1451 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1452 if p := e.Read(); p == nil { 1453 t.Fatal("expected leave message to be sent") 1454 } else { 1455 p.DecRef() 1456 } 1457 for _, a := range test.multicastAddrs { 1458 if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { 1459 t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) 1460 } 1461 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1462 if p := e.Read(); p != nil { 1463 t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p) 1464 } 1465 } 1466 if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { 1467 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) 1468 } 1469 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1470 if p := e.Read(); p != nil { 1471 t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p) 1472 } 1473 1474 // A report should only be sent for the group we last joined after 1475 // enabling the NIC since the original groups were all left. 1476 if err := s.EnableNIC(nicID); err != nil { 1477 t.Fatalf("EnableNIC(%d): %s", nicID, err) 1478 } 1479 reportV2Counter++ 1480 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1481 test.validateReport(t, e, []tcpip.Address{test.finalMulticastAddr}) 1482 1483 clock.Advance(test.maxUnsolicitedResponseDelay) 1484 reportV2Counter++ 1485 subTest.checkStats(t, s, reportCounter, leaveCounter, reportV2Counter) 1486 test.validateReport(t, e, []tcpip.Address{test.finalMulticastAddr}) 1487 1488 // Should not send any more packets. 1489 clock.Advance(time.Hour) 1490 if p := e.Read(); p != nil { 1491 t.Fatalf("sent unexpected packet = %#v", p) 1492 } 1493 }) 1494 } 1495 }) 1496 } 1497 } 1498 1499 // TestMGPDisabledOnLoopback tests that the multicast group protocol is not 1500 // performed on loopback interfaces since they have no neighbours. 1501 func TestMGPDisabledOnLoopback(t *testing.T) { 1502 tests := []struct { 1503 name string 1504 protoNum tcpip.NetworkProtocolNumber 1505 multicastAddr tcpip.Address 1506 sentReportStat func(*stack.Stack) *tcpip.StatCounter 1507 }{ 1508 { 1509 name: "IGMP", 1510 protoNum: ipv4.ProtocolNumber, 1511 multicastAddr: ipv4MulticastAddr1, 1512 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1513 return s.Stats().IGMP.PacketsSent.V2MembershipReport 1514 }, 1515 }, 1516 { 1517 name: "MLD", 1518 protoNum: ipv6.ProtocolNumber, 1519 multicastAddr: ipv6MulticastAddr1, 1520 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1521 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 1522 }, 1523 }, 1524 } 1525 1526 for _, test := range tests { 1527 t.Run(test.name, func(t *testing.T) { 1528 s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) 1529 defer func() { 1530 s.Close() 1531 s.Wait() 1532 }() 1533 sentReportStat := test.sentReportStat(s) 1534 if got := sentReportStat.Value(); got != 0 { 1535 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1536 } 1537 clock.Advance(time.Hour) 1538 if got := sentReportStat.Value(); got != 0 { 1539 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1540 } 1541 1542 // Test joining a specific group explicitly and verify that no reports are 1543 // sent. 1544 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 1545 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 1546 } 1547 if got := sentReportStat.Value(); got != 0 { 1548 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1549 } 1550 clock.Advance(time.Hour) 1551 if got := sentReportStat.Value(); got != 0 { 1552 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1553 } 1554 }) 1555 } 1556 } 1557 1558 func TestMGPCoalescedQueryResponseRecords(t *testing.T) { 1559 const igmpv3MLDv2ReportRecordHeaderLen = 4 1560 1561 type subTest struct { 1562 name string 1563 enterVersion func(e *channel.Endpoint) 1564 validateReport func(*testing.T, *stack.PacketBuffer) 1565 checkStats func(*testing.T, *stack.Stack, uint64, uint64, uint64) 1566 } 1567 1568 genAddr := func(bytes []byte, i uint16) tcpip.Address { 1569 bytes[len(bytes)-1] = byte(i & 0xFF) 1570 bytes[len(bytes)-2] = byte(i >> 8) 1571 return tcpip.AddrFromSlice(bytes[:]) 1572 } 1573 1574 calcMaxRecordsPerMessage := func(hdrLen, recordLen uint16) uint16 { 1575 return (header.IPv6MinimumMTU - hdrLen) / recordLen 1576 } 1577 1578 tests := []struct { 1579 name string 1580 protoNum tcpip.NetworkProtocolNumber 1581 maxUnsolicitedResponseDelay time.Duration 1582 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 1583 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) uint64 1584 validateReport func(*testing.T, *stack.PacketBuffer, tcpip.Address) 1585 checkStats func(*testing.T, *stack.Stack, uint64) 1586 genAddr func(uint16) tcpip.Address 1587 maxRecordsPerMessage uint16 1588 rxQuery func(*channel.Endpoint, uint8) 1589 validateReportWithMultipleRecords func(*testing.T, *channel.Endpoint, []tcpip.Address) 1590 }{ 1591 { 1592 name: "IGMP", 1593 protoNum: ipv4.ProtocolNumber, 1594 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 1595 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 1596 return s.Stats().IGMP.PacketsReceived.MembershipQuery 1597 }, 1598 validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1599 t.Helper() 1600 1601 validateIGMPv3ReportPacket(t, p, []tcpip.Address{addr}, header.IGMPv3ReportRecordChangeToExcludeMode) 1602 }, 1603 checkStats: func(t *testing.T, s *stack.Stack, reports uint64) { 1604 t.Helper() 1605 iptestutil.CheckIGMPv3Stats(t, s, 0, 0, reports) 1606 }, 1607 genAddr: func(i uint16) tcpip.Address { 1608 bytes := [header.IPv4AddressSize]byte{224, 1, 0, 0} 1609 return genAddr(bytes[:], i) 1610 }, 1611 maxRecordsPerMessage: calcMaxRecordsPerMessage(header.IPv4MinimumSize+8 /* size of IGMPv3 report header */, igmpv3MLDv2ReportRecordHeaderLen+header.IPv4AddressSize), 1612 rxQuery: func(e *channel.Endpoint, maxRespTime uint8) { 1613 createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, header.IPv4Any, header.IGMPv3QueryMinimumSize-header.IGMPQueryMinimumSize /* extraLength */) 1614 }, 1615 validateReportWithMultipleRecords: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1616 t.Helper() 1617 iptestutil.ValidateIGMPv3RecordsAcrossReports(t, e, stackIPv4Addr, addrs, header.IGMPv3ReportRecordModeIsExclude) 1618 }, 1619 }, 1620 { 1621 name: "MLD", 1622 protoNum: ipv6.ProtocolNumber, 1623 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 1624 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 1625 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 1626 }, 1627 checkInitialGroups: checkInitialIPv6Groups, 1628 validateReport: func(t *testing.T, p *stack.PacketBuffer, addr tcpip.Address) { 1629 t.Helper() 1630 1631 validateMLDv2ReportPacket(t, p, []tcpip.Address{addr}, header.MLDv2ReportRecordChangeToExcludeMode) 1632 }, 1633 checkStats: func(t *testing.T, s *stack.Stack, reports uint64) { 1634 t.Helper() 1635 iptestutil.CheckMLDv2Stats(t, s, 0, 0, reports) 1636 }, 1637 genAddr: func(i uint16) tcpip.Address { 1638 bytes := [header.IPv6AddressSize]byte{0xFF, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0} 1639 return genAddr(bytes[:], i) 1640 }, 1641 maxRecordsPerMessage: calcMaxRecordsPerMessage(header.IPv6MinimumSize+8 /* size of MLDv2 report header */, igmpv3MLDv2ReportRecordHeaderLen+header.IPv6AddressSize), 1642 rxQuery: func(e *channel.Endpoint, maxRespTime uint8) { 1643 createAndInjectMLDPacket(e, mldQuery, maxRespTime, header.IPv6Any, header.MLDv2QueryMinimumSize-header.MLDMinimumSize /* extraLength */) 1644 }, 1645 validateReportWithMultipleRecords: func(t *testing.T, e *channel.Endpoint, addrs []tcpip.Address) { 1646 t.Helper() 1647 1648 iptestutil.ValidateMLDv2RecordsAcrossReports(t, e, linkLocalIPv6Addr1, addrs, header.MLDv2ReportRecordModeIsExclude) 1649 }, 1650 }, 1651 } 1652 1653 subTests := []struct { 1654 name string 1655 extraRecords uint16 1656 expectedReports uint64 1657 }{ 1658 { 1659 name: "No extra records", 1660 extraRecords: 0, 1661 expectedReports: 1, 1662 }, 1663 { 1664 name: "One extra record", 1665 extraRecords: 1, 1666 expectedReports: 2, 1667 }, 1668 { 1669 name: "Two extra records", 1670 extraRecords: 2, 1671 expectedReports: 2, 1672 }, 1673 } 1674 1675 for _, test := range tests { 1676 t.Run(test.name, func(t *testing.T) { 1677 for _, subTest := range subTests { 1678 t.Run(subTest.name, func(t *testing.T) { 1679 ctx := newMulticastTestContext(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 1680 defer ctx.cleanup() 1681 s, e, clock := ctx.s, ctx.e, ctx.clock 1682 1683 var reportV2Counter uint64 1684 if test.checkInitialGroups != nil { 1685 reportV2Counter = test.checkInitialGroups(t, e, s, clock) 1686 } 1687 1688 addrs := make([]tcpip.Address, test.maxRecordsPerMessage+subTest.extraRecords) 1689 for i := 0; i < len(addrs); i++ { 1690 addr := test.genAddr(uint16(i)) 1691 addrs[i] = addr 1692 1693 if err := s.JoinGroup(test.protoNum, nicID, addr); err != nil { 1694 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, addr, err) 1695 } 1696 reportV2Counter++ 1697 test.checkStats(t, s, reportV2Counter) 1698 if p := e.Read(); p == nil { 1699 t.Fatal("expected a report message to be sent") 1700 } else { 1701 test.validateReport(t, p, addr) 1702 p.DecRef() 1703 } 1704 if t.Failed() { 1705 t.FailNow() 1706 } 1707 1708 // Verify the second report is sent by the maximum unsolicited response 1709 // interval. 1710 p := e.Read() 1711 if p != nil { 1712 t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p) 1713 } 1714 clock.Advance(test.maxUnsolicitedResponseDelay) 1715 reportV2Counter++ 1716 test.checkStats(t, s, reportV2Counter) 1717 if p := e.Read(); p == nil { 1718 t.Fatal("expected a report message to be sent") 1719 } else { 1720 test.validateReport(t, p, addr) 1721 p.DecRef() 1722 } 1723 } 1724 1725 // Should not send any more packets. 1726 clock.Advance(time.Hour) 1727 if p := e.Read(); p != nil { 1728 t.Fatalf("sent unexpected packet = %#v", p) 1729 } 1730 test.checkStats(t, s, reportV2Counter) 1731 1732 // Receive a query which should send a few reports which together hold 1733 // records for all the groups we joined. 1734 test.rxQuery(e, 1) 1735 clock.Advance(time.Second) 1736 reportV2Counter += subTest.expectedReports 1737 test.checkStats(t, s, reportV2Counter) 1738 test.validateReportWithMultipleRecords(t, e, addrs) 1739 }) 1740 } 1741 }) 1742 } 1743 }