gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ip_test 16 17 import ( 18 "bytes" 19 "fmt" 20 "math/rand" 21 "testing" 22 "time" 23 24 "github.com/google/go-cmp/cmp" 25 "gvisor.dev/gvisor/pkg/sync" 26 "gvisor.dev/gvisor/pkg/tcpip" 27 "gvisor.dev/gvisor/pkg/tcpip/faketime" 28 "gvisor.dev/gvisor/pkg/tcpip/header" 29 "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" 30 ) 31 32 const maxUnsolicitedReportDelay = time.Second 33 34 var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) 35 36 type mockMulticastGroupProtocolProtectedFields struct { 37 sync.RWMutex 38 39 genericMulticastGroup ip.GenericMulticastProtocolState 40 sendReportGroupAddrCount map[tcpip.Address]int 41 sendLeaveGroupAddrCount map[tcpip.Address]int 42 makeQueuePackets bool 43 disabled bool 44 sentV2Reports map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType 45 } 46 47 type mockMulticastGroupProtocol struct { 48 t *testing.T 49 50 skipProtocolAddress tcpip.Address 51 52 mu mockMulticastGroupProtocolProtectedFields 53 } 54 55 func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions, v1Compatibility bool) { 56 m.mu.Lock() 57 defer m.mu.Unlock() 58 m.initLocked() 59 opts.Protocol = m 60 m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) 61 62 if v1Compatibility { 63 m.mu.genericMulticastGroup.SetV1ModeLocked(true) 64 } 65 } 66 67 func (m *mockMulticastGroupProtocol) initLocked() { 68 m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) 69 m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) 70 m.mu.sentV2Reports = make(map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType) 71 } 72 73 func (m *mockMulticastGroupProtocol) setEnabled(v bool) { 74 m.mu.Lock() 75 defer m.mu.Unlock() 76 m.mu.disabled = !v 77 } 78 79 func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { 80 m.mu.Lock() 81 defer m.mu.Unlock() 82 m.mu.makeQueuePackets = v 83 } 84 85 func (m *mockMulticastGroupProtocol) setV1Mode(v bool) bool { 86 m.mu.Lock() 87 defer m.mu.Unlock() 88 return m.mu.genericMulticastGroup.SetV1ModeLocked(v) 89 } 90 91 func (m *mockMulticastGroupProtocol) getV1Mode() bool { 92 m.mu.RLock() 93 defer m.mu.RUnlock() 94 return m.mu.genericMulticastGroup.GetV1ModeLocked() 95 } 96 97 func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { 98 m.mu.Lock() 99 defer m.mu.Unlock() 100 m.mu.genericMulticastGroup.JoinGroupLocked(addr) 101 } 102 103 func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { 104 m.mu.Lock() 105 defer m.mu.Unlock() 106 return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) 107 } 108 109 func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { 110 m.mu.Lock() 111 defer m.mu.Unlock() 112 m.mu.genericMulticastGroup.HandleReportLocked(addr) 113 } 114 115 func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { 116 m.mu.Lock() 117 defer m.mu.Unlock() 118 m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) 119 } 120 121 func (m *mockMulticastGroupProtocol) handleQueryV2(addr tcpip.Address, maxResponseCode uint16, sources header.AddressIterator, robustnessVariable uint8, queryInterval time.Duration) { 122 m.mu.Lock() 123 defer m.mu.Unlock() 124 m.mu.genericMulticastGroup.HandleQueryV2Locked(addr, maxResponseCode, sources, robustnessVariable, queryInterval) 125 } 126 127 func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { 128 m.mu.RLock() 129 defer m.mu.RUnlock() 130 return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) 131 } 132 133 func (m *mockMulticastGroupProtocol) makeAllNonMember() { 134 m.mu.Lock() 135 defer m.mu.Unlock() 136 m.mu.genericMulticastGroup.MakeAllNonMemberLocked() 137 } 138 139 func (m *mockMulticastGroupProtocol) initializeGroups() { 140 m.mu.Lock() 141 defer m.mu.Unlock() 142 m.mu.genericMulticastGroup.InitializeGroupsLocked() 143 } 144 145 func (m *mockMulticastGroupProtocol) sendQueuedReports() { 146 m.mu.Lock() 147 defer m.mu.Unlock() 148 m.mu.genericMulticastGroup.SendQueuedReportsLocked() 149 } 150 151 // Enabled implements ip.MulticastGroupProtocol. 152 // 153 // Precondition: m.mu must be read locked. 154 func (m *mockMulticastGroupProtocol) Enabled() bool { 155 if m.mu.TryLock() { 156 m.mu.Unlock() // +checklocksforce: TryLock. 157 m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") 158 } 159 160 return !m.mu.disabled 161 } 162 163 // SendReport implements ip.MulticastGroupProtocol. 164 // 165 // Precondition: m.mu must be locked. 166 func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { 167 if m.mu.TryLock() { 168 m.mu.Unlock() // +checklocksforce: TryLock. 169 m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) 170 } 171 if m.mu.TryRLock() { 172 m.mu.RUnlock() // +checklocksforce: TryLock. 173 m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) 174 } 175 176 m.mu.sendReportGroupAddrCount[groupAddress]++ 177 return !m.mu.makeQueuePackets, nil 178 } 179 180 // SendLeave implements ip.MulticastGroupProtocol. 181 // 182 // Precondition: m.mu must be locked. 183 func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { 184 if m.mu.TryLock() { 185 m.mu.Unlock() // +checklocksforce: TryLock. 186 m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) 187 } 188 if m.mu.TryRLock() { 189 m.mu.RUnlock() // +checklocksforce: TryLock. 190 m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) 191 } 192 193 m.mu.sendLeaveGroupAddrCount[groupAddress]++ 194 return nil 195 } 196 197 // ShouldPerformProtocol implements ip.MulticastGroupProtocol. 198 func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { 199 return groupAddress != m.skipProtocolAddress 200 } 201 202 type mockReportV2Record struct { 203 recordType ip.MulticastGroupProtocolV2ReportRecordType 204 groupAddress tcpip.Address 205 } 206 207 type mockReportV2 struct { 208 records []mockReportV2Record 209 } 210 211 type mockReportV2Builder struct { 212 m *mockMulticastGroupProtocol 213 report mockReportV2 214 } 215 216 // AddRecord implements ip.MulticastGroupProtocolV2ReportBuilder. 217 func (b *mockReportV2Builder) AddRecord(recordType ip.MulticastGroupProtocolV2ReportRecordType, groupAddress tcpip.Address) { 218 b.report.records = append(b.report.records, mockReportV2Record{recordType: recordType, groupAddress: groupAddress}) 219 } 220 221 func recordsToMap(m map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType, records []mockReportV2Record) { 222 for _, record := range records { 223 m[record.groupAddress] = append(m[record.groupAddress], record.recordType) 224 } 225 } 226 227 // Send implements ip.MulticastGroupProtocolV2ReportBuilder. 228 func (b *mockReportV2Builder) Send() (sent bool, err tcpip.Error) { 229 if b.m.mu.TryLock() { 230 b.m.mu.Unlock() // +checklocksforce: TryLock. 231 b.m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending v2 report") 232 } 233 if b.m.mu.TryRLock() { 234 b.m.mu.RUnlock() // +checklocksforce: TryLock. 235 b.m.t.Fatal("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending v2 report") 236 } 237 238 recordsToMap(b.m.mu.sentV2Reports, b.report.records) 239 return !b.m.mu.makeQueuePackets, nil 240 } 241 242 // NewReportV2Builder implements ip.MulticastGroupProtocol. 243 func (m *mockMulticastGroupProtocol) NewReportV2Builder() ip.MulticastGroupProtocolV2ReportBuilder { 244 return &mockReportV2Builder{m: m} 245 } 246 247 // V2QueryMaxRespCodeToV2Delay implements ip.MulticastGroupProtocol. 248 func (*mockMulticastGroupProtocol) V2QueryMaxRespCodeToV2Delay(code uint16) time.Duration { 249 return time.Duration(code) * time.Millisecond 250 } 251 252 // V2QueryMaxRespCodeToV1Delay implements ip.MulticastGroupProtocol. 253 func (*mockMulticastGroupProtocol) V2QueryMaxRespCodeToV1Delay(code uint16) time.Duration { 254 return time.Duration(code) * time.Millisecond 255 } 256 257 type checkFields struct { 258 sendReportGroupAddresses []tcpip.Address 259 sendLeaveGroupAddresses []tcpip.Address 260 sentV2Reports []mockReportV2 261 } 262 263 func (m *mockMulticastGroupProtocol) check(fields checkFields) string { 264 m.mu.Lock() 265 defer m.mu.Unlock() 266 267 sendReportGroupAddrCount := make(map[tcpip.Address]int) 268 for _, a := range fields.sendReportGroupAddresses { 269 sendReportGroupAddrCount[a] = 1 270 } 271 272 sendLeaveGroupAddrCount := make(map[tcpip.Address]int) 273 for _, a := range fields.sendLeaveGroupAddresses { 274 sendLeaveGroupAddrCount[a] = 1 275 } 276 277 sentV2Reports := make(map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType) 278 for _, report := range fields.sentV2Reports { 279 recordsToMap(sentV2Reports, report.records) 280 } 281 282 diff := cmp.Diff( 283 &mockMulticastGroupProtocol{ 284 mu: mockMulticastGroupProtocolProtectedFields{ 285 sendReportGroupAddrCount: sendReportGroupAddrCount, 286 sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, 287 sentV2Reports: sentV2Reports, 288 }, 289 }, 290 m, 291 cmp.AllowUnexported(mockMulticastGroupProtocol{}), 292 cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), 293 cmp.AllowUnexported(mockReportV2{}), 294 cmp.AllowUnexported(mockReportV2Record{}), 295 // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t 296 cmp.FilterPath( 297 func(p cmp.Path) bool { 298 switch p.Last().String() { 299 case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": 300 return true 301 default: 302 return false 303 } 304 }, 305 cmp.Ignore(), 306 ), 307 ) 308 m.initLocked() 309 return diff 310 } 311 312 func TestJoinGroup(t *testing.T) { 313 tests := []struct { 314 name string 315 addr tcpip.Address 316 shouldSendReports bool 317 }{ 318 { 319 name: "Normal group", 320 addr: addr1, 321 shouldSendReports: true, 322 }, 323 { 324 name: "All-nodes group", 325 addr: addr2, 326 shouldSendReports: false, 327 }, 328 } 329 330 subTests := []struct { 331 name string 332 v1Compatibility bool 333 checkFields func(tcpip.Address) checkFields 334 }{ 335 { 336 name: "V1 Compatibility", 337 v1Compatibility: true, 338 checkFields: func(addr tcpip.Address) checkFields { 339 return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} 340 }, 341 }, 342 { 343 name: "V2", 344 v1Compatibility: false, 345 checkFields: func(addr tcpip.Address) checkFields { 346 return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 347 { 348 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 349 groupAddress: addr, 350 }, 351 }}}} 352 }, 353 }, 354 } 355 356 for _, test := range tests { 357 t.Run(test.name, func(t *testing.T) { 358 for _, subTest := range subTests { 359 t.Run(subTest.name, func(t *testing.T) { 360 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} 361 clock := faketime.NewManualClock() 362 363 mgp.init(ip.GenericMulticastProtocolOptions{ 364 Rand: rand.New(rand.NewSource(0)), 365 Clock: clock, 366 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 367 }, subTest.v1Compatibility) 368 369 // Joining a group should send a report immediately and another after 370 // a random interval between 0 and the maximum unsolicited report delay. 371 mgp.joinGroup(test.addr) 372 if test.shouldSendReports { 373 expected := subTest.checkFields(test.addr) 374 if diff := mgp.check(expected); diff != "" { 375 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 376 } 377 378 // Generic multicast protocol timers are expected to take the job mutex. 379 clock.Advance(maxUnsolicitedReportDelay) 380 if diff := mgp.check(expected); diff != "" { 381 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 382 } 383 } 384 385 // Should have no more messages to send. 386 clock.Advance(time.Hour) 387 if diff := mgp.check(checkFields{}); diff != "" { 388 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 389 } 390 }) 391 } 392 }) 393 } 394 } 395 396 func TestLeaveGroup(t *testing.T) { 397 const maxRespCode = 1 398 399 tests := []struct { 400 name string 401 addr tcpip.Address 402 shouldSendMessages bool 403 }{ 404 { 405 name: "Normal group", 406 addr: addr1, 407 shouldSendMessages: true, 408 }, 409 { 410 name: "All-nodes group", 411 addr: addr2, 412 shouldSendMessages: false, 413 }, 414 } 415 416 subTests := []struct { 417 name string 418 v1Compatibility bool 419 checkFields func(tcpip.Address, bool) checkFields 420 handleQuery func(*mockMulticastGroupProtocol, tcpip.Address) 421 }{ 422 { 423 name: "V1 Compatibility", 424 v1Compatibility: true, 425 checkFields: func(addr tcpip.Address, leave bool) checkFields { 426 if leave { 427 return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}} 428 } 429 return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} 430 }, 431 handleQuery: func(mgp *mockMulticastGroupProtocol, groupAddress tcpip.Address) { 432 mgp.handleQuery(groupAddress, maxRespCode) 433 }, 434 }, 435 { 436 name: "V2", 437 v1Compatibility: false, 438 checkFields: func(addr tcpip.Address, leave bool) checkFields { 439 recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode 440 if leave { 441 recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode 442 } 443 444 return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 445 { 446 recordType: recordType, 447 groupAddress: addr, 448 }, 449 }}}} 450 }, 451 handleQuery: func(mgp *mockMulticastGroupProtocol, groupAddress tcpip.Address) { 452 mgp.handleQueryV2(groupAddress, maxRespCode, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0) 453 }, 454 }, 455 } 456 457 for _, test := range tests { 458 t.Run(test.name, func(t *testing.T) { 459 for _, subTest := range subTests { 460 t.Run(subTest.name, func(t *testing.T) { 461 for _, queryAddr := range []tcpip.Address{test.addr, tcpip.Address{}} { 462 t.Run(fmt.Sprintf("QueryAddr=%s", queryAddr), func(t *testing.T) { 463 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} 464 clock := faketime.NewManualClock() 465 466 mgp.init(ip.GenericMulticastProtocolOptions{ 467 Rand: rand.New(rand.NewSource(1)), 468 Clock: clock, 469 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 470 }, subTest.v1Compatibility) 471 472 mgp.joinGroup(test.addr) 473 if test.shouldSendMessages { 474 if diff := mgp.check(subTest.checkFields(test.addr, false /* leave */)); diff != "" { 475 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 476 } 477 } 478 479 // The timer scheduled to send the query response should do 480 // nothing since we will leave the group before the response is 481 // sent. 482 subTest.handleQuery(&mgp, queryAddr) 483 484 // Leaving a group should send a leave report immediately and 485 // cancel any delayed reports. 486 if !mgp.leaveGroup(test.addr) { 487 t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) 488 } 489 490 // A query should not do anything since we left the group. 491 subTest.handleQuery(&mgp, queryAddr) 492 493 if test.shouldSendMessages { 494 if diff := mgp.check(subTest.checkFields(test.addr, true /* leave */)); diff != "" { 495 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 496 } 497 498 if !subTest.v1Compatibility { 499 clock.Advance(maxUnsolicitedReportDelay) 500 501 if diff := mgp.check(subTest.checkFields(test.addr, true /* leave */)); diff != "" { 502 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 503 } 504 } 505 } 506 507 // Should have no more messages to send. 508 clock.Advance(time.Hour) 509 if diff := mgp.check(checkFields{}); diff != "" { 510 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 511 } 512 }) 513 } 514 }) 515 } 516 }) 517 } 518 } 519 520 func TestHandleReport(t *testing.T) { 521 tests := []struct { 522 name string 523 reportAddr tcpip.Address 524 expectReportsFor []tcpip.Address 525 }{ 526 { 527 name: "Unpecified empty", 528 reportAddr: tcpip.Address{}, 529 expectReportsFor: []tcpip.Address{addr1, addr2}, 530 }, 531 { 532 name: "Unpecified any", 533 reportAddr: tcpip.AddrFromSlice([]byte("\x00\x00\x00\x00")), 534 expectReportsFor: []tcpip.Address{addr1, addr2}, 535 }, 536 { 537 name: "Specified", 538 reportAddr: addr1, 539 expectReportsFor: []tcpip.Address{addr2}, 540 }, 541 { 542 name: "Specified all-nodes", 543 reportAddr: addr3, 544 expectReportsFor: []tcpip.Address{addr1, addr2}, 545 }, 546 { 547 name: "Specified other", 548 reportAddr: addr4, 549 expectReportsFor: []tcpip.Address{addr1, addr2}, 550 }, 551 } 552 553 subTests := []struct { 554 name string 555 v1Compatibility bool 556 checkFields func([]tcpip.Address) checkFields 557 }{ 558 { 559 name: "V1 Compatibility", 560 v1Compatibility: true, 561 checkFields: func(addrs []tcpip.Address) checkFields { 562 return checkFields{sendReportGroupAddresses: addrs} 563 }, 564 }, 565 { 566 name: "V2", 567 v1Compatibility: false, 568 checkFields: func(addrs []tcpip.Address) checkFields { 569 var records []mockReportV2Record 570 for _, addr := range addrs { 571 records = append(records, mockReportV2Record{ 572 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 573 groupAddress: addr, 574 }) 575 } 576 577 return checkFields{sentV2Reports: []mockReportV2{{records: records}}} 578 }, 579 }, 580 } 581 582 for _, test := range tests { 583 t.Run(test.name, func(t *testing.T) { 584 for _, subTest := range subTests { 585 t.Run(subTest.name, func(t *testing.T) { 586 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} 587 clock := faketime.NewManualClock() 588 589 mgp.init(ip.GenericMulticastProtocolOptions{ 590 Rand: rand.New(rand.NewSource(2)), 591 Clock: clock, 592 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 593 }, subTest.v1Compatibility) 594 595 mgp.joinGroup(addr1) 596 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1})); diff != "" { 597 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 598 } 599 mgp.joinGroup(addr2) 600 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr2})); diff != "" { 601 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 602 } 603 mgp.joinGroup(addr3) 604 if diff := mgp.check(checkFields{}); diff != "" { 605 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 606 } 607 608 // Receiving a report for a group we have a timer scheduled for should 609 // cancel our delayed report timer for the group. 610 mgp.handleReport(test.reportAddr) 611 if len(test.expectReportsFor) != 0 { 612 // Generic multicast protocol timers are expected to take the job mutex. 613 clock.Advance(maxUnsolicitedReportDelay) 614 if diff := mgp.check(subTest.checkFields(test.expectReportsFor)); diff != "" { 615 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 616 } 617 } 618 619 // Should have no more messages to send. 620 clock.Advance(time.Hour) 621 if diff := mgp.check(checkFields{}); diff != "" { 622 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 623 } 624 }) 625 } 626 }) 627 } 628 } 629 630 func TestHandleQuery(t *testing.T) { 631 tests := []struct { 632 name string 633 queryAddr tcpip.Address 634 maxDelay time.Duration 635 expectQueriedReportsFor []tcpip.Address 636 expectDelayedReportsFor []tcpip.Address 637 }{ 638 { 639 name: "Unpecified empty", 640 queryAddr: tcpip.Address{}, 641 maxDelay: 0, 642 expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, 643 expectDelayedReportsFor: nil, 644 }, 645 { 646 name: "Unpecified any", 647 queryAddr: tcpip.AddrFromSlice([]byte("\x00\x00\x00\x00")), 648 maxDelay: 1, 649 expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, 650 expectDelayedReportsFor: nil, 651 }, 652 { 653 name: "Specified", 654 queryAddr: addr1, 655 maxDelay: 2, 656 expectQueriedReportsFor: []tcpip.Address{addr1}, 657 expectDelayedReportsFor: []tcpip.Address{addr2}, 658 }, 659 { 660 name: "Specified all-nodes", 661 queryAddr: addr3, 662 maxDelay: 3, 663 expectQueriedReportsFor: nil, 664 expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, 665 }, 666 { 667 name: "Specified other", 668 queryAddr: addr4, 669 maxDelay: 4, 670 expectQueriedReportsFor: nil, 671 expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, 672 }, 673 } 674 675 subTests := []struct { 676 name string 677 v1Compatibility bool 678 checkFields func([]tcpip.Address) checkFields 679 }{ 680 { 681 name: "V1 Compatibility", 682 v1Compatibility: true, 683 checkFields: func(addrs []tcpip.Address) checkFields { 684 return checkFields{sendReportGroupAddresses: addrs} 685 }, 686 }, 687 { 688 name: "V2", 689 v1Compatibility: false, 690 checkFields: func(addrs []tcpip.Address) checkFields { 691 var records []mockReportV2Record 692 for _, addr := range addrs { 693 records = append(records, mockReportV2Record{ 694 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 695 groupAddress: addr, 696 }) 697 } 698 699 return checkFields{sentV2Reports: []mockReportV2{{records: records}}} 700 }, 701 }, 702 } 703 704 for _, test := range tests { 705 t.Run(test.name, func(t *testing.T) { 706 for _, subTest := range subTests { 707 t.Run(subTest.name, func(t *testing.T) { 708 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} 709 clock := faketime.NewManualClock() 710 711 mgp.init(ip.GenericMulticastProtocolOptions{ 712 Rand: rand.New(rand.NewSource(3)), 713 Clock: clock, 714 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 715 }, subTest.v1Compatibility) 716 717 mgp.joinGroup(addr1) 718 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1})); diff != "" { 719 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 720 } 721 mgp.joinGroup(addr2) 722 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr2})); diff != "" { 723 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 724 } 725 mgp.joinGroup(addr3) 726 if diff := mgp.check(checkFields{}); diff != "" { 727 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 728 } 729 730 // Receiving a query should make us reschedule our delayed report timer 731 // to some time within the new max response delay. 732 mgp.handleQuery(test.queryAddr, test.maxDelay) 733 clock.Advance(test.maxDelay) 734 if diff := mgp.check(checkFields{sendReportGroupAddresses: test.expectQueriedReportsFor}); diff != "" { 735 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 736 } 737 738 // The groups that were not affected by the query should still send a 739 // report after the max unsolicited report delay. 740 // 741 // If we were in V2 mode, then we would have cancelled the interface's 742 // state changed timer so we won't see any further reports after 743 // receiving a V1 query. 744 if subTest.v1Compatibility { 745 clock.Advance(maxUnsolicitedReportDelay) 746 if diff := mgp.check(subTest.checkFields(test.expectDelayedReportsFor)); diff != "" { 747 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 748 } 749 } 750 751 // Should have no more messages to send. 752 clock.Advance(time.Hour) 753 if diff := mgp.check(checkFields{}); diff != "" { 754 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 755 } 756 }) 757 } 758 }) 759 } 760 } 761 762 func TestHandleQueryV2Response(t *testing.T) { 763 tests := []struct { 764 name string 765 queryAddr tcpip.Address 766 maxDelay uint16 767 expectQueriedReportsFor []tcpip.Address 768 expectDelayedReportsFor []tcpip.Address 769 }{ 770 { 771 name: "Unpecified empty", 772 queryAddr: tcpip.Address{}, 773 maxDelay: 0, 774 expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, 775 expectDelayedReportsFor: nil, 776 }, 777 { 778 name: "Unpecified any", 779 queryAddr: tcpip.AddrFromSlice([]byte("\x00\x00\x00\x00")), 780 maxDelay: 1, 781 expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, 782 expectDelayedReportsFor: nil, 783 }, 784 { 785 name: "Specified", 786 queryAddr: addr1, 787 maxDelay: 2, 788 expectQueriedReportsFor: []tcpip.Address{addr1}, 789 expectDelayedReportsFor: []tcpip.Address{addr2}, 790 }, 791 { 792 name: "Specified all-nodes", 793 queryAddr: addr3, 794 maxDelay: 3, 795 expectQueriedReportsFor: nil, 796 expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, 797 }, 798 { 799 name: "Specified other", 800 queryAddr: addr4, 801 maxDelay: 4, 802 expectQueriedReportsFor: nil, 803 expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, 804 }, 805 } 806 807 subTests := []struct { 808 name string 809 v1Compatibility bool 810 checkFields func([]tcpip.Address, bool) checkFields 811 }{ 812 { 813 name: "V1 Compatibility", 814 v1Compatibility: true, 815 checkFields: func(addrs []tcpip.Address, _ bool) checkFields { 816 return checkFields{sendReportGroupAddresses: addrs} 817 }, 818 }, 819 { 820 name: "V2", 821 v1Compatibility: false, 822 checkFields: func(addrs []tcpip.Address, queryResponse bool) checkFields { 823 var records []mockReportV2Record 824 recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode 825 if queryResponse { 826 recordType = ip.MulticastGroupProtocolV2ReportRecordModeIsExclude 827 } 828 829 for _, addr := range addrs { 830 records = append(records, mockReportV2Record{ 831 recordType: recordType, 832 groupAddress: addr, 833 }) 834 } 835 836 return checkFields{sentV2Reports: []mockReportV2{{records: records}}} 837 }, 838 }, 839 } 840 841 for _, test := range tests { 842 t.Run(test.name, func(t *testing.T) { 843 for _, subTest := range subTests { 844 t.Run(subTest.name, func(t *testing.T) { 845 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} 846 clock := faketime.NewManualClock() 847 848 mgp.init(ip.GenericMulticastProtocolOptions{ 849 Rand: rand.New(rand.NewSource(3)), 850 Clock: clock, 851 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 852 }, subTest.v1Compatibility) 853 854 mgp.joinGroup(addr1) 855 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1}, false /* queryResponse */)); diff != "" { 856 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 857 } 858 mgp.joinGroup(addr2) 859 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr2}, false /* queryResponse */)); diff != "" { 860 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 861 } 862 mgp.joinGroup(addr3) 863 if diff := mgp.check(checkFields{}); diff != "" { 864 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 865 } 866 clock.Advance(maxUnsolicitedReportDelay) 867 if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1, addr2}, false /* queryResponse */)); diff != "" { 868 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 869 } 870 clock.Advance(maxUnsolicitedReportDelay) 871 if diff := mgp.check(checkFields{}); diff != "" { 872 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 873 } 874 875 // Receiving a query should make us reschedule our delayed report 876 // timer to some time within the new max response delay. 877 // 878 // Note that if we are in V1 compatibility mode, the V2 query will be 879 // handled as a V1 query. 880 mgp.handleQueryV2(test.queryAddr, test.maxDelay, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0) 881 if subTest.v1Compatibility { 882 clock.Advance(mgp.V2QueryMaxRespCodeToV1Delay(test.maxDelay)) 883 } else { 884 clock.Advance(mgp.V2QueryMaxRespCodeToV2Delay(test.maxDelay)) 885 } 886 if diff := mgp.check(subTest.checkFields(test.expectQueriedReportsFor, true /* queryResponse */)); diff != "" { 887 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 888 } 889 890 // Should have no more messages to send. 891 clock.Advance(time.Hour) 892 if diff := mgp.check(checkFields{}); diff != "" { 893 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 894 } 895 }) 896 } 897 }) 898 } 899 } 900 901 func TestV1CompatbilityModeTimer(t *testing.T) { 902 tests := []struct { 903 name string 904 robustnessVariable uint8 905 queryInterval time.Duration 906 }{ 907 { 908 name: "Unspecified Robustness variable and Query interval", 909 robustnessVariable: 0, 910 queryInterval: 0, 911 }, 912 { 913 name: "Unspecified Robustness variable", 914 robustnessVariable: 0, 915 queryInterval: ip.DefaultQueryInterval + time.Second, 916 }, 917 { 918 name: "Unspecified Query interval", 919 robustnessVariable: ip.DefaultRobustnessVariable + 1, 920 queryInterval: 0, 921 }, 922 { 923 name: "Default Robustness variable and Query interval", 924 robustnessVariable: ip.DefaultRobustnessVariable, 925 queryInterval: ip.DefaultQueryInterval, 926 }, 927 { 928 name: "Specified Robustness variable and Query interval", 929 robustnessVariable: ip.DefaultRobustnessVariable + 1, 930 queryInterval: ip.DefaultQueryInterval + time.Second, 931 }, 932 } 933 934 for _, test := range tests { 935 t.Run(test.name, func(t *testing.T) { 936 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} 937 clock := faketime.NewManualClock() 938 939 mgp.init(ip.GenericMulticastProtocolOptions{ 940 Rand: rand.New(rand.NewSource(3)), 941 Clock: clock, 942 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 943 }, false /* v1Compatibiltiy */) 944 945 v2Check := func(t *testing.T) { 946 t.Helper() 947 948 mgp.joinGroup(addr1) 949 if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 950 { 951 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 952 groupAddress: addr1, 953 }, 954 }}}}); diff != "" { 955 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 956 } 957 if !mgp.leaveGroup(addr1) { 958 t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", addr1) 959 } 960 if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 961 { 962 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode, 963 groupAddress: addr1, 964 }, 965 }}}}); diff != "" { 966 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 967 } 968 } 969 v2Check(t) 970 971 subTests := []struct { 972 name string 973 advanceTime time.Duration 974 }{ 975 { 976 name: "Default", 977 advanceTime: ip.DefaultRobustnessVariable * ip.DefaultQueryInterval, 978 }, 979 { 980 name: "After V2 Query", 981 advanceTime: func() time.Duration { 982 robustnessVariable := test.robustnessVariable 983 if robustnessVariable == 0 { 984 robustnessVariable = ip.DefaultRobustnessVariable 985 } 986 987 queryInterval := test.queryInterval 988 if queryInterval == 0 { 989 queryInterval = ip.DefaultQueryInterval 990 } 991 992 return time.Duration(robustnessVariable) * queryInterval 993 }(), 994 }, 995 } 996 997 for _, subTest := range subTests { 998 t.Run(subTest.name, func(t *testing.T) { 999 mgp.handleQuery(addr3, time.Nanosecond) 1000 v1Check := func() { 1001 t.Helper() 1002 mgp.joinGroup(addr1) 1003 if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" { 1004 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1005 } 1006 if !mgp.leaveGroup(addr1) { 1007 t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", addr1) 1008 } 1009 if diff := mgp.check(checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr1}}); diff != "" { 1010 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1011 } 1012 } 1013 v1Check() 1014 const minDuration = time.Duration(1) 1015 clock.Advance(subTest.advanceTime - minDuration) 1016 v1Check() 1017 1018 clock.Advance(minDuration) 1019 v2Check(t) 1020 // Should update the Robustness variable and Querier's Query interval. 1021 mgp.handleQueryV2(addr3, 0, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), test.robustnessVariable, test.queryInterval) 1022 }) 1023 } 1024 }) 1025 } 1026 } 1027 1028 func TestJoinCount(t *testing.T) { 1029 const maxUnsolicitedReportDelay = time.Second 1030 1031 tests := []struct { 1032 name string 1033 v1Compatibility bool 1034 checkFields func(tcpip.Address, bool) checkFields 1035 }{ 1036 { 1037 name: "V1 Compatibility", 1038 v1Compatibility: true, 1039 checkFields: func(addr tcpip.Address, leave bool) checkFields { 1040 if leave { 1041 return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}} 1042 } 1043 return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} 1044 }, 1045 }, 1046 { 1047 name: "V2", 1048 v1Compatibility: false, 1049 checkFields: func(addr tcpip.Address, leave bool) checkFields { 1050 recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode 1051 if leave { 1052 recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode 1053 } 1054 1055 return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 1056 { 1057 recordType: recordType, 1058 groupAddress: addr, 1059 }, 1060 }}}} 1061 }, 1062 }, 1063 } 1064 1065 for _, test := range tests { 1066 t.Run(test.name, func(t *testing.T) { 1067 mgp := mockMulticastGroupProtocol{t: t} 1068 clock := faketime.NewManualClock() 1069 1070 mgp.init(ip.GenericMulticastProtocolOptions{ 1071 Rand: rand.New(rand.NewSource(4)), 1072 Clock: clock, 1073 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 1074 }, test.v1Compatibility) 1075 1076 // Set the join count to 2 for a group. 1077 mgp.joinGroup(addr1) 1078 if !mgp.isLocallyJoined(addr1) { 1079 t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) 1080 } 1081 // Only the first join should trigger a report to be sent. 1082 if diff := mgp.check(test.checkFields(addr1, false /* leave */)); diff != "" { 1083 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1084 } 1085 mgp.joinGroup(addr1) 1086 if !mgp.isLocallyJoined(addr1) { 1087 t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) 1088 } 1089 if diff := mgp.check(checkFields{}); diff != "" { 1090 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1091 } 1092 if t.Failed() { 1093 t.FailNow() 1094 } 1095 1096 // Group should still be considered joined after leaving once. 1097 if !mgp.leaveGroup(addr1) { 1098 t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) 1099 } 1100 if !mgp.isLocallyJoined(addr1) { 1101 t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) 1102 } 1103 // A leave report should only be sent once the join count reaches 0. 1104 if diff := mgp.check(checkFields{}); diff != "" { 1105 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1106 } 1107 if t.Failed() { 1108 t.FailNow() 1109 } 1110 1111 // Leaving once more should actually remove us from the group. 1112 if !mgp.leaveGroup(addr1) { 1113 t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) 1114 } 1115 if mgp.isLocallyJoined(addr1) { 1116 t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) 1117 } 1118 if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" { 1119 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1120 } 1121 if !test.v1Compatibility { 1122 // V2 should still have a queued state-changed report. 1123 clock.Advance(maxUnsolicitedReportDelay) 1124 if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" { 1125 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1126 } 1127 } 1128 if t.Failed() { 1129 t.FailNow() 1130 } 1131 1132 // Group should no longer be joined so we should not have anything to 1133 // leave. 1134 if mgp.leaveGroup(addr1) { 1135 t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) 1136 } 1137 if mgp.isLocallyJoined(addr1) { 1138 t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) 1139 } 1140 if diff := mgp.check(checkFields{}); diff != "" { 1141 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1142 } 1143 1144 // Should have no more messages to send. 1145 // 1146 // Generic multicast protocol timers are expected to take the job mutex. 1147 clock.Advance(time.Hour) 1148 if diff := mgp.check(checkFields{}); diff != "" { 1149 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1150 } 1151 }) 1152 } 1153 } 1154 1155 func TestMakeAllNonMemberAndInitialize(t *testing.T) { 1156 const unsolicitedTransmissionCount = 2 1157 1158 tests := []struct { 1159 name string 1160 v1 bool 1161 v1Compatibility bool 1162 checkFields func([]tcpip.Address, bool) checkFields 1163 }{ 1164 { 1165 name: "V1", 1166 v1: true, 1167 v1Compatibility: false, 1168 checkFields: func(addrs []tcpip.Address, leave bool) checkFields { 1169 if leave { 1170 return checkFields{sendLeaveGroupAddresses: addrs} 1171 } 1172 return checkFields{sendReportGroupAddresses: addrs} 1173 }, 1174 }, 1175 { 1176 name: "V1 Compatibility", 1177 v1: false, 1178 v1Compatibility: true, 1179 checkFields: func(addrs []tcpip.Address, leave bool) checkFields { 1180 if leave { 1181 return checkFields{sendLeaveGroupAddresses: addrs} 1182 } 1183 return checkFields{sendReportGroupAddresses: addrs} 1184 }, 1185 }, 1186 { 1187 name: "V2", 1188 v1: false, 1189 v1Compatibility: false, 1190 checkFields: func(addrs []tcpip.Address, leave bool) checkFields { 1191 recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode 1192 if leave { 1193 recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode 1194 } 1195 var records []mockReportV2Record 1196 for _, addr := range addrs { 1197 records = append(records, mockReportV2Record{ 1198 recordType: recordType, 1199 groupAddress: addr, 1200 }) 1201 } 1202 1203 return checkFields{sentV2Reports: []mockReportV2{{records: records}}} 1204 }, 1205 }, 1206 } 1207 1208 for _, test := range tests { 1209 t.Run(test.name, func(t *testing.T) { 1210 mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} 1211 clock := faketime.NewManualClock() 1212 1213 mgp.init(ip.GenericMulticastProtocolOptions{ 1214 Rand: rand.New(rand.NewSource(3)), 1215 Clock: clock, 1216 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 1217 }, test.v1) 1218 1219 if test.v1Compatibility { 1220 // V1 query targeting an unjoined group should drop us into V1 1221 // compatibility mode without sending any packets, affecting tests. 1222 mgp.handleQuery(addr3, 0) 1223 } 1224 1225 mgp.joinGroup(addr1) 1226 if diff := mgp.check(test.checkFields([]tcpip.Address{addr1}, false /* leave */)); diff != "" { 1227 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1228 } 1229 mgp.joinGroup(addr2) 1230 if diff := mgp.check(test.checkFields([]tcpip.Address{addr2}, false /* leave */)); diff != "" { 1231 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1232 } 1233 mgp.joinGroup(addr3) 1234 if diff := mgp.check(checkFields{}); diff != "" { 1235 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1236 } 1237 1238 // Should send the leave reports for each but still consider them locally 1239 // joined. 1240 mgp.makeAllNonMember() 1241 if diff := mgp.check(test.checkFields([]tcpip.Address{addr1, addr2}, true /* leave */)); diff != "" { 1242 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1243 } 1244 1245 // Generic multicast protocol timers are expected to take the job mutex. 1246 clock.Advance(time.Hour) 1247 if diff := mgp.check(checkFields{}); diff != "" { 1248 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1249 } 1250 for _, group := range []tcpip.Address{addr1, addr2, addr3} { 1251 if !mgp.isLocallyJoined(group) { 1252 t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) 1253 } 1254 } 1255 1256 // Should send the initial set of unsolcited V2 reports. 1257 mgp.initializeGroups() 1258 for i := 0; i < unsolicitedTransmissionCount; i++ { 1259 if test.v1 { 1260 if diff := mgp.check(test.checkFields([]tcpip.Address{addr1, addr2}, false /* leave */)); diff != "" { 1261 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1262 } 1263 } else { 1264 if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{ 1265 { 1266 records: []mockReportV2Record{ 1267 { 1268 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 1269 groupAddress: addr1, 1270 }, 1271 }, 1272 }, 1273 { 1274 records: []mockReportV2Record{ 1275 { 1276 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 1277 groupAddress: addr2, 1278 }, 1279 }, 1280 }, 1281 }}); diff != "" { 1282 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1283 } 1284 } 1285 clock.Advance(maxUnsolicitedReportDelay) 1286 } 1287 1288 // Should have no more messages to send. 1289 clock.Advance(time.Hour) 1290 if diff := mgp.check(checkFields{}); diff != "" { 1291 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1292 } 1293 1294 if got := mgp.getV1Mode(); got != test.v1 { 1295 t.Errorf("got mgp.getV1Mode() = %t, want = %t", got, test.v1) 1296 } 1297 }) 1298 } 1299 } 1300 1301 // TestGroupStateNonMember tests that groups do not send packets when in the 1302 // non-member state, but are still considered locally joined. 1303 func TestGroupStateNonMember(t *testing.T) { 1304 tests := []struct { 1305 name string 1306 v1Compatibility bool 1307 checkFields func([]tcpip.Address, bool) checkFields 1308 }{ 1309 { 1310 name: "V1 Compatibility", 1311 v1Compatibility: true, 1312 checkFields: func(addrs []tcpip.Address, leave bool) checkFields { 1313 if leave { 1314 return checkFields{sendLeaveGroupAddresses: addrs} 1315 } 1316 return checkFields{sendReportGroupAddresses: addrs} 1317 }, 1318 }, 1319 { 1320 name: "V2", 1321 v1Compatibility: false, 1322 checkFields: func(addrs []tcpip.Address, leave bool) checkFields { 1323 recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode 1324 if leave { 1325 recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode 1326 } 1327 var records []mockReportV2Record 1328 for _, addr := range addrs { 1329 records = append(records, mockReportV2Record{ 1330 recordType: recordType, 1331 groupAddress: addr, 1332 }) 1333 } 1334 1335 return checkFields{sentV2Reports: []mockReportV2{{records: records}}} 1336 }, 1337 }, 1338 } 1339 1340 for _, test := range tests { 1341 t.Run(test.name, func(t *testing.T) { 1342 mgp := mockMulticastGroupProtocol{t: t} 1343 clock := faketime.NewManualClock() 1344 1345 mgp.init(ip.GenericMulticastProtocolOptions{ 1346 Rand: rand.New(rand.NewSource(3)), 1347 Clock: clock, 1348 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 1349 }, test.v1Compatibility) 1350 mgp.setEnabled(false) 1351 1352 // Joining groups should not send any reports. 1353 mgp.joinGroup(addr1) 1354 if !mgp.isLocallyJoined(addr1) { 1355 t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) 1356 } 1357 if diff := mgp.check(checkFields{}); diff != "" { 1358 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1359 } 1360 mgp.joinGroup(addr2) 1361 if !mgp.isLocallyJoined(addr1) { 1362 t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) 1363 } 1364 if diff := mgp.check(checkFields{}); diff != "" { 1365 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1366 } 1367 1368 // Receiving a query should not send any reports. 1369 mgp.handleQuery(addr1, time.Nanosecond) 1370 // Generic multicast protocol timers are expected to take the job mutex. 1371 clock.Advance(time.Nanosecond) 1372 if diff := mgp.check(checkFields{}); diff != "" { 1373 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1374 } 1375 1376 // Leaving groups should not send any leave messages. 1377 if !mgp.leaveGroup(addr1) { 1378 t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) 1379 } 1380 if mgp.isLocallyJoined(addr1) { 1381 t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) 1382 } 1383 if diff := mgp.check(checkFields{}); diff != "" { 1384 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1385 } 1386 1387 clock.Advance(time.Hour) 1388 if diff := mgp.check(checkFields{}); diff != "" { 1389 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1390 } 1391 }) 1392 } 1393 } 1394 1395 // TestMakeAllNonMemberCancelsDelayedReportJob tests that the delayed report job 1396 // is cancelled on MakeAllNonMember, otherwise the job will panic if the endpoint 1397 // is disabled. 1398 func TestMakeAllNonMemberCancelsDelayedReportJob(t *testing.T) { 1399 const maxRespCode = 1 1400 1401 tests := []struct { 1402 name string 1403 v1 bool 1404 v1Compatibility bool 1405 checkFields func(tcpip.Address, bool) checkFields 1406 }{ 1407 { 1408 name: "V1", 1409 v1: true, 1410 v1Compatibility: false, 1411 checkFields: func(addr tcpip.Address, leave bool) checkFields { 1412 if leave { 1413 return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}} 1414 } 1415 return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} 1416 }, 1417 }, 1418 { 1419 name: "V1 Compatibility", 1420 v1: false, 1421 v1Compatibility: true, 1422 checkFields: func(addr tcpip.Address, leave bool) checkFields { 1423 if leave { 1424 return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}} 1425 } 1426 return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} 1427 }, 1428 }, 1429 { 1430 name: "V2", 1431 v1: false, 1432 v1Compatibility: false, 1433 checkFields: func(addr tcpip.Address, leave bool) checkFields { 1434 recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode 1435 if leave { 1436 recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode 1437 } 1438 return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{mockReportV2Record{ 1439 recordType: recordType, 1440 groupAddress: addr, 1441 }}}}} 1442 }, 1443 }, 1444 } 1445 1446 for _, test := range tests { 1447 t.Run(test.name, func(t *testing.T) { 1448 mgp := mockMulticastGroupProtocol{t: t} 1449 clock := faketime.NewManualClock() 1450 1451 mgp.init(ip.GenericMulticastProtocolOptions{ 1452 Rand: rand.New(rand.NewSource(3)), 1453 Clock: clock, 1454 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 1455 }, test.v1) 1456 1457 if test.v1Compatibility { 1458 // V1 query targeting an unjoined group should drop us into V1 1459 // compatibility mode without sending any packets, affecting tests. 1460 mgp.handleQuery(addr3, 0) 1461 } 1462 1463 mgp.joinGroup(addr1) 1464 if diff := mgp.check(test.checkFields(addr1, false /* leave */)); diff != "" { 1465 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1466 } 1467 1468 // Handle a query so that the delayed report job is scheduled when operating 1469 // in V2 mode. 1470 mgp.handleQueryV2(addr1, maxRespCode, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0) 1471 1472 mgp.makeAllNonMember() 1473 if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" { 1474 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1475 } 1476 1477 mgp.setEnabled(false) 1478 1479 // Generic multicast protocol timers are expected to take the job mutex. 1480 // 1481 // Advance the clock to after the delayed report job is supposed to fire. 1482 // If the delayed report job isn't cancelled by the MakeAllNonMember call, 1483 // it will panic due to the expectation that the protocol is enabled. 1484 if test.v1 || test.v1Compatibility { 1485 clock.Advance(mgp.V2QueryMaxRespCodeToV1Delay(maxRespCode)) 1486 } else { 1487 clock.Advance(mgp.V2QueryMaxRespCodeToV2Delay(maxRespCode)) 1488 } 1489 if diff := mgp.check(checkFields{}); diff != "" { 1490 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1491 } 1492 }) 1493 } 1494 } 1495 1496 func TestQueuedPackets(t *testing.T) { 1497 tests := []struct { 1498 name string 1499 v1Compatibility bool 1500 checkFields func(tcpip.Address) checkFields 1501 }{ 1502 { 1503 name: "V1 Compatibility", 1504 v1Compatibility: true, 1505 checkFields: func(addr tcpip.Address) checkFields { 1506 return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} 1507 }, 1508 }, 1509 { 1510 name: "V2", 1511 v1Compatibility: false, 1512 checkFields: func(addr tcpip.Address) checkFields { 1513 return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 1514 { 1515 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 1516 groupAddress: addr, 1517 }, 1518 }}}} 1519 }, 1520 }, 1521 } 1522 1523 for _, test := range tests { 1524 t.Run(test.name, func(t *testing.T) { 1525 clock := faketime.NewManualClock() 1526 mgp := mockMulticastGroupProtocol{t: t} 1527 mgp.init(ip.GenericMulticastProtocolOptions{ 1528 Rand: rand.New(rand.NewSource(4)), 1529 Clock: clock, 1530 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 1531 }, test.v1Compatibility) 1532 1533 // Joining should trigger a SendReport, but mgp should report that we did not 1534 // send the packet. 1535 mgp.setQueuePackets(true) 1536 mgp.joinGroup(addr1) 1537 if diff := mgp.check(test.checkFields(addr1)); diff != "" { 1538 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1539 } 1540 1541 // The delayed report timer should have been cancelled since we did not send 1542 // the initial report earlier. 1543 clock.Advance(time.Hour) 1544 if diff := mgp.check(checkFields{}); diff != "" { 1545 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1546 } 1547 1548 // Mock being able to successfully send the report. 1549 mgp.setQueuePackets(false) 1550 mgp.sendQueuedReports() 1551 if diff := mgp.check(test.checkFields(addr1)); diff != "" { 1552 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1553 } 1554 1555 // The delayed report (sent after the initial report) should now be sent. 1556 clock.Advance(maxUnsolicitedReportDelay) 1557 if diff := mgp.check(test.checkFields(addr1)); diff != "" { 1558 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1559 } 1560 1561 // Should not have anything else to send (we should be idle). 1562 mgp.sendQueuedReports() 1563 clock.Advance(time.Hour) 1564 if diff := mgp.check(checkFields{}); diff != "" { 1565 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1566 } 1567 1568 // Receive a query but mock being unable to send reports again. 1569 mgp.setQueuePackets(true) 1570 mgp.handleQuery(addr1, time.Nanosecond) 1571 clock.Advance(time.Nanosecond) 1572 if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" { 1573 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1574 } 1575 1576 // Mock being able to send reports again - we should have a packet queued to 1577 // send. 1578 mgp.setQueuePackets(false) 1579 mgp.sendQueuedReports() 1580 if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" { 1581 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1582 } 1583 1584 // Should not have anything else to send. 1585 mgp.sendQueuedReports() 1586 clock.Advance(time.Hour) 1587 if diff := mgp.check(checkFields{}); diff != "" { 1588 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1589 } 1590 1591 // Receive a query again, but mock being unable to send reports. 1592 mgp.setQueuePackets(true) 1593 mgp.handleQuery(addr1, time.Nanosecond) 1594 clock.Advance(time.Nanosecond) 1595 if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" { 1596 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1597 } 1598 1599 // Receiving a report should transition us into the idle member state, 1600 // even if we had a packet queued. We should no longer have any packets to 1601 // send. 1602 mgp.handleReport(addr1) 1603 mgp.sendQueuedReports() 1604 if diff := mgp.check(checkFields{}); diff != "" { 1605 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1606 } 1607 1608 // When we fail to send the initial set of reports, incoming reports should 1609 // prevent a newly joined group's reports from being sent. 1610 mgp.setQueuePackets(true) 1611 mgp.joinGroup(addr2) 1612 if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" { 1613 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1614 } 1615 mgp.handleReport(addr2) 1616 // Attempting to send queued reports while still unable to send reports should 1617 // not change the host state. 1618 mgp.sendQueuedReports() 1619 if diff := mgp.check(checkFields{}); diff != "" { 1620 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1621 } 1622 // Should not have any packets queued. 1623 mgp.setQueuePackets(false) 1624 mgp.sendQueuedReports() 1625 clock.Advance(time.Hour) 1626 if diff := mgp.check(checkFields{}); diff != "" { 1627 t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1628 } 1629 }) 1630 } 1631 } 1632 1633 func TestGetSetV1Mode(t *testing.T) { 1634 clock := faketime.NewManualClock() 1635 mgp := mockMulticastGroupProtocol{t: t} 1636 mgp.init(ip.GenericMulticastProtocolOptions{ 1637 Rand: rand.New(rand.NewSource(4)), 1638 Clock: clock, 1639 MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, 1640 }, false /* v1Compatibility */) 1641 1642 if mgp.getV1Mode() { 1643 t.Error("got mgp.getV1Mode() = true, want = false") 1644 } 1645 1646 mgp.joinGroup(addr1) 1647 if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 1648 { 1649 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 1650 groupAddress: addr1, 1651 }, 1652 }}}}); diff != "" { 1653 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1654 } 1655 1656 if mgp.setV1Mode(true) { 1657 t.Error("got mgp.setV1Mode(true) = true, want = false") 1658 } 1659 if !mgp.getV1Mode() { 1660 t.Error("got mgp.getV1Mode() = false, want = true") 1661 } 1662 mgp.joinGroup(addr2) 1663 if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" { 1664 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1665 } 1666 1667 if !mgp.setV1Mode(false) { 1668 t.Error("got mgp.setV1Mode(false) = false, want = true") 1669 } 1670 if mgp.getV1Mode() { 1671 t.Error("got mgp.getV1Mode() = true, want = false") 1672 } 1673 mgp.joinGroup(addr3) 1674 if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{ 1675 { 1676 recordType: ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode, 1677 groupAddress: addr3, 1678 }, 1679 }}}}); diff != "" { 1680 t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) 1681 } 1682 }