gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ip_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"math/rand"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"gvisor.dev/gvisor/pkg/sync"
    26  	"gvisor.dev/gvisor/pkg/tcpip"
    27  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    28  	"gvisor.dev/gvisor/pkg/tcpip/header"
    29  	"gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
    30  )
    31  
    32  const maxUnsolicitedReportDelay = time.Second
    33  
    34  var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
    35  
    36  type mockMulticastGroupProtocolProtectedFields struct {
    37  	sync.RWMutex
    38  
    39  	genericMulticastGroup    ip.GenericMulticastProtocolState
    40  	sendReportGroupAddrCount map[tcpip.Address]int
    41  	sendLeaveGroupAddrCount  map[tcpip.Address]int
    42  	makeQueuePackets         bool
    43  	disabled                 bool
    44  	sentV2Reports            map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType
    45  }
    46  
    47  type mockMulticastGroupProtocol struct {
    48  	t *testing.T
    49  
    50  	skipProtocolAddress tcpip.Address
    51  
    52  	mu mockMulticastGroupProtocolProtectedFields
    53  }
    54  
    55  func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions, v1Compatibility bool) {
    56  	m.mu.Lock()
    57  	defer m.mu.Unlock()
    58  	m.initLocked()
    59  	opts.Protocol = m
    60  	m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
    61  
    62  	if v1Compatibility {
    63  		m.mu.genericMulticastGroup.SetV1ModeLocked(true)
    64  	}
    65  }
    66  
    67  func (m *mockMulticastGroupProtocol) initLocked() {
    68  	m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
    69  	m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
    70  	m.mu.sentV2Reports = make(map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType)
    71  }
    72  
    73  func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
    74  	m.mu.Lock()
    75  	defer m.mu.Unlock()
    76  	m.mu.disabled = !v
    77  }
    78  
    79  func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
    80  	m.mu.Lock()
    81  	defer m.mu.Unlock()
    82  	m.mu.makeQueuePackets = v
    83  }
    84  
    85  func (m *mockMulticastGroupProtocol) setV1Mode(v bool) bool {
    86  	m.mu.Lock()
    87  	defer m.mu.Unlock()
    88  	return m.mu.genericMulticastGroup.SetV1ModeLocked(v)
    89  }
    90  
    91  func (m *mockMulticastGroupProtocol) getV1Mode() bool {
    92  	m.mu.RLock()
    93  	defer m.mu.RUnlock()
    94  	return m.mu.genericMulticastGroup.GetV1ModeLocked()
    95  }
    96  
    97  func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
    98  	m.mu.Lock()
    99  	defer m.mu.Unlock()
   100  	m.mu.genericMulticastGroup.JoinGroupLocked(addr)
   101  }
   102  
   103  func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
   104  	m.mu.Lock()
   105  	defer m.mu.Unlock()
   106  	return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
   107  }
   108  
   109  func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
   110  	m.mu.Lock()
   111  	defer m.mu.Unlock()
   112  	m.mu.genericMulticastGroup.HandleReportLocked(addr)
   113  }
   114  
   115  func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
   116  	m.mu.Lock()
   117  	defer m.mu.Unlock()
   118  	m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
   119  }
   120  
   121  func (m *mockMulticastGroupProtocol) handleQueryV2(addr tcpip.Address, maxResponseCode uint16, sources header.AddressIterator, robustnessVariable uint8, queryInterval time.Duration) {
   122  	m.mu.Lock()
   123  	defer m.mu.Unlock()
   124  	m.mu.genericMulticastGroup.HandleQueryV2Locked(addr, maxResponseCode, sources, robustnessVariable, queryInterval)
   125  }
   126  
   127  func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
   128  	m.mu.RLock()
   129  	defer m.mu.RUnlock()
   130  	return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
   131  }
   132  
   133  func (m *mockMulticastGroupProtocol) makeAllNonMember() {
   134  	m.mu.Lock()
   135  	defer m.mu.Unlock()
   136  	m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
   137  }
   138  
   139  func (m *mockMulticastGroupProtocol) initializeGroups() {
   140  	m.mu.Lock()
   141  	defer m.mu.Unlock()
   142  	m.mu.genericMulticastGroup.InitializeGroupsLocked()
   143  }
   144  
   145  func (m *mockMulticastGroupProtocol) sendQueuedReports() {
   146  	m.mu.Lock()
   147  	defer m.mu.Unlock()
   148  	m.mu.genericMulticastGroup.SendQueuedReportsLocked()
   149  }
   150  
   151  // Enabled implements ip.MulticastGroupProtocol.
   152  //
   153  // Precondition: m.mu must be read locked.
   154  func (m *mockMulticastGroupProtocol) Enabled() bool {
   155  	if m.mu.TryLock() {
   156  		m.mu.Unlock() // +checklocksforce: TryLock.
   157  		m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
   158  	}
   159  
   160  	return !m.mu.disabled
   161  }
   162  
   163  // SendReport implements ip.MulticastGroupProtocol.
   164  //
   165  // Precondition: m.mu must be locked.
   166  func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
   167  	if m.mu.TryLock() {
   168  		m.mu.Unlock() // +checklocksforce: TryLock.
   169  		m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
   170  	}
   171  	if m.mu.TryRLock() {
   172  		m.mu.RUnlock() // +checklocksforce: TryLock.
   173  		m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
   174  	}
   175  
   176  	m.mu.sendReportGroupAddrCount[groupAddress]++
   177  	return !m.mu.makeQueuePackets, nil
   178  }
   179  
   180  // SendLeave implements ip.MulticastGroupProtocol.
   181  //
   182  // Precondition: m.mu must be locked.
   183  func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
   184  	if m.mu.TryLock() {
   185  		m.mu.Unlock() // +checklocksforce: TryLock.
   186  		m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
   187  	}
   188  	if m.mu.TryRLock() {
   189  		m.mu.RUnlock() // +checklocksforce: TryLock.
   190  		m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
   191  	}
   192  
   193  	m.mu.sendLeaveGroupAddrCount[groupAddress]++
   194  	return nil
   195  }
   196  
   197  // ShouldPerformProtocol implements ip.MulticastGroupProtocol.
   198  func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
   199  	return groupAddress != m.skipProtocolAddress
   200  }
   201  
   202  type mockReportV2Record struct {
   203  	recordType   ip.MulticastGroupProtocolV2ReportRecordType
   204  	groupAddress tcpip.Address
   205  }
   206  
   207  type mockReportV2 struct {
   208  	records []mockReportV2Record
   209  }
   210  
   211  type mockReportV2Builder struct {
   212  	m      *mockMulticastGroupProtocol
   213  	report mockReportV2
   214  }
   215  
   216  // AddRecord implements ip.MulticastGroupProtocolV2ReportBuilder.
   217  func (b *mockReportV2Builder) AddRecord(recordType ip.MulticastGroupProtocolV2ReportRecordType, groupAddress tcpip.Address) {
   218  	b.report.records = append(b.report.records, mockReportV2Record{recordType: recordType, groupAddress: groupAddress})
   219  }
   220  
   221  func recordsToMap(m map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType, records []mockReportV2Record) {
   222  	for _, record := range records {
   223  		m[record.groupAddress] = append(m[record.groupAddress], record.recordType)
   224  	}
   225  }
   226  
   227  // Send implements ip.MulticastGroupProtocolV2ReportBuilder.
   228  func (b *mockReportV2Builder) Send() (sent bool, err tcpip.Error) {
   229  	if b.m.mu.TryLock() {
   230  		b.m.mu.Unlock() // +checklocksforce: TryLock.
   231  		b.m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending v2 report")
   232  	}
   233  	if b.m.mu.TryRLock() {
   234  		b.m.mu.RUnlock() // +checklocksforce: TryLock.
   235  		b.m.t.Fatal("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending v2 report")
   236  	}
   237  
   238  	recordsToMap(b.m.mu.sentV2Reports, b.report.records)
   239  	return !b.m.mu.makeQueuePackets, nil
   240  }
   241  
   242  // NewReportV2Builder implements ip.MulticastGroupProtocol.
   243  func (m *mockMulticastGroupProtocol) NewReportV2Builder() ip.MulticastGroupProtocolV2ReportBuilder {
   244  	return &mockReportV2Builder{m: m}
   245  }
   246  
   247  // V2QueryMaxRespCodeToV2Delay implements ip.MulticastGroupProtocol.
   248  func (*mockMulticastGroupProtocol) V2QueryMaxRespCodeToV2Delay(code uint16) time.Duration {
   249  	return time.Duration(code) * time.Millisecond
   250  }
   251  
   252  // V2QueryMaxRespCodeToV1Delay implements ip.MulticastGroupProtocol.
   253  func (*mockMulticastGroupProtocol) V2QueryMaxRespCodeToV1Delay(code uint16) time.Duration {
   254  	return time.Duration(code) * time.Millisecond
   255  }
   256  
   257  type checkFields struct {
   258  	sendReportGroupAddresses []tcpip.Address
   259  	sendLeaveGroupAddresses  []tcpip.Address
   260  	sentV2Reports            []mockReportV2
   261  }
   262  
   263  func (m *mockMulticastGroupProtocol) check(fields checkFields) string {
   264  	m.mu.Lock()
   265  	defer m.mu.Unlock()
   266  
   267  	sendReportGroupAddrCount := make(map[tcpip.Address]int)
   268  	for _, a := range fields.sendReportGroupAddresses {
   269  		sendReportGroupAddrCount[a] = 1
   270  	}
   271  
   272  	sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
   273  	for _, a := range fields.sendLeaveGroupAddresses {
   274  		sendLeaveGroupAddrCount[a] = 1
   275  	}
   276  
   277  	sentV2Reports := make(map[tcpip.Address][]ip.MulticastGroupProtocolV2ReportRecordType)
   278  	for _, report := range fields.sentV2Reports {
   279  		recordsToMap(sentV2Reports, report.records)
   280  	}
   281  
   282  	diff := cmp.Diff(
   283  		&mockMulticastGroupProtocol{
   284  			mu: mockMulticastGroupProtocolProtectedFields{
   285  				sendReportGroupAddrCount: sendReportGroupAddrCount,
   286  				sendLeaveGroupAddrCount:  sendLeaveGroupAddrCount,
   287  				sentV2Reports:            sentV2Reports,
   288  			},
   289  		},
   290  		m,
   291  		cmp.AllowUnexported(mockMulticastGroupProtocol{}),
   292  		cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
   293  		cmp.AllowUnexported(mockReportV2{}),
   294  		cmp.AllowUnexported(mockReportV2Record{}),
   295  		// ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
   296  		cmp.FilterPath(
   297  			func(p cmp.Path) bool {
   298  				switch p.Last().String() {
   299  				case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
   300  					return true
   301  				default:
   302  					return false
   303  				}
   304  			},
   305  			cmp.Ignore(),
   306  		),
   307  	)
   308  	m.initLocked()
   309  	return diff
   310  }
   311  
   312  func TestJoinGroup(t *testing.T) {
   313  	tests := []struct {
   314  		name              string
   315  		addr              tcpip.Address
   316  		shouldSendReports bool
   317  	}{
   318  		{
   319  			name:              "Normal group",
   320  			addr:              addr1,
   321  			shouldSendReports: true,
   322  		},
   323  		{
   324  			name:              "All-nodes group",
   325  			addr:              addr2,
   326  			shouldSendReports: false,
   327  		},
   328  	}
   329  
   330  	subTests := []struct {
   331  		name            string
   332  		v1Compatibility bool
   333  		checkFields     func(tcpip.Address) checkFields
   334  	}{
   335  		{
   336  			name:            "V1 Compatibility",
   337  			v1Compatibility: true,
   338  			checkFields: func(addr tcpip.Address) checkFields {
   339  				return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}}
   340  			},
   341  		},
   342  		{
   343  			name:            "V2",
   344  			v1Compatibility: false,
   345  			checkFields: func(addr tcpip.Address) checkFields {
   346  				return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
   347  					{
   348  						recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
   349  						groupAddress: addr,
   350  					},
   351  				}}}}
   352  			},
   353  		},
   354  	}
   355  
   356  	for _, test := range tests {
   357  		t.Run(test.name, func(t *testing.T) {
   358  			for _, subTest := range subTests {
   359  				t.Run(subTest.name, func(t *testing.T) {
   360  					mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
   361  					clock := faketime.NewManualClock()
   362  
   363  					mgp.init(ip.GenericMulticastProtocolOptions{
   364  						Rand:                      rand.New(rand.NewSource(0)),
   365  						Clock:                     clock,
   366  						MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   367  					}, subTest.v1Compatibility)
   368  
   369  					// Joining a group should send a report immediately and another after
   370  					// a random interval between 0 and the maximum unsolicited report delay.
   371  					mgp.joinGroup(test.addr)
   372  					if test.shouldSendReports {
   373  						expected := subTest.checkFields(test.addr)
   374  						if diff := mgp.check(expected); diff != "" {
   375  							t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   376  						}
   377  
   378  						// Generic multicast protocol timers are expected to take the job mutex.
   379  						clock.Advance(maxUnsolicitedReportDelay)
   380  						if diff := mgp.check(expected); diff != "" {
   381  							t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   382  						}
   383  					}
   384  
   385  					// Should have no more messages to send.
   386  					clock.Advance(time.Hour)
   387  					if diff := mgp.check(checkFields{}); diff != "" {
   388  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   389  					}
   390  				})
   391  			}
   392  		})
   393  	}
   394  }
   395  
   396  func TestLeaveGroup(t *testing.T) {
   397  	const maxRespCode = 1
   398  
   399  	tests := []struct {
   400  		name               string
   401  		addr               tcpip.Address
   402  		shouldSendMessages bool
   403  	}{
   404  		{
   405  			name:               "Normal group",
   406  			addr:               addr1,
   407  			shouldSendMessages: true,
   408  		},
   409  		{
   410  			name:               "All-nodes group",
   411  			addr:               addr2,
   412  			shouldSendMessages: false,
   413  		},
   414  	}
   415  
   416  	subTests := []struct {
   417  		name            string
   418  		v1Compatibility bool
   419  		checkFields     func(tcpip.Address, bool) checkFields
   420  		handleQuery     func(*mockMulticastGroupProtocol, tcpip.Address)
   421  	}{
   422  		{
   423  			name:            "V1 Compatibility",
   424  			v1Compatibility: true,
   425  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
   426  				if leave {
   427  					return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}}
   428  				}
   429  				return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}}
   430  			},
   431  			handleQuery: func(mgp *mockMulticastGroupProtocol, groupAddress tcpip.Address) {
   432  				mgp.handleQuery(groupAddress, maxRespCode)
   433  			},
   434  		},
   435  		{
   436  			name:            "V2",
   437  			v1Compatibility: false,
   438  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
   439  				recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode
   440  				if leave {
   441  					recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode
   442  				}
   443  
   444  				return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
   445  					{
   446  						recordType:   recordType,
   447  						groupAddress: addr,
   448  					},
   449  				}}}}
   450  			},
   451  			handleQuery: func(mgp *mockMulticastGroupProtocol, groupAddress tcpip.Address) {
   452  				mgp.handleQueryV2(groupAddress, maxRespCode, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0)
   453  			},
   454  		},
   455  	}
   456  
   457  	for _, test := range tests {
   458  		t.Run(test.name, func(t *testing.T) {
   459  			for _, subTest := range subTests {
   460  				t.Run(subTest.name, func(t *testing.T) {
   461  					for _, queryAddr := range []tcpip.Address{test.addr, tcpip.Address{}} {
   462  						t.Run(fmt.Sprintf("QueryAddr=%s", queryAddr), func(t *testing.T) {
   463  							mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
   464  							clock := faketime.NewManualClock()
   465  
   466  							mgp.init(ip.GenericMulticastProtocolOptions{
   467  								Rand:                      rand.New(rand.NewSource(1)),
   468  								Clock:                     clock,
   469  								MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   470  							}, subTest.v1Compatibility)
   471  
   472  							mgp.joinGroup(test.addr)
   473  							if test.shouldSendMessages {
   474  								if diff := mgp.check(subTest.checkFields(test.addr, false /* leave */)); diff != "" {
   475  									t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   476  								}
   477  							}
   478  
   479  							// The timer scheduled to send the query response should do
   480  							// nothing since we will leave the group before the response is
   481  							// sent.
   482  							subTest.handleQuery(&mgp, queryAddr)
   483  
   484  							// Leaving a group should send a leave report immediately and
   485  							// cancel any delayed reports.
   486  							if !mgp.leaveGroup(test.addr) {
   487  								t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
   488  							}
   489  
   490  							// A query should not do anything since we left the group.
   491  							subTest.handleQuery(&mgp, queryAddr)
   492  
   493  							if test.shouldSendMessages {
   494  								if diff := mgp.check(subTest.checkFields(test.addr, true /* leave */)); diff != "" {
   495  									t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   496  								}
   497  
   498  								if !subTest.v1Compatibility {
   499  									clock.Advance(maxUnsolicitedReportDelay)
   500  
   501  									if diff := mgp.check(subTest.checkFields(test.addr, true /* leave */)); diff != "" {
   502  										t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   503  									}
   504  								}
   505  							}
   506  
   507  							// Should have no more messages to send.
   508  							clock.Advance(time.Hour)
   509  							if diff := mgp.check(checkFields{}); diff != "" {
   510  								t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   511  							}
   512  						})
   513  					}
   514  				})
   515  			}
   516  		})
   517  	}
   518  }
   519  
   520  func TestHandleReport(t *testing.T) {
   521  	tests := []struct {
   522  		name             string
   523  		reportAddr       tcpip.Address
   524  		expectReportsFor []tcpip.Address
   525  	}{
   526  		{
   527  			name:             "Unpecified empty",
   528  			reportAddr:       tcpip.Address{},
   529  			expectReportsFor: []tcpip.Address{addr1, addr2},
   530  		},
   531  		{
   532  			name:             "Unpecified any",
   533  			reportAddr:       tcpip.AddrFromSlice([]byte("\x00\x00\x00\x00")),
   534  			expectReportsFor: []tcpip.Address{addr1, addr2},
   535  		},
   536  		{
   537  			name:             "Specified",
   538  			reportAddr:       addr1,
   539  			expectReportsFor: []tcpip.Address{addr2},
   540  		},
   541  		{
   542  			name:             "Specified all-nodes",
   543  			reportAddr:       addr3,
   544  			expectReportsFor: []tcpip.Address{addr1, addr2},
   545  		},
   546  		{
   547  			name:             "Specified other",
   548  			reportAddr:       addr4,
   549  			expectReportsFor: []tcpip.Address{addr1, addr2},
   550  		},
   551  	}
   552  
   553  	subTests := []struct {
   554  		name            string
   555  		v1Compatibility bool
   556  		checkFields     func([]tcpip.Address) checkFields
   557  	}{
   558  		{
   559  			name:            "V1 Compatibility",
   560  			v1Compatibility: true,
   561  			checkFields: func(addrs []tcpip.Address) checkFields {
   562  				return checkFields{sendReportGroupAddresses: addrs}
   563  			},
   564  		},
   565  		{
   566  			name:            "V2",
   567  			v1Compatibility: false,
   568  			checkFields: func(addrs []tcpip.Address) checkFields {
   569  				var records []mockReportV2Record
   570  				for _, addr := range addrs {
   571  					records = append(records, mockReportV2Record{
   572  						recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
   573  						groupAddress: addr,
   574  					})
   575  				}
   576  
   577  				return checkFields{sentV2Reports: []mockReportV2{{records: records}}}
   578  			},
   579  		},
   580  	}
   581  
   582  	for _, test := range tests {
   583  		t.Run(test.name, func(t *testing.T) {
   584  			for _, subTest := range subTests {
   585  				t.Run(subTest.name, func(t *testing.T) {
   586  					mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   587  					clock := faketime.NewManualClock()
   588  
   589  					mgp.init(ip.GenericMulticastProtocolOptions{
   590  						Rand:                      rand.New(rand.NewSource(2)),
   591  						Clock:                     clock,
   592  						MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   593  					}, subTest.v1Compatibility)
   594  
   595  					mgp.joinGroup(addr1)
   596  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1})); diff != "" {
   597  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   598  					}
   599  					mgp.joinGroup(addr2)
   600  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr2})); diff != "" {
   601  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   602  					}
   603  					mgp.joinGroup(addr3)
   604  					if diff := mgp.check(checkFields{}); diff != "" {
   605  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   606  					}
   607  
   608  					// Receiving a report for a group we have a timer scheduled for should
   609  					// cancel our delayed report timer for the group.
   610  					mgp.handleReport(test.reportAddr)
   611  					if len(test.expectReportsFor) != 0 {
   612  						// Generic multicast protocol timers are expected to take the job mutex.
   613  						clock.Advance(maxUnsolicitedReportDelay)
   614  						if diff := mgp.check(subTest.checkFields(test.expectReportsFor)); diff != "" {
   615  							t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   616  						}
   617  					}
   618  
   619  					// Should have no more messages to send.
   620  					clock.Advance(time.Hour)
   621  					if diff := mgp.check(checkFields{}); diff != "" {
   622  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   623  					}
   624  				})
   625  			}
   626  		})
   627  	}
   628  }
   629  
   630  func TestHandleQuery(t *testing.T) {
   631  	tests := []struct {
   632  		name                    string
   633  		queryAddr               tcpip.Address
   634  		maxDelay                time.Duration
   635  		expectQueriedReportsFor []tcpip.Address
   636  		expectDelayedReportsFor []tcpip.Address
   637  	}{
   638  		{
   639  			name:                    "Unpecified empty",
   640  			queryAddr:               tcpip.Address{},
   641  			maxDelay:                0,
   642  			expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
   643  			expectDelayedReportsFor: nil,
   644  		},
   645  		{
   646  			name:                    "Unpecified any",
   647  			queryAddr:               tcpip.AddrFromSlice([]byte("\x00\x00\x00\x00")),
   648  			maxDelay:                1,
   649  			expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
   650  			expectDelayedReportsFor: nil,
   651  		},
   652  		{
   653  			name:                    "Specified",
   654  			queryAddr:               addr1,
   655  			maxDelay:                2,
   656  			expectQueriedReportsFor: []tcpip.Address{addr1},
   657  			expectDelayedReportsFor: []tcpip.Address{addr2},
   658  		},
   659  		{
   660  			name:                    "Specified all-nodes",
   661  			queryAddr:               addr3,
   662  			maxDelay:                3,
   663  			expectQueriedReportsFor: nil,
   664  			expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
   665  		},
   666  		{
   667  			name:                    "Specified other",
   668  			queryAddr:               addr4,
   669  			maxDelay:                4,
   670  			expectQueriedReportsFor: nil,
   671  			expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
   672  		},
   673  	}
   674  
   675  	subTests := []struct {
   676  		name            string
   677  		v1Compatibility bool
   678  		checkFields     func([]tcpip.Address) checkFields
   679  	}{
   680  		{
   681  			name:            "V1 Compatibility",
   682  			v1Compatibility: true,
   683  			checkFields: func(addrs []tcpip.Address) checkFields {
   684  				return checkFields{sendReportGroupAddresses: addrs}
   685  			},
   686  		},
   687  		{
   688  			name:            "V2",
   689  			v1Compatibility: false,
   690  			checkFields: func(addrs []tcpip.Address) checkFields {
   691  				var records []mockReportV2Record
   692  				for _, addr := range addrs {
   693  					records = append(records, mockReportV2Record{
   694  						recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
   695  						groupAddress: addr,
   696  					})
   697  				}
   698  
   699  				return checkFields{sentV2Reports: []mockReportV2{{records: records}}}
   700  			},
   701  		},
   702  	}
   703  
   704  	for _, test := range tests {
   705  		t.Run(test.name, func(t *testing.T) {
   706  			for _, subTest := range subTests {
   707  				t.Run(subTest.name, func(t *testing.T) {
   708  					mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   709  					clock := faketime.NewManualClock()
   710  
   711  					mgp.init(ip.GenericMulticastProtocolOptions{
   712  						Rand:                      rand.New(rand.NewSource(3)),
   713  						Clock:                     clock,
   714  						MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   715  					}, subTest.v1Compatibility)
   716  
   717  					mgp.joinGroup(addr1)
   718  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1})); diff != "" {
   719  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   720  					}
   721  					mgp.joinGroup(addr2)
   722  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr2})); diff != "" {
   723  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   724  					}
   725  					mgp.joinGroup(addr3)
   726  					if diff := mgp.check(checkFields{}); diff != "" {
   727  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   728  					}
   729  
   730  					// Receiving a query should make us reschedule our delayed report timer
   731  					// to some time within the new max response delay.
   732  					mgp.handleQuery(test.queryAddr, test.maxDelay)
   733  					clock.Advance(test.maxDelay)
   734  					if diff := mgp.check(checkFields{sendReportGroupAddresses: test.expectQueriedReportsFor}); diff != "" {
   735  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   736  					}
   737  
   738  					// The groups that were not affected by the query should still send a
   739  					// report after the max unsolicited report delay.
   740  					//
   741  					// If we were in V2 mode, then we would have cancelled the interface's
   742  					// state changed timer so we won't see any further reports after
   743  					// receiving a V1 query.
   744  					if subTest.v1Compatibility {
   745  						clock.Advance(maxUnsolicitedReportDelay)
   746  						if diff := mgp.check(subTest.checkFields(test.expectDelayedReportsFor)); diff != "" {
   747  							t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   748  						}
   749  					}
   750  
   751  					// Should have no more messages to send.
   752  					clock.Advance(time.Hour)
   753  					if diff := mgp.check(checkFields{}); diff != "" {
   754  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   755  					}
   756  				})
   757  			}
   758  		})
   759  	}
   760  }
   761  
   762  func TestHandleQueryV2Response(t *testing.T) {
   763  	tests := []struct {
   764  		name                    string
   765  		queryAddr               tcpip.Address
   766  		maxDelay                uint16
   767  		expectQueriedReportsFor []tcpip.Address
   768  		expectDelayedReportsFor []tcpip.Address
   769  	}{
   770  		{
   771  			name:                    "Unpecified empty",
   772  			queryAddr:               tcpip.Address{},
   773  			maxDelay:                0,
   774  			expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
   775  			expectDelayedReportsFor: nil,
   776  		},
   777  		{
   778  			name:                    "Unpecified any",
   779  			queryAddr:               tcpip.AddrFromSlice([]byte("\x00\x00\x00\x00")),
   780  			maxDelay:                1,
   781  			expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
   782  			expectDelayedReportsFor: nil,
   783  		},
   784  		{
   785  			name:                    "Specified",
   786  			queryAddr:               addr1,
   787  			maxDelay:                2,
   788  			expectQueriedReportsFor: []tcpip.Address{addr1},
   789  			expectDelayedReportsFor: []tcpip.Address{addr2},
   790  		},
   791  		{
   792  			name:                    "Specified all-nodes",
   793  			queryAddr:               addr3,
   794  			maxDelay:                3,
   795  			expectQueriedReportsFor: nil,
   796  			expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
   797  		},
   798  		{
   799  			name:                    "Specified other",
   800  			queryAddr:               addr4,
   801  			maxDelay:                4,
   802  			expectQueriedReportsFor: nil,
   803  			expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
   804  		},
   805  	}
   806  
   807  	subTests := []struct {
   808  		name            string
   809  		v1Compatibility bool
   810  		checkFields     func([]tcpip.Address, bool) checkFields
   811  	}{
   812  		{
   813  			name:            "V1 Compatibility",
   814  			v1Compatibility: true,
   815  			checkFields: func(addrs []tcpip.Address, _ bool) checkFields {
   816  				return checkFields{sendReportGroupAddresses: addrs}
   817  			},
   818  		},
   819  		{
   820  			name:            "V2",
   821  			v1Compatibility: false,
   822  			checkFields: func(addrs []tcpip.Address, queryResponse bool) checkFields {
   823  				var records []mockReportV2Record
   824  				recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode
   825  				if queryResponse {
   826  					recordType = ip.MulticastGroupProtocolV2ReportRecordModeIsExclude
   827  				}
   828  
   829  				for _, addr := range addrs {
   830  					records = append(records, mockReportV2Record{
   831  						recordType:   recordType,
   832  						groupAddress: addr,
   833  					})
   834  				}
   835  
   836  				return checkFields{sentV2Reports: []mockReportV2{{records: records}}}
   837  			},
   838  		},
   839  	}
   840  
   841  	for _, test := range tests {
   842  		t.Run(test.name, func(t *testing.T) {
   843  			for _, subTest := range subTests {
   844  				t.Run(subTest.name, func(t *testing.T) {
   845  					mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   846  					clock := faketime.NewManualClock()
   847  
   848  					mgp.init(ip.GenericMulticastProtocolOptions{
   849  						Rand:                      rand.New(rand.NewSource(3)),
   850  						Clock:                     clock,
   851  						MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   852  					}, subTest.v1Compatibility)
   853  
   854  					mgp.joinGroup(addr1)
   855  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1}, false /* queryResponse */)); diff != "" {
   856  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   857  					}
   858  					mgp.joinGroup(addr2)
   859  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr2}, false /* queryResponse */)); diff != "" {
   860  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   861  					}
   862  					mgp.joinGroup(addr3)
   863  					if diff := mgp.check(checkFields{}); diff != "" {
   864  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   865  					}
   866  					clock.Advance(maxUnsolicitedReportDelay)
   867  					if diff := mgp.check(subTest.checkFields([]tcpip.Address{addr1, addr2}, false /* queryResponse */)); diff != "" {
   868  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   869  					}
   870  					clock.Advance(maxUnsolicitedReportDelay)
   871  					if diff := mgp.check(checkFields{}); diff != "" {
   872  						t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   873  					}
   874  
   875  					// Receiving a query should make us reschedule our delayed report
   876  					// timer to some time within the new max response delay.
   877  					//
   878  					// Note that if we are in V1 compatibility mode, the V2 query will be
   879  					// handled as a V1 query.
   880  					mgp.handleQueryV2(test.queryAddr, test.maxDelay, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0)
   881  					if subTest.v1Compatibility {
   882  						clock.Advance(mgp.V2QueryMaxRespCodeToV1Delay(test.maxDelay))
   883  					} else {
   884  						clock.Advance(mgp.V2QueryMaxRespCodeToV2Delay(test.maxDelay))
   885  					}
   886  					if diff := mgp.check(subTest.checkFields(test.expectQueriedReportsFor, true /* queryResponse */)); diff != "" {
   887  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   888  					}
   889  
   890  					// Should have no more messages to send.
   891  					clock.Advance(time.Hour)
   892  					if diff := mgp.check(checkFields{}); diff != "" {
   893  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   894  					}
   895  				})
   896  			}
   897  		})
   898  	}
   899  }
   900  
   901  func TestV1CompatbilityModeTimer(t *testing.T) {
   902  	tests := []struct {
   903  		name               string
   904  		robustnessVariable uint8
   905  		queryInterval      time.Duration
   906  	}{
   907  		{
   908  			name:               "Unspecified Robustness variable and Query interval",
   909  			robustnessVariable: 0,
   910  			queryInterval:      0,
   911  		},
   912  		{
   913  			name:               "Unspecified Robustness variable",
   914  			robustnessVariable: 0,
   915  			queryInterval:      ip.DefaultQueryInterval + time.Second,
   916  		},
   917  		{
   918  			name:               "Unspecified Query interval",
   919  			robustnessVariable: ip.DefaultRobustnessVariable + 1,
   920  			queryInterval:      0,
   921  		},
   922  		{
   923  			name:               "Default Robustness variable and Query interval",
   924  			robustnessVariable: ip.DefaultRobustnessVariable,
   925  			queryInterval:      ip.DefaultQueryInterval,
   926  		},
   927  		{
   928  			name:               "Specified Robustness variable and Query interval",
   929  			robustnessVariable: ip.DefaultRobustnessVariable + 1,
   930  			queryInterval:      ip.DefaultQueryInterval + time.Second,
   931  		},
   932  	}
   933  
   934  	for _, test := range tests {
   935  		t.Run(test.name, func(t *testing.T) {
   936  			mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   937  			clock := faketime.NewManualClock()
   938  
   939  			mgp.init(ip.GenericMulticastProtocolOptions{
   940  				Rand:                      rand.New(rand.NewSource(3)),
   941  				Clock:                     clock,
   942  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   943  			}, false /* v1Compatibiltiy */)
   944  
   945  			v2Check := func(t *testing.T) {
   946  				t.Helper()
   947  
   948  				mgp.joinGroup(addr1)
   949  				if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
   950  					{
   951  						recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
   952  						groupAddress: addr1,
   953  					},
   954  				}}}}); diff != "" {
   955  					t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   956  				}
   957  				if !mgp.leaveGroup(addr1) {
   958  					t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", addr1)
   959  				}
   960  				if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
   961  					{
   962  						recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode,
   963  						groupAddress: addr1,
   964  					},
   965  				}}}}); diff != "" {
   966  					t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   967  				}
   968  			}
   969  			v2Check(t)
   970  
   971  			subTests := []struct {
   972  				name        string
   973  				advanceTime time.Duration
   974  			}{
   975  				{
   976  					name:        "Default",
   977  					advanceTime: ip.DefaultRobustnessVariable * ip.DefaultQueryInterval,
   978  				},
   979  				{
   980  					name: "After V2 Query",
   981  					advanceTime: func() time.Duration {
   982  						robustnessVariable := test.robustnessVariable
   983  						if robustnessVariable == 0 {
   984  							robustnessVariable = ip.DefaultRobustnessVariable
   985  						}
   986  
   987  						queryInterval := test.queryInterval
   988  						if queryInterval == 0 {
   989  							queryInterval = ip.DefaultQueryInterval
   990  						}
   991  
   992  						return time.Duration(robustnessVariable) * queryInterval
   993  					}(),
   994  				},
   995  			}
   996  
   997  			for _, subTest := range subTests {
   998  				t.Run(subTest.name, func(t *testing.T) {
   999  					mgp.handleQuery(addr3, time.Nanosecond)
  1000  					v1Check := func() {
  1001  						t.Helper()
  1002  						mgp.joinGroup(addr1)
  1003  						if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" {
  1004  							t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1005  						}
  1006  						if !mgp.leaveGroup(addr1) {
  1007  							t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", addr1)
  1008  						}
  1009  						if diff := mgp.check(checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr1}}); diff != "" {
  1010  							t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1011  						}
  1012  					}
  1013  					v1Check()
  1014  					const minDuration = time.Duration(1)
  1015  					clock.Advance(subTest.advanceTime - minDuration)
  1016  					v1Check()
  1017  
  1018  					clock.Advance(minDuration)
  1019  					v2Check(t)
  1020  					// Should update the Robustness variable and Querier's Query interval.
  1021  					mgp.handleQueryV2(addr3, 0, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), test.robustnessVariable, test.queryInterval)
  1022  				})
  1023  			}
  1024  		})
  1025  	}
  1026  }
  1027  
  1028  func TestJoinCount(t *testing.T) {
  1029  	const maxUnsolicitedReportDelay = time.Second
  1030  
  1031  	tests := []struct {
  1032  		name            string
  1033  		v1Compatibility bool
  1034  		checkFields     func(tcpip.Address, bool) checkFields
  1035  	}{
  1036  		{
  1037  			name:            "V1 Compatibility",
  1038  			v1Compatibility: true,
  1039  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
  1040  				if leave {
  1041  					return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}}
  1042  				}
  1043  				return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}}
  1044  			},
  1045  		},
  1046  		{
  1047  			name:            "V2",
  1048  			v1Compatibility: false,
  1049  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
  1050  				recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode
  1051  				if leave {
  1052  					recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode
  1053  				}
  1054  
  1055  				return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
  1056  					{
  1057  						recordType:   recordType,
  1058  						groupAddress: addr,
  1059  					},
  1060  				}}}}
  1061  			},
  1062  		},
  1063  	}
  1064  
  1065  	for _, test := range tests {
  1066  		t.Run(test.name, func(t *testing.T) {
  1067  			mgp := mockMulticastGroupProtocol{t: t}
  1068  			clock := faketime.NewManualClock()
  1069  
  1070  			mgp.init(ip.GenericMulticastProtocolOptions{
  1071  				Rand:                      rand.New(rand.NewSource(4)),
  1072  				Clock:                     clock,
  1073  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
  1074  			}, test.v1Compatibility)
  1075  
  1076  			// Set the join count to 2 for a group.
  1077  			mgp.joinGroup(addr1)
  1078  			if !mgp.isLocallyJoined(addr1) {
  1079  				t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
  1080  			}
  1081  			// Only the first join should trigger a report to be sent.
  1082  			if diff := mgp.check(test.checkFields(addr1, false /* leave */)); diff != "" {
  1083  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1084  			}
  1085  			mgp.joinGroup(addr1)
  1086  			if !mgp.isLocallyJoined(addr1) {
  1087  				t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
  1088  			}
  1089  			if diff := mgp.check(checkFields{}); diff != "" {
  1090  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1091  			}
  1092  			if t.Failed() {
  1093  				t.FailNow()
  1094  			}
  1095  
  1096  			// Group should still be considered joined after leaving once.
  1097  			if !mgp.leaveGroup(addr1) {
  1098  				t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
  1099  			}
  1100  			if !mgp.isLocallyJoined(addr1) {
  1101  				t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
  1102  			}
  1103  			// A leave report should only be sent once the join count reaches 0.
  1104  			if diff := mgp.check(checkFields{}); diff != "" {
  1105  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1106  			}
  1107  			if t.Failed() {
  1108  				t.FailNow()
  1109  			}
  1110  
  1111  			// Leaving once more should actually remove us from the group.
  1112  			if !mgp.leaveGroup(addr1) {
  1113  				t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
  1114  			}
  1115  			if mgp.isLocallyJoined(addr1) {
  1116  				t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
  1117  			}
  1118  			if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" {
  1119  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1120  			}
  1121  			if !test.v1Compatibility {
  1122  				// V2 should still have a queued state-changed report.
  1123  				clock.Advance(maxUnsolicitedReportDelay)
  1124  				if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" {
  1125  					t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1126  				}
  1127  			}
  1128  			if t.Failed() {
  1129  				t.FailNow()
  1130  			}
  1131  
  1132  			// Group should no longer be joined so we should not have anything to
  1133  			// leave.
  1134  			if mgp.leaveGroup(addr1) {
  1135  				t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
  1136  			}
  1137  			if mgp.isLocallyJoined(addr1) {
  1138  				t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
  1139  			}
  1140  			if diff := mgp.check(checkFields{}); diff != "" {
  1141  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1142  			}
  1143  
  1144  			// Should have no more messages to send.
  1145  			//
  1146  			// Generic multicast protocol timers are expected to take the job mutex.
  1147  			clock.Advance(time.Hour)
  1148  			if diff := mgp.check(checkFields{}); diff != "" {
  1149  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1150  			}
  1151  		})
  1152  	}
  1153  }
  1154  
  1155  func TestMakeAllNonMemberAndInitialize(t *testing.T) {
  1156  	const unsolicitedTransmissionCount = 2
  1157  
  1158  	tests := []struct {
  1159  		name            string
  1160  		v1              bool
  1161  		v1Compatibility bool
  1162  		checkFields     func([]tcpip.Address, bool) checkFields
  1163  	}{
  1164  		{
  1165  			name:            "V1",
  1166  			v1:              true,
  1167  			v1Compatibility: false,
  1168  			checkFields: func(addrs []tcpip.Address, leave bool) checkFields {
  1169  				if leave {
  1170  					return checkFields{sendLeaveGroupAddresses: addrs}
  1171  				}
  1172  				return checkFields{sendReportGroupAddresses: addrs}
  1173  			},
  1174  		},
  1175  		{
  1176  			name:            "V1 Compatibility",
  1177  			v1:              false,
  1178  			v1Compatibility: true,
  1179  			checkFields: func(addrs []tcpip.Address, leave bool) checkFields {
  1180  				if leave {
  1181  					return checkFields{sendLeaveGroupAddresses: addrs}
  1182  				}
  1183  				return checkFields{sendReportGroupAddresses: addrs}
  1184  			},
  1185  		},
  1186  		{
  1187  			name:            "V2",
  1188  			v1:              false,
  1189  			v1Compatibility: false,
  1190  			checkFields: func(addrs []tcpip.Address, leave bool) checkFields {
  1191  				recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode
  1192  				if leave {
  1193  					recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode
  1194  				}
  1195  				var records []mockReportV2Record
  1196  				for _, addr := range addrs {
  1197  					records = append(records, mockReportV2Record{
  1198  						recordType:   recordType,
  1199  						groupAddress: addr,
  1200  					})
  1201  				}
  1202  
  1203  				return checkFields{sentV2Reports: []mockReportV2{{records: records}}}
  1204  			},
  1205  		},
  1206  	}
  1207  
  1208  	for _, test := range tests {
  1209  		t.Run(test.name, func(t *testing.T) {
  1210  			mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
  1211  			clock := faketime.NewManualClock()
  1212  
  1213  			mgp.init(ip.GenericMulticastProtocolOptions{
  1214  				Rand:                      rand.New(rand.NewSource(3)),
  1215  				Clock:                     clock,
  1216  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
  1217  			}, test.v1)
  1218  
  1219  			if test.v1Compatibility {
  1220  				// V1 query targeting an unjoined group should drop us into V1
  1221  				// compatibility mode without sending any packets, affecting tests.
  1222  				mgp.handleQuery(addr3, 0)
  1223  			}
  1224  
  1225  			mgp.joinGroup(addr1)
  1226  			if diff := mgp.check(test.checkFields([]tcpip.Address{addr1}, false /* leave */)); diff != "" {
  1227  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1228  			}
  1229  			mgp.joinGroup(addr2)
  1230  			if diff := mgp.check(test.checkFields([]tcpip.Address{addr2}, false /* leave */)); diff != "" {
  1231  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1232  			}
  1233  			mgp.joinGroup(addr3)
  1234  			if diff := mgp.check(checkFields{}); diff != "" {
  1235  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1236  			}
  1237  
  1238  			// Should send the leave reports for each but still consider them locally
  1239  			// joined.
  1240  			mgp.makeAllNonMember()
  1241  			if diff := mgp.check(test.checkFields([]tcpip.Address{addr1, addr2}, true /* leave */)); diff != "" {
  1242  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1243  			}
  1244  
  1245  			// Generic multicast protocol timers are expected to take the job mutex.
  1246  			clock.Advance(time.Hour)
  1247  			if diff := mgp.check(checkFields{}); diff != "" {
  1248  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1249  			}
  1250  			for _, group := range []tcpip.Address{addr1, addr2, addr3} {
  1251  				if !mgp.isLocallyJoined(group) {
  1252  					t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
  1253  				}
  1254  			}
  1255  
  1256  			// Should send the initial set of unsolcited V2 reports.
  1257  			mgp.initializeGroups()
  1258  			for i := 0; i < unsolicitedTransmissionCount; i++ {
  1259  				if test.v1 {
  1260  					if diff := mgp.check(test.checkFields([]tcpip.Address{addr1, addr2}, false /* leave */)); diff != "" {
  1261  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1262  					}
  1263  				} else {
  1264  					if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{
  1265  						{
  1266  							records: []mockReportV2Record{
  1267  								{
  1268  									recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
  1269  									groupAddress: addr1,
  1270  								},
  1271  							},
  1272  						},
  1273  						{
  1274  							records: []mockReportV2Record{
  1275  								{
  1276  									recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
  1277  									groupAddress: addr2,
  1278  								},
  1279  							},
  1280  						},
  1281  					}}); diff != "" {
  1282  						t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1283  					}
  1284  				}
  1285  				clock.Advance(maxUnsolicitedReportDelay)
  1286  			}
  1287  
  1288  			// Should have no more messages to send.
  1289  			clock.Advance(time.Hour)
  1290  			if diff := mgp.check(checkFields{}); diff != "" {
  1291  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1292  			}
  1293  
  1294  			if got := mgp.getV1Mode(); got != test.v1 {
  1295  				t.Errorf("got mgp.getV1Mode() = %t, want = %t", got, test.v1)
  1296  			}
  1297  		})
  1298  	}
  1299  }
  1300  
  1301  // TestGroupStateNonMember tests that groups do not send packets when in the
  1302  // non-member state, but are still considered locally joined.
  1303  func TestGroupStateNonMember(t *testing.T) {
  1304  	tests := []struct {
  1305  		name            string
  1306  		v1Compatibility bool
  1307  		checkFields     func([]tcpip.Address, bool) checkFields
  1308  	}{
  1309  		{
  1310  			name:            "V1 Compatibility",
  1311  			v1Compatibility: true,
  1312  			checkFields: func(addrs []tcpip.Address, leave bool) checkFields {
  1313  				if leave {
  1314  					return checkFields{sendLeaveGroupAddresses: addrs}
  1315  				}
  1316  				return checkFields{sendReportGroupAddresses: addrs}
  1317  			},
  1318  		},
  1319  		{
  1320  			name:            "V2",
  1321  			v1Compatibility: false,
  1322  			checkFields: func(addrs []tcpip.Address, leave bool) checkFields {
  1323  				recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode
  1324  				if leave {
  1325  					recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode
  1326  				}
  1327  				var records []mockReportV2Record
  1328  				for _, addr := range addrs {
  1329  					records = append(records, mockReportV2Record{
  1330  						recordType:   recordType,
  1331  						groupAddress: addr,
  1332  					})
  1333  				}
  1334  
  1335  				return checkFields{sentV2Reports: []mockReportV2{{records: records}}}
  1336  			},
  1337  		},
  1338  	}
  1339  
  1340  	for _, test := range tests {
  1341  		t.Run(test.name, func(t *testing.T) {
  1342  			mgp := mockMulticastGroupProtocol{t: t}
  1343  			clock := faketime.NewManualClock()
  1344  
  1345  			mgp.init(ip.GenericMulticastProtocolOptions{
  1346  				Rand:                      rand.New(rand.NewSource(3)),
  1347  				Clock:                     clock,
  1348  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
  1349  			}, test.v1Compatibility)
  1350  			mgp.setEnabled(false)
  1351  
  1352  			// Joining groups should not send any reports.
  1353  			mgp.joinGroup(addr1)
  1354  			if !mgp.isLocallyJoined(addr1) {
  1355  				t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
  1356  			}
  1357  			if diff := mgp.check(checkFields{}); diff != "" {
  1358  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1359  			}
  1360  			mgp.joinGroup(addr2)
  1361  			if !mgp.isLocallyJoined(addr1) {
  1362  				t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
  1363  			}
  1364  			if diff := mgp.check(checkFields{}); diff != "" {
  1365  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1366  			}
  1367  
  1368  			// Receiving a query should not send any reports.
  1369  			mgp.handleQuery(addr1, time.Nanosecond)
  1370  			// Generic multicast protocol timers are expected to take the job mutex.
  1371  			clock.Advance(time.Nanosecond)
  1372  			if diff := mgp.check(checkFields{}); diff != "" {
  1373  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1374  			}
  1375  
  1376  			// Leaving groups should not send any leave messages.
  1377  			if !mgp.leaveGroup(addr1) {
  1378  				t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
  1379  			}
  1380  			if mgp.isLocallyJoined(addr1) {
  1381  				t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
  1382  			}
  1383  			if diff := mgp.check(checkFields{}); diff != "" {
  1384  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1385  			}
  1386  
  1387  			clock.Advance(time.Hour)
  1388  			if diff := mgp.check(checkFields{}); diff != "" {
  1389  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1390  			}
  1391  		})
  1392  	}
  1393  }
  1394  
  1395  // TestMakeAllNonMemberCancelsDelayedReportJob tests that the delayed report job
  1396  // is cancelled on MakeAllNonMember, otherwise the job will panic if the endpoint
  1397  // is disabled.
  1398  func TestMakeAllNonMemberCancelsDelayedReportJob(t *testing.T) {
  1399  	const maxRespCode = 1
  1400  
  1401  	tests := []struct {
  1402  		name            string
  1403  		v1              bool
  1404  		v1Compatibility bool
  1405  		checkFields     func(tcpip.Address, bool) checkFields
  1406  	}{
  1407  		{
  1408  			name:            "V1",
  1409  			v1:              true,
  1410  			v1Compatibility: false,
  1411  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
  1412  				if leave {
  1413  					return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}}
  1414  				}
  1415  				return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}}
  1416  			},
  1417  		},
  1418  		{
  1419  			name:            "V1 Compatibility",
  1420  			v1:              false,
  1421  			v1Compatibility: true,
  1422  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
  1423  				if leave {
  1424  					return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}}
  1425  				}
  1426  				return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}}
  1427  			},
  1428  		},
  1429  		{
  1430  			name:            "V2",
  1431  			v1:              false,
  1432  			v1Compatibility: false,
  1433  			checkFields: func(addr tcpip.Address, leave bool) checkFields {
  1434  				recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode
  1435  				if leave {
  1436  					recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode
  1437  				}
  1438  				return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{mockReportV2Record{
  1439  					recordType:   recordType,
  1440  					groupAddress: addr,
  1441  				}}}}}
  1442  			},
  1443  		},
  1444  	}
  1445  
  1446  	for _, test := range tests {
  1447  		t.Run(test.name, func(t *testing.T) {
  1448  			mgp := mockMulticastGroupProtocol{t: t}
  1449  			clock := faketime.NewManualClock()
  1450  
  1451  			mgp.init(ip.GenericMulticastProtocolOptions{
  1452  				Rand:                      rand.New(rand.NewSource(3)),
  1453  				Clock:                     clock,
  1454  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
  1455  			}, test.v1)
  1456  
  1457  			if test.v1Compatibility {
  1458  				// V1 query targeting an unjoined group should drop us into V1
  1459  				// compatibility mode without sending any packets, affecting tests.
  1460  				mgp.handleQuery(addr3, 0)
  1461  			}
  1462  
  1463  			mgp.joinGroup(addr1)
  1464  			if diff := mgp.check(test.checkFields(addr1, false /* leave */)); diff != "" {
  1465  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1466  			}
  1467  
  1468  			// Handle a query so that the delayed report job is scheduled when operating
  1469  			// in V2 mode.
  1470  			mgp.handleQueryV2(addr1, maxRespCode, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0)
  1471  
  1472  			mgp.makeAllNonMember()
  1473  			if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" {
  1474  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1475  			}
  1476  
  1477  			mgp.setEnabled(false)
  1478  
  1479  			// Generic multicast protocol timers are expected to take the job mutex.
  1480  			//
  1481  			// Advance the clock to after the delayed report job is supposed to fire.
  1482  			// If the delayed report job isn't cancelled by the MakeAllNonMember call,
  1483  			// it will panic due to the expectation that the protocol is enabled.
  1484  			if test.v1 || test.v1Compatibility {
  1485  				clock.Advance(mgp.V2QueryMaxRespCodeToV1Delay(maxRespCode))
  1486  			} else {
  1487  				clock.Advance(mgp.V2QueryMaxRespCodeToV2Delay(maxRespCode))
  1488  			}
  1489  			if diff := mgp.check(checkFields{}); diff != "" {
  1490  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1491  			}
  1492  		})
  1493  	}
  1494  }
  1495  
  1496  func TestQueuedPackets(t *testing.T) {
  1497  	tests := []struct {
  1498  		name            string
  1499  		v1Compatibility bool
  1500  		checkFields     func(tcpip.Address) checkFields
  1501  	}{
  1502  		{
  1503  			name:            "V1 Compatibility",
  1504  			v1Compatibility: true,
  1505  			checkFields: func(addr tcpip.Address) checkFields {
  1506  				return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}}
  1507  			},
  1508  		},
  1509  		{
  1510  			name:            "V2",
  1511  			v1Compatibility: false,
  1512  			checkFields: func(addr tcpip.Address) checkFields {
  1513  				return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
  1514  					{
  1515  						recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
  1516  						groupAddress: addr,
  1517  					},
  1518  				}}}}
  1519  			},
  1520  		},
  1521  	}
  1522  
  1523  	for _, test := range tests {
  1524  		t.Run(test.name, func(t *testing.T) {
  1525  			clock := faketime.NewManualClock()
  1526  			mgp := mockMulticastGroupProtocol{t: t}
  1527  			mgp.init(ip.GenericMulticastProtocolOptions{
  1528  				Rand:                      rand.New(rand.NewSource(4)),
  1529  				Clock:                     clock,
  1530  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
  1531  			}, test.v1Compatibility)
  1532  
  1533  			// Joining should trigger a SendReport, but mgp should report that we did not
  1534  			// send the packet.
  1535  			mgp.setQueuePackets(true)
  1536  			mgp.joinGroup(addr1)
  1537  			if diff := mgp.check(test.checkFields(addr1)); diff != "" {
  1538  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1539  			}
  1540  
  1541  			// The delayed report timer should have been cancelled since we did not send
  1542  			// the initial report earlier.
  1543  			clock.Advance(time.Hour)
  1544  			if diff := mgp.check(checkFields{}); diff != "" {
  1545  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1546  			}
  1547  
  1548  			// Mock being able to successfully send the report.
  1549  			mgp.setQueuePackets(false)
  1550  			mgp.sendQueuedReports()
  1551  			if diff := mgp.check(test.checkFields(addr1)); diff != "" {
  1552  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1553  			}
  1554  
  1555  			// The delayed report (sent after the initial report) should now be sent.
  1556  			clock.Advance(maxUnsolicitedReportDelay)
  1557  			if diff := mgp.check(test.checkFields(addr1)); diff != "" {
  1558  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1559  			}
  1560  
  1561  			// Should not have anything else to send (we should be idle).
  1562  			mgp.sendQueuedReports()
  1563  			clock.Advance(time.Hour)
  1564  			if diff := mgp.check(checkFields{}); diff != "" {
  1565  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1566  			}
  1567  
  1568  			// Receive a query but mock being unable to send reports again.
  1569  			mgp.setQueuePackets(true)
  1570  			mgp.handleQuery(addr1, time.Nanosecond)
  1571  			clock.Advance(time.Nanosecond)
  1572  			if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" {
  1573  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1574  			}
  1575  
  1576  			// Mock being able to send reports again - we should have a packet queued to
  1577  			// send.
  1578  			mgp.setQueuePackets(false)
  1579  			mgp.sendQueuedReports()
  1580  			if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" {
  1581  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1582  			}
  1583  
  1584  			// Should not have anything else to send.
  1585  			mgp.sendQueuedReports()
  1586  			clock.Advance(time.Hour)
  1587  			if diff := mgp.check(checkFields{}); diff != "" {
  1588  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1589  			}
  1590  
  1591  			// Receive a query again, but mock being unable to send reports.
  1592  			mgp.setQueuePackets(true)
  1593  			mgp.handleQuery(addr1, time.Nanosecond)
  1594  			clock.Advance(time.Nanosecond)
  1595  			if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr1}}); diff != "" {
  1596  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1597  			}
  1598  
  1599  			// Receiving a report should transition us into the idle member state,
  1600  			// even if we had a packet queued. We should no longer have any packets to
  1601  			// send.
  1602  			mgp.handleReport(addr1)
  1603  			mgp.sendQueuedReports()
  1604  			if diff := mgp.check(checkFields{}); diff != "" {
  1605  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1606  			}
  1607  
  1608  			// When we fail to send the initial set of reports, incoming reports should
  1609  			// prevent a newly joined group's reports from being sent.
  1610  			mgp.setQueuePackets(true)
  1611  			mgp.joinGroup(addr2)
  1612  			if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" {
  1613  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1614  			}
  1615  			mgp.handleReport(addr2)
  1616  			// Attempting to send queued reports while still unable to send reports should
  1617  			// not change the host state.
  1618  			mgp.sendQueuedReports()
  1619  			if diff := mgp.check(checkFields{}); diff != "" {
  1620  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1621  			}
  1622  			// Should not have any packets queued.
  1623  			mgp.setQueuePackets(false)
  1624  			mgp.sendQueuedReports()
  1625  			clock.Advance(time.Hour)
  1626  			if diff := mgp.check(checkFields{}); diff != "" {
  1627  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1628  			}
  1629  		})
  1630  	}
  1631  }
  1632  
  1633  func TestGetSetV1Mode(t *testing.T) {
  1634  	clock := faketime.NewManualClock()
  1635  	mgp := mockMulticastGroupProtocol{t: t}
  1636  	mgp.init(ip.GenericMulticastProtocolOptions{
  1637  		Rand:                      rand.New(rand.NewSource(4)),
  1638  		Clock:                     clock,
  1639  		MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
  1640  	}, false /* v1Compatibility */)
  1641  
  1642  	if mgp.getV1Mode() {
  1643  		t.Error("got mgp.getV1Mode() = true, want = false")
  1644  	}
  1645  
  1646  	mgp.joinGroup(addr1)
  1647  	if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
  1648  		{
  1649  			recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
  1650  			groupAddress: addr1,
  1651  		},
  1652  	}}}}); diff != "" {
  1653  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1654  	}
  1655  
  1656  	if mgp.setV1Mode(true) {
  1657  		t.Error("got mgp.setV1Mode(true) = true, want = false")
  1658  	}
  1659  	if !mgp.getV1Mode() {
  1660  		t.Error("got mgp.getV1Mode() = false, want = true")
  1661  	}
  1662  	mgp.joinGroup(addr2)
  1663  	if diff := mgp.check(checkFields{sendReportGroupAddresses: []tcpip.Address{addr2}}); diff != "" {
  1664  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1665  	}
  1666  
  1667  	if !mgp.setV1Mode(false) {
  1668  		t.Error("got mgp.setV1Mode(false) = false, want = true")
  1669  	}
  1670  	if mgp.getV1Mode() {
  1671  		t.Error("got mgp.getV1Mode() = true, want = false")
  1672  	}
  1673  	mgp.joinGroup(addr3)
  1674  	if diff := mgp.check(checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{
  1675  		{
  1676  			recordType:   ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode,
  1677  			groupAddress: addr3,
  1678  		},
  1679  	}}}}); diff != "" {
  1680  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
  1681  	}
  1682  }