github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ip_test
    16  
    17  import (
    18  	"math/rand"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"github.com/SagerNet/gvisor/pkg/sync"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/faketime"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/network/internal/ip"
    27  )
    28  
    29  const maxUnsolicitedReportDelay = time.Second
    30  
    31  var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
    32  
    33  type mockMulticastGroupProtocolProtectedFields struct {
    34  	sync.RWMutex
    35  
    36  	genericMulticastGroup    ip.GenericMulticastProtocolState
    37  	sendReportGroupAddrCount map[tcpip.Address]int
    38  	sendLeaveGroupAddrCount  map[tcpip.Address]int
    39  	makeQueuePackets         bool
    40  	disabled                 bool
    41  }
    42  
    43  type mockMulticastGroupProtocol struct {
    44  	t *testing.T
    45  
    46  	skipProtocolAddress tcpip.Address
    47  
    48  	mu mockMulticastGroupProtocolProtectedFields
    49  }
    50  
    51  func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
    52  	m.mu.Lock()
    53  	defer m.mu.Unlock()
    54  	m.initLocked()
    55  	opts.Protocol = m
    56  	m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
    57  }
    58  
    59  func (m *mockMulticastGroupProtocol) initLocked() {
    60  	m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
    61  	m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
    62  }
    63  
    64  func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
    65  	m.mu.Lock()
    66  	defer m.mu.Unlock()
    67  	m.mu.disabled = !v
    68  }
    69  
    70  func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
    71  	m.mu.Lock()
    72  	defer m.mu.Unlock()
    73  	m.mu.makeQueuePackets = v
    74  }
    75  
    76  func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
    77  	m.mu.Lock()
    78  	defer m.mu.Unlock()
    79  	m.mu.genericMulticastGroup.JoinGroupLocked(addr)
    80  }
    81  
    82  func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
    83  	m.mu.Lock()
    84  	defer m.mu.Unlock()
    85  	return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
    86  }
    87  
    88  func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
    89  	m.mu.Lock()
    90  	defer m.mu.Unlock()
    91  	m.mu.genericMulticastGroup.HandleReportLocked(addr)
    92  }
    93  
    94  func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
    95  	m.mu.Lock()
    96  	defer m.mu.Unlock()
    97  	m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
    98  }
    99  
   100  func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
   101  	m.mu.RLock()
   102  	defer m.mu.RUnlock()
   103  	return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
   104  }
   105  
   106  func (m *mockMulticastGroupProtocol) makeAllNonMember() {
   107  	m.mu.Lock()
   108  	defer m.mu.Unlock()
   109  	m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
   110  }
   111  
   112  func (m *mockMulticastGroupProtocol) initializeGroups() {
   113  	m.mu.Lock()
   114  	defer m.mu.Unlock()
   115  	m.mu.genericMulticastGroup.InitializeGroupsLocked()
   116  }
   117  
   118  func (m *mockMulticastGroupProtocol) sendQueuedReports() {
   119  	m.mu.Lock()
   120  	defer m.mu.Unlock()
   121  	m.mu.genericMulticastGroup.SendQueuedReportsLocked()
   122  }
   123  
   124  // Enabled implements ip.MulticastGroupProtocol.
   125  //
   126  // Precondition: m.mu must be read locked.
   127  func (m *mockMulticastGroupProtocol) Enabled() bool {
   128  	if m.mu.TryLock() {
   129  		m.mu.Unlock() // +checklocksforce: TryLock.
   130  		m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
   131  	}
   132  
   133  	return !m.mu.disabled
   134  }
   135  
   136  // SendReport implements ip.MulticastGroupProtocol.
   137  //
   138  // Precondition: m.mu must be locked.
   139  func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
   140  	if m.mu.TryLock() {
   141  		m.mu.Unlock() // +checklocksforce: TryLock.
   142  		m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
   143  	}
   144  	if m.mu.TryRLock() {
   145  		m.mu.RUnlock() // +checklocksforce: TryLock.
   146  		m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
   147  	}
   148  
   149  	m.mu.sendReportGroupAddrCount[groupAddress]++
   150  	return !m.mu.makeQueuePackets, nil
   151  }
   152  
   153  // SendLeave implements ip.MulticastGroupProtocol.
   154  //
   155  // Precondition: m.mu must be locked.
   156  func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
   157  	if m.mu.TryLock() {
   158  		m.mu.Unlock() // +checklocksforce: TryLock.
   159  		m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
   160  	}
   161  	if m.mu.TryRLock() {
   162  		m.mu.RUnlock() // +checklocksforce: TryLock.
   163  		m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
   164  	}
   165  
   166  	m.mu.sendLeaveGroupAddrCount[groupAddress]++
   167  	return nil
   168  }
   169  
   170  // ShouldPerformProtocol implements ip.MulticastGroupProtocol.
   171  func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
   172  	return groupAddress != m.skipProtocolAddress
   173  }
   174  
   175  func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
   176  	m.mu.Lock()
   177  	defer m.mu.Unlock()
   178  
   179  	sendReportGroupAddrCount := make(map[tcpip.Address]int)
   180  	for _, a := range sendReportGroupAddresses {
   181  		sendReportGroupAddrCount[a] = 1
   182  	}
   183  
   184  	sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
   185  	for _, a := range sendLeaveGroupAddresses {
   186  		sendLeaveGroupAddrCount[a] = 1
   187  	}
   188  
   189  	diff := cmp.Diff(
   190  		&mockMulticastGroupProtocol{
   191  			mu: mockMulticastGroupProtocolProtectedFields{
   192  				sendReportGroupAddrCount: sendReportGroupAddrCount,
   193  				sendLeaveGroupAddrCount:  sendLeaveGroupAddrCount,
   194  			},
   195  		},
   196  		m,
   197  		cmp.AllowUnexported(mockMulticastGroupProtocol{}),
   198  		cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
   199  		// ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
   200  		cmp.FilterPath(
   201  			func(p cmp.Path) bool {
   202  				switch p.Last().String() {
   203  				case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
   204  					return true
   205  				default:
   206  					return false
   207  				}
   208  			},
   209  			cmp.Ignore(),
   210  		),
   211  	)
   212  	m.initLocked()
   213  	return diff
   214  }
   215  
   216  func TestJoinGroup(t *testing.T) {
   217  	tests := []struct {
   218  		name              string
   219  		addr              tcpip.Address
   220  		shouldSendReports bool
   221  	}{
   222  		{
   223  			name:              "Normal group",
   224  			addr:              addr1,
   225  			shouldSendReports: true,
   226  		},
   227  		{
   228  			name:              "All-nodes group",
   229  			addr:              addr2,
   230  			shouldSendReports: false,
   231  		},
   232  	}
   233  
   234  	for _, test := range tests {
   235  		t.Run(test.name, func(t *testing.T) {
   236  			mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
   237  			clock := faketime.NewManualClock()
   238  
   239  			mgp.init(ip.GenericMulticastProtocolOptions{
   240  				Rand:                      rand.New(rand.NewSource(0)),
   241  				Clock:                     clock,
   242  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   243  			})
   244  
   245  			// Joining a group should send a report immediately and another after
   246  			// a random interval between 0 and the maximum unsolicited report delay.
   247  			mgp.joinGroup(test.addr)
   248  			if test.shouldSendReports {
   249  				if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   250  					t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   251  				}
   252  
   253  				// Generic multicast protocol timers are expected to take the job mutex.
   254  				clock.Advance(maxUnsolicitedReportDelay)
   255  				if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   256  					t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   257  				}
   258  			}
   259  
   260  			// Should have no more messages to send.
   261  			clock.Advance(time.Hour)
   262  			if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   263  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   264  			}
   265  		})
   266  	}
   267  }
   268  
   269  func TestLeaveGroup(t *testing.T) {
   270  	tests := []struct {
   271  		name               string
   272  		addr               tcpip.Address
   273  		shouldSendMessages bool
   274  	}{
   275  		{
   276  			name:               "Normal group",
   277  			addr:               addr1,
   278  			shouldSendMessages: true,
   279  		},
   280  		{
   281  			name:               "All-nodes group",
   282  			addr:               addr2,
   283  			shouldSendMessages: false,
   284  		},
   285  	}
   286  
   287  	for _, test := range tests {
   288  		t.Run(test.name, func(t *testing.T) {
   289  			mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
   290  			clock := faketime.NewManualClock()
   291  
   292  			mgp.init(ip.GenericMulticastProtocolOptions{
   293  				Rand:                      rand.New(rand.NewSource(1)),
   294  				Clock:                     clock,
   295  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   296  			})
   297  
   298  			mgp.joinGroup(test.addr)
   299  			if test.shouldSendMessages {
   300  				if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   301  					t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   302  				}
   303  			}
   304  
   305  			// Leaving a group should send a leave report immediately and cancel any
   306  			// delayed reports.
   307  			{
   308  
   309  				if !mgp.leaveGroup(test.addr) {
   310  					t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
   311  				}
   312  			}
   313  			if test.shouldSendMessages {
   314  				if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
   315  					t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   316  				}
   317  			}
   318  
   319  			// Should have no more messages to send.
   320  			//
   321  			// Generic multicast protocol timers are expected to take the job mutex.
   322  			clock.Advance(time.Hour)
   323  			if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   324  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   325  			}
   326  		})
   327  	}
   328  }
   329  
   330  func TestHandleReport(t *testing.T) {
   331  	tests := []struct {
   332  		name             string
   333  		reportAddr       tcpip.Address
   334  		expectReportsFor []tcpip.Address
   335  	}{
   336  		{
   337  			name:             "Unpecified empty",
   338  			reportAddr:       "",
   339  			expectReportsFor: []tcpip.Address{addr1, addr2},
   340  		},
   341  		{
   342  			name:             "Unpecified any",
   343  			reportAddr:       "\x00",
   344  			expectReportsFor: []tcpip.Address{addr1, addr2},
   345  		},
   346  		{
   347  			name:             "Specified",
   348  			reportAddr:       addr1,
   349  			expectReportsFor: []tcpip.Address{addr2},
   350  		},
   351  		{
   352  			name:             "Specified all-nodes",
   353  			reportAddr:       addr3,
   354  			expectReportsFor: []tcpip.Address{addr1, addr2},
   355  		},
   356  		{
   357  			name:             "Specified other",
   358  			reportAddr:       addr4,
   359  			expectReportsFor: []tcpip.Address{addr1, addr2},
   360  		},
   361  	}
   362  
   363  	for _, test := range tests {
   364  		t.Run(test.name, func(t *testing.T) {
   365  			mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   366  			clock := faketime.NewManualClock()
   367  
   368  			mgp.init(ip.GenericMulticastProtocolOptions{
   369  				Rand:                      rand.New(rand.NewSource(2)),
   370  				Clock:                     clock,
   371  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   372  			})
   373  
   374  			mgp.joinGroup(addr1)
   375  			if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   376  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   377  			}
   378  			mgp.joinGroup(addr2)
   379  			if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   380  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   381  			}
   382  			mgp.joinGroup(addr3)
   383  			if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   384  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   385  			}
   386  
   387  			// Receiving a report for a group we have a timer scheduled for should
   388  			// cancel our delayed report timer for the group.
   389  			mgp.handleReport(test.reportAddr)
   390  			if len(test.expectReportsFor) != 0 {
   391  				// Generic multicast protocol timers are expected to take the job mutex.
   392  				clock.Advance(maxUnsolicitedReportDelay)
   393  				if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   394  					t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   395  				}
   396  			}
   397  
   398  			// Should have no more messages to send.
   399  			clock.Advance(time.Hour)
   400  			if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   401  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   402  			}
   403  		})
   404  	}
   405  }
   406  
   407  func TestHandleQuery(t *testing.T) {
   408  	tests := []struct {
   409  		name                    string
   410  		queryAddr               tcpip.Address
   411  		maxDelay                time.Duration
   412  		expectQueriedReportsFor []tcpip.Address
   413  		expectDelayedReportsFor []tcpip.Address
   414  	}{
   415  		{
   416  			name:                    "Unpecified empty",
   417  			queryAddr:               "",
   418  			maxDelay:                0,
   419  			expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
   420  			expectDelayedReportsFor: nil,
   421  		},
   422  		{
   423  			name:                    "Unpecified any",
   424  			queryAddr:               "\x00",
   425  			maxDelay:                1,
   426  			expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
   427  			expectDelayedReportsFor: nil,
   428  		},
   429  		{
   430  			name:                    "Specified",
   431  			queryAddr:               addr1,
   432  			maxDelay:                2,
   433  			expectQueriedReportsFor: []tcpip.Address{addr1},
   434  			expectDelayedReportsFor: []tcpip.Address{addr2},
   435  		},
   436  		{
   437  			name:                    "Specified all-nodes",
   438  			queryAddr:               addr3,
   439  			maxDelay:                3,
   440  			expectQueriedReportsFor: nil,
   441  			expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
   442  		},
   443  		{
   444  			name:                    "Specified other",
   445  			queryAddr:               addr4,
   446  			maxDelay:                4,
   447  			expectQueriedReportsFor: nil,
   448  			expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
   449  		},
   450  	}
   451  
   452  	for _, test := range tests {
   453  		t.Run(test.name, func(t *testing.T) {
   454  			mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   455  			clock := faketime.NewManualClock()
   456  
   457  			mgp.init(ip.GenericMulticastProtocolOptions{
   458  				Rand:                      rand.New(rand.NewSource(3)),
   459  				Clock:                     clock,
   460  				MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   461  			})
   462  
   463  			mgp.joinGroup(addr1)
   464  			if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   465  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   466  			}
   467  			mgp.joinGroup(addr2)
   468  			if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   469  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   470  			}
   471  			mgp.joinGroup(addr3)
   472  			if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   473  				t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   474  			}
   475  
   476  			// Receiving a query should make us reschedule our delayed report timer
   477  			// to some time within the new max response delay.
   478  			mgp.handleQuery(test.queryAddr, test.maxDelay)
   479  			clock.Advance(test.maxDelay)
   480  			if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   481  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   482  			}
   483  
   484  			// The groups that were not affected by the query should still send a
   485  			// report after the max unsolicited report delay.
   486  			clock.Advance(maxUnsolicitedReportDelay)
   487  			if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   488  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   489  			}
   490  
   491  			// Should have no more messages to send.
   492  			clock.Advance(time.Hour)
   493  			if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   494  				t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   495  			}
   496  		})
   497  	}
   498  }
   499  
   500  func TestJoinCount(t *testing.T) {
   501  	mgp := mockMulticastGroupProtocol{t: t}
   502  	clock := faketime.NewManualClock()
   503  
   504  	mgp.init(ip.GenericMulticastProtocolOptions{
   505  		Rand:                      rand.New(rand.NewSource(4)),
   506  		Clock:                     clock,
   507  		MaxUnsolicitedReportDelay: time.Second,
   508  	})
   509  
   510  	// Set the join count to 2 for a group.
   511  	mgp.joinGroup(addr1)
   512  	if !mgp.isLocallyJoined(addr1) {
   513  		t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
   514  	}
   515  	// Only the first join should trigger a report to be sent.
   516  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   517  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   518  	}
   519  	mgp.joinGroup(addr1)
   520  	if !mgp.isLocallyJoined(addr1) {
   521  		t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
   522  	}
   523  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   524  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   525  	}
   526  	if t.Failed() {
   527  		t.FailNow()
   528  	}
   529  
   530  	// Group should still be considered joined after leaving once.
   531  	if !mgp.leaveGroup(addr1) {
   532  		t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
   533  	}
   534  	if !mgp.isLocallyJoined(addr1) {
   535  		t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
   536  	}
   537  	// A leave report should only be sent once the join count reaches 0.
   538  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   539  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   540  	}
   541  	if t.Failed() {
   542  		t.FailNow()
   543  	}
   544  
   545  	// Leaving once more should actually remove us from the group.
   546  	if !mgp.leaveGroup(addr1) {
   547  		t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
   548  	}
   549  	if mgp.isLocallyJoined(addr1) {
   550  		t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
   551  	}
   552  	if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
   553  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   554  	}
   555  	if t.Failed() {
   556  		t.FailNow()
   557  	}
   558  
   559  	// Group should no longer be joined so we should not have anything to
   560  	// leave.
   561  	if mgp.leaveGroup(addr1) {
   562  		t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
   563  	}
   564  	if mgp.isLocallyJoined(addr1) {
   565  		t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
   566  	}
   567  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   568  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   569  	}
   570  
   571  	// Should have no more messages to send.
   572  	//
   573  	// Generic multicast protocol timers are expected to take the job mutex.
   574  	clock.Advance(time.Hour)
   575  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   576  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   577  	}
   578  }
   579  
   580  func TestMakeAllNonMemberAndInitialize(t *testing.T) {
   581  	mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
   582  	clock := faketime.NewManualClock()
   583  
   584  	mgp.init(ip.GenericMulticastProtocolOptions{
   585  		Rand:                      rand.New(rand.NewSource(3)),
   586  		Clock:                     clock,
   587  		MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   588  	})
   589  
   590  	mgp.joinGroup(addr1)
   591  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   592  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   593  	}
   594  	mgp.joinGroup(addr2)
   595  	if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   596  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   597  	}
   598  	mgp.joinGroup(addr3)
   599  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   600  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   601  	}
   602  
   603  	// Should send the leave reports for each but still consider them locally
   604  	// joined.
   605  	mgp.makeAllNonMember()
   606  	if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
   607  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   608  	}
   609  	// Generic multicast protocol timers are expected to take the job mutex.
   610  	clock.Advance(time.Hour)
   611  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   612  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   613  	}
   614  	for _, group := range []tcpip.Address{addr1, addr2, addr3} {
   615  		if !mgp.isLocallyJoined(group) {
   616  			t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
   617  		}
   618  	}
   619  
   620  	// Should send the initial set of unsolcited reports.
   621  	mgp.initializeGroups()
   622  	if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   623  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   624  	}
   625  	clock.Advance(maxUnsolicitedReportDelay)
   626  	if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   627  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   628  	}
   629  
   630  	// Should have no more messages to send.
   631  	clock.Advance(time.Hour)
   632  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   633  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   634  	}
   635  }
   636  
   637  // TestGroupStateNonMember tests that groups do not send packets when in the
   638  // non-member state, but are still considered locally joined.
   639  func TestGroupStateNonMember(t *testing.T) {
   640  	mgp := mockMulticastGroupProtocol{t: t}
   641  	clock := faketime.NewManualClock()
   642  
   643  	mgp.init(ip.GenericMulticastProtocolOptions{
   644  		Rand:                      rand.New(rand.NewSource(3)),
   645  		Clock:                     clock,
   646  		MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   647  	})
   648  	mgp.setEnabled(false)
   649  
   650  	// Joining groups should not send any reports.
   651  	mgp.joinGroup(addr1)
   652  	if !mgp.isLocallyJoined(addr1) {
   653  		t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
   654  	}
   655  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   656  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   657  	}
   658  	mgp.joinGroup(addr2)
   659  	if !mgp.isLocallyJoined(addr1) {
   660  		t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
   661  	}
   662  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   663  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   664  	}
   665  
   666  	// Receiving a query should not send any reports.
   667  	mgp.handleQuery(addr1, time.Nanosecond)
   668  	// Generic multicast protocol timers are expected to take the job mutex.
   669  	clock.Advance(time.Nanosecond)
   670  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   671  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   672  	}
   673  
   674  	// Leaving groups should not send any leave messages.
   675  	if !mgp.leaveGroup(addr1) {
   676  		t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
   677  	}
   678  	if mgp.isLocallyJoined(addr1) {
   679  		t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
   680  	}
   681  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   682  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   683  	}
   684  
   685  	clock.Advance(time.Hour)
   686  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   687  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   688  	}
   689  }
   690  
   691  func TestQueuedPackets(t *testing.T) {
   692  	clock := faketime.NewManualClock()
   693  	mgp := mockMulticastGroupProtocol{t: t}
   694  	mgp.init(ip.GenericMulticastProtocolOptions{
   695  		Rand:                      rand.New(rand.NewSource(4)),
   696  		Clock:                     clock,
   697  		MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
   698  	})
   699  
   700  	// Joining should trigger a SendReport, but mgp should report that we did not
   701  	// send the packet.
   702  	mgp.setQueuePackets(true)
   703  	mgp.joinGroup(addr1)
   704  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   705  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   706  	}
   707  
   708  	// The delayed report timer should have been cancelled since we did not send
   709  	// the initial report earlier.
   710  	clock.Advance(time.Hour)
   711  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   712  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   713  	}
   714  
   715  	// Mock being able to successfully send the report.
   716  	mgp.setQueuePackets(false)
   717  	mgp.sendQueuedReports()
   718  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   719  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   720  	}
   721  
   722  	// The delayed report (sent after the initial report) should now be sent.
   723  	clock.Advance(maxUnsolicitedReportDelay)
   724  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   725  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   726  	}
   727  
   728  	// Should not have anything else to send (we should be idle).
   729  	mgp.sendQueuedReports()
   730  	clock.Advance(time.Hour)
   731  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   732  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   733  	}
   734  
   735  	// Receive a query but mock being unable to send reports again.
   736  	mgp.setQueuePackets(true)
   737  	mgp.handleQuery(addr1, time.Nanosecond)
   738  	clock.Advance(time.Nanosecond)
   739  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   740  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   741  	}
   742  
   743  	// Mock being able to send reports again - we should have a packet queued to
   744  	// send.
   745  	mgp.setQueuePackets(false)
   746  	mgp.sendQueuedReports()
   747  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   748  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   749  	}
   750  
   751  	// Should not have anything else to send.
   752  	mgp.sendQueuedReports()
   753  	clock.Advance(time.Hour)
   754  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   755  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   756  	}
   757  
   758  	// Receive a query again, but mock being unable to send reports.
   759  	mgp.setQueuePackets(true)
   760  	mgp.handleQuery(addr1, time.Nanosecond)
   761  	clock.Advance(time.Nanosecond)
   762  	if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   763  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   764  	}
   765  
   766  	// Receiving a report should should transition us into the idle member state,
   767  	// even if we had a packet queued. We should no longer have any packets to
   768  	// send.
   769  	mgp.handleReport(addr1)
   770  	mgp.sendQueuedReports()
   771  	clock.Advance(time.Hour)
   772  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   773  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   774  	}
   775  
   776  	// When we fail to send the initial set of reports, incoming reports should
   777  	// not affect a newly joined group's reports from being sent.
   778  	mgp.setQueuePackets(true)
   779  	mgp.joinGroup(addr2)
   780  	if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   781  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   782  	}
   783  	mgp.handleReport(addr2)
   784  	// Attempting to send queued reports while still unable to send reports should
   785  	// not change the host state.
   786  	mgp.sendQueuedReports()
   787  	if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   788  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   789  	}
   790  	// Mock being able to successfully send the report.
   791  	mgp.setQueuePackets(false)
   792  	mgp.sendQueuedReports()
   793  	if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   794  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   795  	}
   796  	// The delayed report (sent after the initial report) should now be sent.
   797  	clock.Advance(maxUnsolicitedReportDelay)
   798  	if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   799  		t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   800  	}
   801  
   802  	// Should not have anything else to send.
   803  	mgp.sendQueuedReports()
   804  	clock.Advance(time.Hour)
   805  	if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
   806  		t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
   807  	}
   808  }