github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/multicast_group_test.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ip_test 16 17 import ( 18 "fmt" 19 "strings" 20 "testing" 21 "time" 22 23 "github.com/SagerNet/gvisor/pkg/tcpip" 24 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 25 "github.com/SagerNet/gvisor/pkg/tcpip/checker" 26 "github.com/SagerNet/gvisor/pkg/tcpip/faketime" 27 "github.com/SagerNet/gvisor/pkg/tcpip/header" 28 "github.com/SagerNet/gvisor/pkg/tcpip/link/channel" 29 "github.com/SagerNet/gvisor/pkg/tcpip/link/loopback" 30 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4" 31 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6" 32 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 33 "github.com/SagerNet/gvisor/pkg/tcpip/testutil" 34 ) 35 36 const ( 37 linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") 38 39 defaultIPv4PrefixLength = 24 40 41 igmpMembershipQuery = uint8(header.IGMPMembershipQuery) 42 igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) 43 igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) 44 igmpLeaveGroup = uint8(header.IGMPLeaveGroup) 45 mldQuery = uint8(header.ICMPv6MulticastListenerQuery) 46 mldReport = uint8(header.ICMPv6MulticastListenerReport) 47 mldDone = uint8(header.ICMPv6MulticastListenerDone) 48 49 maxUnsolicitedReports = 2 50 ) 51 52 var ( 53 stackIPv4Addr = testutil.MustParse4("10.0.0.1") 54 linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1") 55 linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2") 56 57 ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3") 58 ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4") 59 ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5") 60 ipv6MulticastAddr1 = testutil.MustParse6("ff02::3") 61 ipv6MulticastAddr2 = testutil.MustParse6("ff02::4") 62 ipv6MulticastAddr3 = testutil.MustParse6("ff02::5") 63 ) 64 65 var ( 66 // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the 67 // NIC will wait before sending an unsolicited report after joining a 68 // multicast group, in deciseconds. 69 unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { 70 const decisecond = time.Second / 10 71 if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { 72 panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) 73 } 74 return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) 75 }() 76 77 ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1) 78 ) 79 80 // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet 81 // sent to the provided address with the passed fields set. 82 func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { 83 t.Helper() 84 85 payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) 86 checker.IPv6WithExtHdr(t, payload, 87 checker.IPv6ExtHdr( 88 checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), 89 ), 90 checker.SrcAddr(linkLocalIPv6Addr1), 91 checker.DstAddr(remoteAddress), 92 // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. 93 checker.TTL(1), 94 checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, 95 checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), 96 checker.MLDMulticastAddress(groupAddress), 97 ), 98 ) 99 } 100 101 // validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet 102 // sent to the provided address with the passed fields set. 103 func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { 104 t.Helper() 105 106 payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) 107 checker.IPv4(t, payload, 108 checker.SrcAddr(stackIPv4Addr), 109 checker.DstAddr(remoteAddress), 110 // TTL for an IGMP message must be 1 as per RFC 2236 section 2. 111 checker.TTL(1), 112 checker.IPv4RouterAlert(), 113 checker.IGMP( 114 checker.IGMPType(header.IGMPType(igmpType)), 115 checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), 116 checker.IGMPGroupAddress(groupAddress), 117 ), 118 ) 119 } 120 121 func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { 122 t.Helper() 123 124 e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) 125 s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) 126 return e, s, clock 127 } 128 129 func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { 130 t.Helper() 131 132 igmpEnabled := v4 && mgpEnabled 133 mldEnabled := !v4 && mgpEnabled 134 135 clock := faketime.NewManualClock() 136 s := stack.New(stack.Options{ 137 NetworkProtocols: []stack.NetworkProtocolFactory{ 138 ipv4.NewProtocolWithOptions(ipv4.Options{ 139 IGMP: ipv4.IGMPOptions{ 140 Enabled: igmpEnabled, 141 }, 142 }), 143 ipv6.NewProtocolWithOptions(ipv6.Options{ 144 MLD: ipv6.MLDOptions{ 145 Enabled: mldEnabled, 146 }, 147 }), 148 }, 149 Clock: clock, 150 }) 151 if err := s.CreateNIC(nicID, e); err != nil { 152 t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) 153 } 154 addr := tcpip.AddressWithPrefix{ 155 Address: stackIPv4Addr, 156 PrefixLen: defaultIPv4PrefixLength, 157 } 158 if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { 159 t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) 160 } 161 if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { 162 t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) 163 } 164 165 return s, clock 166 } 167 168 // checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join 169 // when it is created with an IPv6 address. 170 // 171 // To not interfere with tests, checkInitialIPv6Groups will leave the added 172 // address's solicited node multicast group so that the tests can all assume 173 // the NIC has not joined any IPv6 groups. 174 func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { 175 t.Helper() 176 177 stats := s.Stats().ICMP.V6.PacketsSent 178 179 reportCounter++ 180 if got := stats.MulticastListenerReport.Value(); got != reportCounter { 181 t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) 182 } 183 if p, ok := e.Read(); !ok { 184 t.Fatal("expected a report message to be sent") 185 } else { 186 validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) 187 } 188 189 // Leave the group to not affect the tests. This is fine since we are not 190 // testing DAD or the solicited node address specifically. 191 if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { 192 t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) 193 } 194 leaveCounter++ 195 if got := stats.MulticastListenerDone.Value(); got != leaveCounter { 196 t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) 197 } 198 if p, ok := e.Read(); !ok { 199 t.Fatal("expected a report message to be sent") 200 } else { 201 validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) 202 } 203 204 // Should not send any more packets. 205 clock.Advance(time.Hour) 206 if p, ok := e.Read(); ok { 207 t.Fatalf("sent unexpected packet = %#v", p) 208 } 209 210 return reportCounter, leaveCounter 211 } 212 213 // createAndInjectIGMPPacket creates and injects an IGMP packet with the 214 // specified fields. 215 func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { 216 options := header.IPv4OptionsSerializer{ 217 &header.IPv4SerializableRouterAlertOption{}, 218 } 219 buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) 220 ip := header.IPv4(buf) 221 ip.Encode(&header.IPv4Fields{ 222 TotalLength: uint16(len(buf)), 223 TTL: header.IGMPTTL, 224 Protocol: uint8(header.IGMPProtocolNumber), 225 SrcAddr: remoteIPv4Addr, 226 DstAddr: header.IPv4AllSystems, 227 Options: options, 228 }) 229 ip.SetChecksum(^ip.CalculateChecksum()) 230 231 igmp := header.IGMP(ip.Payload()) 232 igmp.SetType(header.IGMPType(igmpType)) 233 igmp.SetMaxRespTime(maxRespTime) 234 igmp.SetGroupAddress(groupAddress) 235 igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) 236 237 e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 238 Data: buf.ToVectorisedView(), 239 })) 240 } 241 242 // createAndInjectMLDPacket creates and injects an MLD packet with the 243 // specified fields. 244 func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { 245 extensionHeaders := header.IPv6ExtHdrSerializer{ 246 header.IPv6SerializableHopByHopExtHdr{ 247 &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, 248 }, 249 } 250 251 extensionHeadersLength := extensionHeaders.Length() 252 payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize 253 buf := buffer.NewView(header.IPv6MinimumSize + payloadLength) 254 255 ip := header.IPv6(buf) 256 ip.Encode(&header.IPv6Fields{ 257 PayloadLength: uint16(payloadLength), 258 HopLimit: header.MLDHopLimit, 259 TransportProtocol: header.ICMPv6ProtocolNumber, 260 SrcAddr: linkLocalIPv6Addr2, 261 DstAddr: header.IPv6AllNodesMulticastAddress, 262 ExtensionHeaders: extensionHeaders, 263 }) 264 265 icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) 266 icmp.SetType(header.ICMPv6Type(mldType)) 267 mld := header.MLD(icmp.MessageBody()) 268 mld.SetMaximumResponseDelay(uint16(maxRespDelay)) 269 mld.SetMulticastAddress(groupAddress) 270 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 271 Header: icmp, 272 Src: linkLocalIPv6Addr2, 273 Dst: header.IPv6AllNodesMulticastAddress, 274 })) 275 276 e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 277 Data: buf.ToVectorisedView(), 278 })) 279 } 280 281 // TestMGPDisabled tests that the multicast group protocol is not enabled by 282 // default. 283 func TestMGPDisabled(t *testing.T) { 284 tests := []struct { 285 name string 286 protoNum tcpip.NetworkProtocolNumber 287 multicastAddr tcpip.Address 288 sentReportStat func(*stack.Stack) *tcpip.StatCounter 289 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 290 rxQuery func(*channel.Endpoint) 291 }{ 292 { 293 name: "IGMP", 294 protoNum: ipv4.ProtocolNumber, 295 multicastAddr: ipv4MulticastAddr1, 296 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 297 return s.Stats().IGMP.PacketsSent.V2MembershipReport 298 }, 299 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 300 return s.Stats().IGMP.PacketsReceived.MembershipQuery 301 }, 302 rxQuery: func(e *channel.Endpoint) { 303 createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) 304 }, 305 }, 306 { 307 name: "MLD", 308 protoNum: ipv6.ProtocolNumber, 309 multicastAddr: ipv6MulticastAddr1, 310 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 311 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 312 }, 313 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 314 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 315 }, 316 rxQuery: func(e *channel.Endpoint) { 317 createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) 318 }, 319 }, 320 } 321 322 for _, test := range tests { 323 t.Run(test.name, func(t *testing.T) { 324 e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) 325 326 // This NIC may join multicast groups when it is enabled but since MGP is 327 // disabled, no reports should be sent. 328 sentReportStat := test.sentReportStat(s) 329 if got := sentReportStat.Value(); got != 0 { 330 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 331 } 332 clock.Advance(time.Hour) 333 if p, ok := e.Read(); ok { 334 t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) 335 } 336 337 // Test joining a specific group explicitly and verify that no reports are 338 // sent. 339 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 340 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 341 } 342 if got := sentReportStat.Value(); got != 0 { 343 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 344 } 345 clock.Advance(time.Hour) 346 if p, ok := e.Read(); ok { 347 t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) 348 } 349 350 // Inject a general query message. This should only trigger a report to be 351 // sent if the MGP was enabled. 352 test.rxQuery(e) 353 if got := test.receivedQueryStat(s).Value(); got != 1 { 354 t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) 355 } 356 clock.Advance(time.Hour) 357 if p, ok := e.Read(); ok { 358 t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) 359 } 360 }) 361 } 362 } 363 364 func TestMGPReceiveCounters(t *testing.T) { 365 tests := []struct { 366 name string 367 headerType uint8 368 maxRespTime byte 369 groupAddress tcpip.Address 370 statCounter func(*stack.Stack) *tcpip.StatCounter 371 rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) 372 }{ 373 { 374 name: "IGMP Membership Query", 375 headerType: igmpMembershipQuery, 376 maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, 377 groupAddress: header.IPv4Any, 378 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 379 return s.Stats().IGMP.PacketsReceived.MembershipQuery 380 }, 381 rxMGPkt: createAndInjectIGMPPacket, 382 }, 383 { 384 name: "IGMPv1 Membership Report", 385 headerType: igmpv1MembershipReport, 386 maxRespTime: 0, 387 groupAddress: header.IPv4AllSystems, 388 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 389 return s.Stats().IGMP.PacketsReceived.V1MembershipReport 390 }, 391 rxMGPkt: createAndInjectIGMPPacket, 392 }, 393 { 394 name: "IGMPv2 Membership Report", 395 headerType: igmpv2MembershipReport, 396 maxRespTime: 0, 397 groupAddress: header.IPv4AllSystems, 398 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 399 return s.Stats().IGMP.PacketsReceived.V2MembershipReport 400 }, 401 rxMGPkt: createAndInjectIGMPPacket, 402 }, 403 { 404 name: "IGMP Leave Group", 405 headerType: igmpLeaveGroup, 406 maxRespTime: 0, 407 groupAddress: header.IPv4AllRoutersGroup, 408 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 409 return s.Stats().IGMP.PacketsReceived.LeaveGroup 410 }, 411 rxMGPkt: createAndInjectIGMPPacket, 412 }, 413 { 414 name: "MLD Query", 415 headerType: mldQuery, 416 maxRespTime: 0, 417 groupAddress: header.IPv6Any, 418 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 419 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 420 }, 421 rxMGPkt: createAndInjectMLDPacket, 422 }, 423 { 424 name: "MLD Report", 425 headerType: mldReport, 426 maxRespTime: 0, 427 groupAddress: header.IPv6Any, 428 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 429 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport 430 }, 431 rxMGPkt: createAndInjectMLDPacket, 432 }, 433 { 434 name: "MLD Done", 435 headerType: mldDone, 436 maxRespTime: 0, 437 groupAddress: header.IPv6Any, 438 statCounter: func(s *stack.Stack) *tcpip.StatCounter { 439 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone 440 }, 441 rxMGPkt: createAndInjectMLDPacket, 442 }, 443 } 444 445 for _, test := range tests { 446 t.Run(test.name, func(t *testing.T) { 447 e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) 448 449 test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) 450 if got := test.statCounter(s).Value(); got != 1 { 451 t.Fatalf("got %s received = %d, want = 1", test.name, got) 452 } 453 }) 454 } 455 } 456 457 // TestMGPJoinGroup tests that when explicitly joining a multicast group, the 458 // stack schedules and sends correct Membership Reports. 459 func TestMGPJoinGroup(t *testing.T) { 460 tests := []struct { 461 name string 462 protoNum tcpip.NetworkProtocolNumber 463 multicastAddr tcpip.Address 464 maxUnsolicitedResponseDelay time.Duration 465 sentReportStat func(*stack.Stack) *tcpip.StatCounter 466 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 467 validateReport func(*testing.T, channel.PacketInfo) 468 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) 469 }{ 470 { 471 name: "IGMP", 472 protoNum: ipv4.ProtocolNumber, 473 multicastAddr: ipv4MulticastAddr1, 474 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 475 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 476 return s.Stats().IGMP.PacketsSent.V2MembershipReport 477 }, 478 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 479 return s.Stats().IGMP.PacketsReceived.MembershipQuery 480 }, 481 validateReport: func(t *testing.T, p channel.PacketInfo) { 482 t.Helper() 483 484 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 485 }, 486 }, 487 { 488 name: "MLD", 489 protoNum: ipv6.ProtocolNumber, 490 multicastAddr: ipv6MulticastAddr1, 491 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 492 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 493 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 494 }, 495 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 496 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 497 }, 498 validateReport: func(t *testing.T, p channel.PacketInfo) { 499 t.Helper() 500 501 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 502 }, 503 checkInitialGroups: checkInitialIPv6Groups, 504 }, 505 } 506 507 for _, test := range tests { 508 t.Run(test.name, func(t *testing.T) { 509 e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 510 511 var reportCounter uint64 512 if test.checkInitialGroups != nil { 513 reportCounter, _ = test.checkInitialGroups(t, e, s, clock) 514 } 515 516 // Test joining a specific address explicitly and verify a Report is sent 517 // immediately. 518 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 519 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 520 } 521 reportCounter++ 522 sentReportStat := test.sentReportStat(s) 523 if got := sentReportStat.Value(); got != reportCounter { 524 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 525 } 526 if p, ok := e.Read(); !ok { 527 t.Fatal("expected a report message to be sent") 528 } else { 529 test.validateReport(t, p) 530 } 531 if t.Failed() { 532 t.FailNow() 533 } 534 535 // Verify the second report is sent by the maximum unsolicited response 536 // interval. 537 p, ok := e.Read() 538 if ok { 539 t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) 540 } 541 clock.Advance(test.maxUnsolicitedResponseDelay) 542 reportCounter++ 543 if got := sentReportStat.Value(); got != reportCounter { 544 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 545 } 546 if p, ok := e.Read(); !ok { 547 t.Fatal("expected a report message to be sent") 548 } else { 549 test.validateReport(t, p) 550 } 551 552 // Should not send any more packets. 553 clock.Advance(time.Hour) 554 if p, ok := e.Read(); ok { 555 t.Fatalf("sent unexpected packet = %#v", p) 556 } 557 }) 558 } 559 } 560 561 // TestMGPLeaveGroup tests that when leaving a previously joined multicast 562 // group the stack sends a leave/done message. 563 func TestMGPLeaveGroup(t *testing.T) { 564 tests := []struct { 565 name string 566 protoNum tcpip.NetworkProtocolNumber 567 multicastAddr tcpip.Address 568 sentReportStat func(*stack.Stack) *tcpip.StatCounter 569 sentLeaveStat func(*stack.Stack) *tcpip.StatCounter 570 validateReport func(*testing.T, channel.PacketInfo) 571 validateLeave func(*testing.T, channel.PacketInfo) 572 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) 573 }{ 574 { 575 name: "IGMP", 576 protoNum: ipv4.ProtocolNumber, 577 multicastAddr: ipv4MulticastAddr1, 578 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 579 return s.Stats().IGMP.PacketsSent.V2MembershipReport 580 }, 581 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 582 return s.Stats().IGMP.PacketsSent.LeaveGroup 583 }, 584 validateReport: func(t *testing.T, p channel.PacketInfo) { 585 t.Helper() 586 587 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 588 }, 589 validateLeave: func(t *testing.T, p channel.PacketInfo) { 590 t.Helper() 591 592 validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) 593 }, 594 }, 595 { 596 name: "MLD", 597 protoNum: ipv6.ProtocolNumber, 598 multicastAddr: ipv6MulticastAddr1, 599 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 600 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 601 }, 602 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 603 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone 604 }, 605 validateReport: func(t *testing.T, p channel.PacketInfo) { 606 t.Helper() 607 608 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 609 }, 610 validateLeave: func(t *testing.T, p channel.PacketInfo) { 611 t.Helper() 612 613 validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) 614 }, 615 checkInitialGroups: checkInitialIPv6Groups, 616 }, 617 } 618 619 for _, test := range tests { 620 t.Run(test.name, func(t *testing.T) { 621 e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 622 623 var reportCounter uint64 624 var leaveCounter uint64 625 if test.checkInitialGroups != nil { 626 reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) 627 } 628 629 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 630 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 631 } 632 reportCounter++ 633 if got := test.sentReportStat(s).Value(); got != reportCounter { 634 t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) 635 } 636 if p, ok := e.Read(); !ok { 637 t.Fatal("expected a report message to be sent") 638 } else { 639 test.validateReport(t, p) 640 } 641 if t.Failed() { 642 t.FailNow() 643 } 644 645 // Leaving the group should trigger an leave/done message to be sent. 646 if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 647 t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) 648 } 649 leaveCounter++ 650 if got := test.sentLeaveStat(s).Value(); got != leaveCounter { 651 t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) 652 } 653 if p, ok := e.Read(); !ok { 654 t.Fatal("expected a leave message to be sent") 655 } else { 656 test.validateLeave(t, p) 657 } 658 659 // Should not send any more packets. 660 clock.Advance(time.Hour) 661 if p, ok := e.Read(); ok { 662 t.Fatalf("sent unexpected packet = %#v", p) 663 } 664 }) 665 } 666 } 667 668 // TestMGPQueryMessages tests that a report is sent in response to query 669 // messages. 670 func TestMGPQueryMessages(t *testing.T) { 671 tests := []struct { 672 name string 673 protoNum tcpip.NetworkProtocolNumber 674 multicastAddr tcpip.Address 675 maxUnsolicitedResponseDelay time.Duration 676 sentReportStat func(*stack.Stack) *tcpip.StatCounter 677 receivedQueryStat func(*stack.Stack) *tcpip.StatCounter 678 rxQuery func(*channel.Endpoint, uint8, tcpip.Address) 679 validateReport func(*testing.T, channel.PacketInfo) 680 maxRespTimeToDuration func(uint8) time.Duration 681 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) 682 }{ 683 { 684 name: "IGMP", 685 protoNum: ipv4.ProtocolNumber, 686 multicastAddr: ipv4MulticastAddr1, 687 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 688 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 689 return s.Stats().IGMP.PacketsSent.V2MembershipReport 690 }, 691 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 692 return s.Stats().IGMP.PacketsReceived.MembershipQuery 693 }, 694 rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { 695 createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) 696 }, 697 validateReport: func(t *testing.T, p channel.PacketInfo) { 698 t.Helper() 699 700 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 701 }, 702 maxRespTimeToDuration: header.DecisecondToDuration, 703 }, 704 { 705 name: "MLD", 706 protoNum: ipv6.ProtocolNumber, 707 multicastAddr: ipv6MulticastAddr1, 708 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 709 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 710 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 711 }, 712 receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { 713 return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery 714 }, 715 rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { 716 createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) 717 }, 718 validateReport: func(t *testing.T, p channel.PacketInfo) { 719 t.Helper() 720 721 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 722 }, 723 maxRespTimeToDuration: func(d uint8) time.Duration { 724 return time.Duration(d) * time.Millisecond 725 }, 726 checkInitialGroups: checkInitialIPv6Groups, 727 }, 728 } 729 730 for _, test := range tests { 731 t.Run(test.name, func(t *testing.T) { 732 subTests := []struct { 733 name string 734 multicastAddr tcpip.Address 735 expectReport bool 736 }{ 737 { 738 name: "Unspecified", 739 multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), 740 expectReport: true, 741 }, 742 { 743 name: "Specified", 744 multicastAddr: test.multicastAddr, 745 expectReport: true, 746 }, 747 { 748 name: "Specified other address", 749 multicastAddr: func() tcpip.Address { 750 addrBytes := []byte(test.multicastAddr) 751 addrBytes[len(addrBytes)-1]++ 752 return tcpip.Address(addrBytes) 753 }(), 754 expectReport: false, 755 }, 756 } 757 758 for _, subTest := range subTests { 759 t.Run(subTest.name, func(t *testing.T) { 760 e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 761 762 var reportCounter uint64 763 if test.checkInitialGroups != nil { 764 reportCounter, _ = test.checkInitialGroups(t, e, s, clock) 765 } 766 767 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 768 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 769 } 770 sentReportStat := test.sentReportStat(s) 771 for i := 0; i < maxUnsolicitedReports; i++ { 772 sentReportStat := test.sentReportStat(s) 773 reportCounter++ 774 if got := sentReportStat.Value(); got != reportCounter { 775 t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) 776 } 777 if p, ok := e.Read(); !ok { 778 t.Fatalf("expected %d-th report message to be sent", i) 779 } else { 780 test.validateReport(t, p) 781 } 782 clock.Advance(test.maxUnsolicitedResponseDelay) 783 } 784 if t.Failed() { 785 t.FailNow() 786 } 787 788 // Should not send any more packets until a query. 789 clock.Advance(time.Hour) 790 if p, ok := e.Read(); ok { 791 t.Fatalf("sent unexpected packet = %#v", p) 792 } 793 794 // Receive a query message which should trigger a report to be sent at 795 // some time before the maximum response time if the report is 796 // targeted at the host. 797 const maxRespTime = 100 798 test.rxQuery(e, maxRespTime, subTest.multicastAddr) 799 if p, ok := e.Read(); ok { 800 t.Fatalf("sent unexpected packet = %#v", p.Pkt) 801 } 802 803 if subTest.expectReport { 804 clock.Advance(test.maxRespTimeToDuration(maxRespTime)) 805 reportCounter++ 806 if got := sentReportStat.Value(); got != reportCounter { 807 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 808 } 809 if p, ok := e.Read(); !ok { 810 t.Fatal("expected a report message to be sent") 811 } else { 812 test.validateReport(t, p) 813 } 814 } 815 816 // Should not send any more packets. 817 clock.Advance(time.Hour) 818 if p, ok := e.Read(); ok { 819 t.Fatalf("sent unexpected packet = %#v", p) 820 } 821 }) 822 } 823 }) 824 } 825 } 826 827 // TestMGPQueryMessages tests that no further reports or leave/done messages 828 // are sent after receiving a report. 829 func TestMGPReportMessages(t *testing.T) { 830 tests := []struct { 831 name string 832 protoNum tcpip.NetworkProtocolNumber 833 multicastAddr tcpip.Address 834 sentReportStat func(*stack.Stack) *tcpip.StatCounter 835 sentLeaveStat func(*stack.Stack) *tcpip.StatCounter 836 rxReport func(*channel.Endpoint) 837 validateReport func(*testing.T, channel.PacketInfo) 838 maxRespTimeToDuration func(uint8) time.Duration 839 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) 840 }{ 841 { 842 name: "IGMP", 843 protoNum: ipv4.ProtocolNumber, 844 multicastAddr: ipv4MulticastAddr1, 845 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 846 return s.Stats().IGMP.PacketsSent.V2MembershipReport 847 }, 848 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 849 return s.Stats().IGMP.PacketsSent.LeaveGroup 850 }, 851 rxReport: func(e *channel.Endpoint) { 852 createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 853 }, 854 validateReport: func(t *testing.T, p channel.PacketInfo) { 855 t.Helper() 856 857 validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) 858 }, 859 maxRespTimeToDuration: header.DecisecondToDuration, 860 }, 861 { 862 name: "MLD", 863 protoNum: ipv6.ProtocolNumber, 864 multicastAddr: ipv6MulticastAddr1, 865 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 866 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 867 }, 868 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 869 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone 870 }, 871 rxReport: func(e *channel.Endpoint) { 872 createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) 873 }, 874 validateReport: func(t *testing.T, p channel.PacketInfo) { 875 t.Helper() 876 877 validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) 878 }, 879 maxRespTimeToDuration: func(d uint8) time.Duration { 880 return time.Duration(d) * time.Millisecond 881 }, 882 checkInitialGroups: checkInitialIPv6Groups, 883 }, 884 } 885 886 for _, test := range tests { 887 t.Run(test.name, func(t *testing.T) { 888 e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 889 890 var reportCounter uint64 891 var leaveCounter uint64 892 if test.checkInitialGroups != nil { 893 reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) 894 } 895 896 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 897 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 898 } 899 sentReportStat := test.sentReportStat(s) 900 reportCounter++ 901 if got := sentReportStat.Value(); got != reportCounter { 902 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 903 } 904 if p, ok := e.Read(); !ok { 905 t.Fatal("expected a report message to be sent") 906 } else { 907 test.validateReport(t, p) 908 } 909 if t.Failed() { 910 t.FailNow() 911 } 912 913 // Receiving a report for a group we joined should cancel any further 914 // reports. 915 test.rxReport(e) 916 clock.Advance(time.Hour) 917 if got := sentReportStat.Value(); got != reportCounter { 918 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 919 } 920 if p, ok := e.Read(); ok { 921 t.Errorf("sent unexpected packet = %#v", p) 922 } 923 if t.Failed() { 924 t.FailNow() 925 } 926 927 // Leaving a group after getting a report should not send a leave/done 928 // message. 929 if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 930 t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) 931 } 932 clock.Advance(time.Hour) 933 if got := test.sentLeaveStat(s).Value(); got != leaveCounter { 934 t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) 935 } 936 937 // Should not send any more packets. 938 clock.Advance(time.Hour) 939 if p, ok := e.Read(); ok { 940 t.Fatalf("sent unexpected packet = %#v", p) 941 } 942 }) 943 } 944 } 945 946 func TestMGPWithNICLifecycle(t *testing.T) { 947 tests := []struct { 948 name string 949 protoNum tcpip.NetworkProtocolNumber 950 multicastAddrs []tcpip.Address 951 finalMulticastAddr tcpip.Address 952 maxUnsolicitedResponseDelay time.Duration 953 sentReportStat func(*stack.Stack) *tcpip.StatCounter 954 sentLeaveStat func(*stack.Stack) *tcpip.StatCounter 955 validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) 956 validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) 957 getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address 958 checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) 959 }{ 960 { 961 name: "IGMP", 962 protoNum: ipv4.ProtocolNumber, 963 multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, 964 finalMulticastAddr: ipv4MulticastAddr3, 965 maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, 966 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 967 return s.Stats().IGMP.PacketsSent.V2MembershipReport 968 }, 969 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 970 return s.Stats().IGMP.PacketsSent.LeaveGroup 971 }, 972 validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { 973 t.Helper() 974 975 validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) 976 }, 977 validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { 978 t.Helper() 979 980 validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) 981 }, 982 getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { 983 t.Helper() 984 985 ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) 986 if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { 987 t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) 988 } 989 addr := header.IGMP(ipv4.Payload()).GroupAddress() 990 s, ok := seen[addr] 991 if !ok { 992 t.Fatalf("unexpectedly got a packet for group %s", addr) 993 } 994 if s { 995 t.Fatalf("already saw packet for group %s", addr) 996 } 997 seen[addr] = true 998 return addr 999 }, 1000 }, 1001 { 1002 name: "MLD", 1003 protoNum: ipv6.ProtocolNumber, 1004 multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, 1005 finalMulticastAddr: ipv6MulticastAddr3, 1006 maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, 1007 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1008 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 1009 }, 1010 sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { 1011 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone 1012 }, 1013 validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { 1014 t.Helper() 1015 1016 validateMLDPacket(t, p, addr, mldReport, 0, addr) 1017 }, 1018 validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { 1019 t.Helper() 1020 1021 validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) 1022 }, 1023 getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { 1024 t.Helper() 1025 1026 ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) 1027 1028 ipv6HeaderIter := header.MakeIPv6PayloadIterator( 1029 header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), 1030 buffer.View(ipv6.Payload()).ToVectorisedView(), 1031 ) 1032 1033 var transport header.IPv6RawPayloadHeader 1034 for { 1035 h, done, err := ipv6HeaderIter.Next() 1036 if err != nil { 1037 t.Fatalf("ipv6HeaderIter.Next(): %s", err) 1038 } 1039 if done { 1040 t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) 1041 } 1042 if t, ok := h.(header.IPv6RawPayloadHeader); ok { 1043 transport = t 1044 break 1045 } 1046 } 1047 1048 if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { 1049 t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) 1050 } 1051 icmpv6 := header.ICMPv6(transport.Buf.ToView()) 1052 if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { 1053 t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) 1054 } 1055 addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() 1056 s, ok := seen[addr] 1057 if !ok { 1058 t.Fatalf("unexpectedly got a packet for group %s", addr) 1059 } 1060 if s { 1061 t.Fatalf("already saw packet for group %s", addr) 1062 } 1063 seen[addr] = true 1064 return addr 1065 }, 1066 checkInitialGroups: checkInitialIPv6Groups, 1067 }, 1068 } 1069 1070 for _, test := range tests { 1071 t.Run(test.name, func(t *testing.T) { 1072 e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) 1073 1074 var reportCounter uint64 1075 var leaveCounter uint64 1076 if test.checkInitialGroups != nil { 1077 reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) 1078 } 1079 1080 sentReportStat := test.sentReportStat(s) 1081 for _, a := range test.multicastAddrs { 1082 if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { 1083 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) 1084 } 1085 reportCounter++ 1086 if got := sentReportStat.Value(); got != reportCounter { 1087 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 1088 } 1089 if p, ok := e.Read(); !ok { 1090 t.Fatalf("expected a report message to be sent for %s", a) 1091 } else { 1092 test.validateReport(t, p, a) 1093 } 1094 } 1095 if t.Failed() { 1096 t.FailNow() 1097 } 1098 1099 // Leave messages should be sent for the joined groups when the NIC is 1100 // disabled. 1101 if err := s.DisableNIC(nicID); err != nil { 1102 t.Fatalf("DisableNIC(%d): %s", nicID, err) 1103 } 1104 sentLeaveStat := test.sentLeaveStat(s) 1105 leaveCounter += uint64(len(test.multicastAddrs)) 1106 if got := sentLeaveStat.Value(); got != leaveCounter { 1107 t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) 1108 } 1109 { 1110 seen := make(map[tcpip.Address]bool) 1111 for _, a := range test.multicastAddrs { 1112 seen[a] = false 1113 } 1114 1115 for i := range test.multicastAddrs { 1116 p, ok := e.Read() 1117 if !ok { 1118 t.Fatalf("expected (%d-th) leave message to be sent", i) 1119 } 1120 1121 test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) 1122 } 1123 } 1124 if t.Failed() { 1125 t.FailNow() 1126 } 1127 1128 // Reports should be sent for the joined groups when the NIC is enabled. 1129 if err := s.EnableNIC(nicID); err != nil { 1130 t.Fatalf("EnableNIC(%d): %s", nicID, err) 1131 } 1132 reportCounter += uint64(len(test.multicastAddrs)) 1133 if got := sentReportStat.Value(); got != reportCounter { 1134 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 1135 } 1136 { 1137 seen := make(map[tcpip.Address]bool) 1138 for _, a := range test.multicastAddrs { 1139 seen[a] = false 1140 } 1141 1142 for i := range test.multicastAddrs { 1143 p, ok := e.Read() 1144 if !ok { 1145 t.Fatalf("expected (%d-th) report message to be sent", i) 1146 } 1147 1148 test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) 1149 } 1150 } 1151 if t.Failed() { 1152 t.FailNow() 1153 } 1154 1155 // Joining/leaving a group while disabled should not send any messages. 1156 if err := s.DisableNIC(nicID); err != nil { 1157 t.Fatalf("DisableNIC(%d): %s", nicID, err) 1158 } 1159 leaveCounter += uint64(len(test.multicastAddrs)) 1160 if got := sentLeaveStat.Value(); got != leaveCounter { 1161 t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) 1162 } 1163 for i := range test.multicastAddrs { 1164 if _, ok := e.Read(); !ok { 1165 t.Fatalf("expected (%d-th) leave message to be sent", i) 1166 } 1167 } 1168 for _, a := range test.multicastAddrs { 1169 if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { 1170 t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) 1171 } 1172 if got := sentLeaveStat.Value(); got != leaveCounter { 1173 t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) 1174 } 1175 if p, ok := e.Read(); ok { 1176 t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) 1177 } 1178 } 1179 if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { 1180 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) 1181 } 1182 if got := sentReportStat.Value(); got != reportCounter { 1183 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 1184 } 1185 if p, ok := e.Read(); ok { 1186 t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) 1187 } 1188 1189 // A report should only be sent for the group we last joined after 1190 // enabling the NIC since the original groups were all left. 1191 if err := s.EnableNIC(nicID); err != nil { 1192 t.Fatalf("EnableNIC(%d): %s", nicID, err) 1193 } 1194 reportCounter++ 1195 if got := sentReportStat.Value(); got != reportCounter { 1196 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 1197 } 1198 if p, ok := e.Read(); !ok { 1199 t.Fatal("expected a report message to be sent") 1200 } else { 1201 test.validateReport(t, p, test.finalMulticastAddr) 1202 } 1203 1204 clock.Advance(test.maxUnsolicitedResponseDelay) 1205 reportCounter++ 1206 if got := sentReportStat.Value(); got != reportCounter { 1207 t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) 1208 } 1209 if p, ok := e.Read(); !ok { 1210 t.Fatal("expected a report message to be sent") 1211 } else { 1212 test.validateReport(t, p, test.finalMulticastAddr) 1213 } 1214 1215 // Should not send any more packets. 1216 clock.Advance(time.Hour) 1217 if p, ok := e.Read(); ok { 1218 t.Fatalf("sent unexpected packet = %#v", p) 1219 } 1220 }) 1221 } 1222 } 1223 1224 // TestMGPDisabledOnLoopback tests that the multicast group protocol is not 1225 // performed on loopback interfaces since they have no neighbours. 1226 func TestMGPDisabledOnLoopback(t *testing.T) { 1227 tests := []struct { 1228 name string 1229 protoNum tcpip.NetworkProtocolNumber 1230 multicastAddr tcpip.Address 1231 sentReportStat func(*stack.Stack) *tcpip.StatCounter 1232 }{ 1233 { 1234 name: "IGMP", 1235 protoNum: ipv4.ProtocolNumber, 1236 multicastAddr: ipv4MulticastAddr1, 1237 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1238 return s.Stats().IGMP.PacketsSent.V2MembershipReport 1239 }, 1240 }, 1241 { 1242 name: "MLD", 1243 protoNum: ipv6.ProtocolNumber, 1244 multicastAddr: ipv6MulticastAddr1, 1245 sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { 1246 return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport 1247 }, 1248 }, 1249 } 1250 1251 for _, test := range tests { 1252 t.Run(test.name, func(t *testing.T) { 1253 s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) 1254 1255 sentReportStat := test.sentReportStat(s) 1256 if got := sentReportStat.Value(); got != 0 { 1257 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1258 } 1259 clock.Advance(time.Hour) 1260 if got := sentReportStat.Value(); got != 0 { 1261 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1262 } 1263 1264 // Test joining a specific group explicitly and verify that no reports are 1265 // sent. 1266 if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { 1267 t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) 1268 } 1269 if got := sentReportStat.Value(); got != 0 { 1270 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1271 } 1272 clock.Advance(time.Hour) 1273 if got := sentReportStat.Value(); got != 0 { 1274 t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) 1275 } 1276 }) 1277 } 1278 }