gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/header/mld_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package header
    16  
    17  import (
    18  	"encoding/binary"
    19  	"fmt"
    20  	"testing"
    21  	"time"
    22  
    23  	"gvisor.dev/gvisor/pkg/tcpip"
    24  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    25  )
    26  
    27  func TestMLD(t *testing.T) {
    28  	b := []byte{
    29  		// Maximum Response Delay
    30  		0, 0,
    31  
    32  		// Reserved
    33  		0, 0,
    34  
    35  		// MulticastAddress
    36  		1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
    37  	}
    38  
    39  	const maxRespDelay = 513
    40  	binary.BigEndian.PutUint16(b, maxRespDelay)
    41  
    42  	mld := MLD(b)
    43  
    44  	if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want {
    45  		t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
    46  	}
    47  
    48  	const newMaxRespDelay = 1234
    49  	mld.SetMaximumResponseDelay(newMaxRespDelay)
    50  	if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want {
    51  		t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
    52  	}
    53  
    54  	if got, want := mld.MulticastAddress(), tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want {
    55  		t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want)
    56  	}
    57  
    58  	multicastAddress := tcpip.AddrFrom16([16]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})
    59  	mld.SetMulticastAddress(multicastAddress)
    60  	if got := mld.MulticastAddress(); got != multicastAddress {
    61  		t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress)
    62  	}
    63  }
    64  
    65  func TestMLDv2MaximumResponseDelay(t *testing.T) {
    66  	const (
    67  		exponentialResponseDelayStartCode = 32768
    68  		mantMaxRespBits                   = 12
    69  	)
    70  
    71  	type respCodeTest struct {
    72  		maxResponseCode          uint16
    73  		expectedMaxResponseDelay time.Duration
    74  	}
    75  
    76  	exponentialRespDelay := func(mant, exp uint16) respCodeTest {
    77  		return respCodeTest{
    78  			maxResponseCode:          exponentialResponseDelayStartCode | mant | exp<<mantMaxRespBits,
    79  			expectedMaxResponseDelay: ((time.Duration(mant) | 0x1000) << (time.Duration(exp) + 3)) * time.Millisecond,
    80  		}
    81  	}
    82  
    83  	tests := []respCodeTest{
    84  		{
    85  			maxResponseCode:          0,
    86  			expectedMaxResponseDelay: 0,
    87  		},
    88  		{
    89  			maxResponseCode:          1,
    90  			expectedMaxResponseDelay: time.Millisecond,
    91  		},
    92  		{
    93  			maxResponseCode:          exponentialResponseDelayStartCode - 1,
    94  			expectedMaxResponseDelay: (exponentialResponseDelayStartCode - 1) * time.Millisecond,
    95  		},
    96  		exponentialRespDelay(0, 0),
    97  		exponentialRespDelay(1, 0),
    98  		exponentialRespDelay(0, 1),
    99  		exponentialRespDelay(1, 1),
   100  	}
   101  
   102  	for _, test := range tests {
   103  		t.Run(fmt.Sprintf("Code=%d", test.maxResponseCode), func(t *testing.T) {
   104  			if got := MLDv2MaximumResponseDelay(test.maxResponseCode); got != test.expectedMaxResponseDelay {
   105  				t.Errorf("got MLDv2MaximumResponseDelay(%d) = %s, want = %s", test.maxResponseCode, got, test.expectedMaxResponseDelay)
   106  			}
   107  		})
   108  	}
   109  }
   110  
   111  func TestMLDv2Query(t *testing.T) {
   112  	const (
   113  		exponentialQueryIntervalStartCode = 128
   114  		mantQQICBits                      = 4
   115  	)
   116  
   117  	qrvs := []uint8{0, 1, 2, 3, 4, 5, 6, 7}
   118  
   119  	type qqicTest struct {
   120  		val              uint8
   121  		expectedInterval time.Duration
   122  	}
   123  
   124  	exponentialQQIC := func(mant, exp uint8) qqicTest {
   125  		return qqicTest{
   126  			val:              exponentialQueryIntervalStartCode | mant | exp<<mantQQICBits,
   127  			expectedInterval: ((time.Duration(mant) | 0x10) << (time.Duration(exp) + 3)) * time.Second,
   128  		}
   129  	}
   130  
   131  	queryIntervalCodes := []qqicTest{
   132  		{
   133  			val:              0,
   134  			expectedInterval: 0,
   135  		},
   136  		{
   137  			val:              1,
   138  			expectedInterval: time.Second,
   139  		},
   140  		{
   141  			val:              exponentialQueryIntervalStartCode - 1,
   142  			expectedInterval: (exponentialQueryIntervalStartCode - 1) * time.Second,
   143  		},
   144  		{
   145  			val:              exponentialQueryIntervalStartCode,
   146  			expectedInterval: exponentialQueryIntervalStartCode * time.Second,
   147  		},
   148  		exponentialQQIC(0, 0),
   149  		exponentialQQIC(1, 0),
   150  		exponentialQQIC(0, 1),
   151  		exponentialQQIC(1, 1),
   152  	}
   153  
   154  	sourceAddrs := []tcpip.Address{
   155  		testutil.MustParse6("a00::a"),
   156  		testutil.MustParse6("b00::b"),
   157  		testutil.MustParse6("c00::c"),
   158  	}
   159  
   160  	sources := []struct {
   161  		count      uint16
   162  		expectedOK bool
   163  	}{
   164  		{
   165  			count:      0,
   166  			expectedOK: true,
   167  		},
   168  		{
   169  			count:      0,
   170  			expectedOK: true,
   171  		},
   172  		{
   173  			count:      1,
   174  			expectedOK: true,
   175  		},
   176  		{
   177  			count:      uint16(len(sourceAddrs)),
   178  			expectedOK: true,
   179  		},
   180  		{
   181  			count:      uint16(len(sourceAddrs) + 1),
   182  			expectedOK: false,
   183  		},
   184  	}
   185  
   186  	for _, respCode := range []uint16{0x0001, 0x0100} {
   187  		for _, qrv := range qrvs {
   188  			for _, qqic := range queryIntervalCodes {
   189  				for _, source := range sources {
   190  					t.Run(fmt.Sprintf("MaxRespCode=%d QRV=%d QQIC=%d Sources=%d", respCode, qrv, qqic.val, source.count), func(t *testing.T) {
   191  						b := []byte{
   192  							// Maximum Response Code
   193  							0, 0,
   194  
   195  							// Reserved
   196  							0, 0,
   197  
   198  							// MulticastAddress
   199  							1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
   200  
   201  							// Resv, S, QRV
   202  							qrv,
   203  
   204  							// QQIC
   205  							qqic.val,
   206  
   207  							// Number of Sources
   208  							0, 0,
   209  
   210  							// Sources
   211  							0xA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xA,
   212  							0xB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xB,
   213  							0xC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xC,
   214  						}
   215  
   216  						binary.BigEndian.PutUint16(b[mldMaximumResponseDelayOffset:], respCode)
   217  						binary.BigEndian.PutUint16(b[mldv2QueryNumberOfSourcesOffset:], source.count)
   218  
   219  						query := MLDv2Query(b)
   220  						if got := query.MaximumResponseCode(); got != respCode {
   221  							t.Errorf("got query.MaximumResponseCode() = %d, want = %d", got, respCode)
   222  						}
   223  						if got := query.QuerierRobustnessVariable(); got != qrv {
   224  							t.Errorf("got query.QuerierRobustnessVariable() = %d, want = %d", got, qrv)
   225  						}
   226  						if got := query.QuerierQueryInterval(); got != qqic.expectedInterval {
   227  							t.Errorf("got query.QuerierQueryInterval() = %s, want = %s", got, qqic.expectedInterval)
   228  						}
   229  						if got, want := query.MulticastAddress(), tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want {
   230  							t.Errorf("got query.MulticastAddress() = %s, want = %s", got, want)
   231  						}
   232  
   233  						iterator, ok := query.Sources()
   234  						if ok != source.expectedOK {
   235  							t.Errorf("got query.Sources() = (_, %t), want = (_, %t)", ok, source.expectedOK)
   236  						}
   237  						if !source.expectedOK {
   238  							return
   239  						}
   240  
   241  						sourceAddrs := sourceAddrs[:source.count]
   242  						for i := uint16(0); ; i++ {
   243  							if len(sourceAddrs) == 0 {
   244  								break
   245  							}
   246  
   247  							source, ok := iterator.Next()
   248  							if !ok {
   249  								t.Fatalf("expected %d-th source", i)
   250  							}
   251  							if source != sourceAddrs[0] {
   252  								t.Errorf("got %d-th source = %s, want = %s", i, source, sourceAddrs[0])
   253  							}
   254  
   255  							sourceAddrs = sourceAddrs[1:]
   256  						}
   257  						if len(sourceAddrs) != 0 {
   258  							t.Errorf("missing sources = %#v", sourceAddrs)
   259  						}
   260  						if source, ok := iterator.Next(); ok {
   261  							t.Errorf("unexpected source = %s", source)
   262  						}
   263  					})
   264  				}
   265  			}
   266  		}
   267  	}
   268  }
   269  
   270  func TestMLDv2Report(t *testing.T) {
   271  	var (
   272  		mcastAddr1 = testutil.MustParse6("ff02::a")
   273  		mcastAddr2 = testutil.MustParse6("ff02::b")
   274  		mcastAddr3 = testutil.MustParse6("ff02::c")
   275  
   276  		srcAddr1 = testutil.MustParse6("a::a")
   277  		srcAddr2 = testutil.MustParse6("b::b")
   278  		srcAddr3 = testutil.MustParse6("c::c")
   279  	)
   280  
   281  	tests := []struct {
   282  		name       string
   283  		serializer MLDv2ReportSerializer
   284  	}{
   285  		{
   286  			name:       "zero reports",
   287  			serializer: MLDv2ReportSerializer{},
   288  		},
   289  		{
   290  			name: "one record with one source",
   291  			serializer: MLDv2ReportSerializer{
   292  				Records: []MLDv2ReportMulticastAddressRecordSerializer{
   293  					{
   294  						RecordType:       MLDv2ReportRecordModeIsInclude,
   295  						MulticastAddress: mcastAddr1,
   296  						Sources:          []tcpip.Address{srcAddr1},
   297  					},
   298  				},
   299  			},
   300  		},
   301  		{
   302  			name: "multiple records with multiple sources",
   303  			serializer: MLDv2ReportSerializer{
   304  				Records: []MLDv2ReportMulticastAddressRecordSerializer{
   305  					{
   306  						RecordType:       MLDv2ReportRecordModeIsInclude,
   307  						MulticastAddress: mcastAddr1,
   308  						Sources:          nil,
   309  					},
   310  					{
   311  						RecordType:       MLDv2ReportRecordModeIsExclude,
   312  						MulticastAddress: mcastAddr2,
   313  						Sources:          []tcpip.Address{srcAddr1, srcAddr2, srcAddr3},
   314  					},
   315  					{
   316  						RecordType:       MLDv2ReportRecordChangeToIncludeMode,
   317  						MulticastAddress: mcastAddr3,
   318  						Sources:          []tcpip.Address{srcAddr1, srcAddr2},
   319  					},
   320  				},
   321  			},
   322  		},
   323  	}
   324  
   325  	for _, test := range tests {
   326  		t.Run(test.name, func(t *testing.T) {
   327  			b := make([]byte, test.serializer.Length())
   328  			test.serializer.SerializeInto(b)
   329  
   330  			report := MLDv2Report(b)
   331  			expectedRecords := test.serializer.Records
   332  
   333  			records := report.MulticastAddressRecords()
   334  			for {
   335  				if len(expectedRecords) == 0 {
   336  					break
   337  				}
   338  
   339  				record, res := records.Next()
   340  				if res != MLDv2ReportMulticastAddressRecordIteratorNextOk {
   341  					t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, MLDv2ReportMulticastAddressRecordIteratorNextOk)
   342  				}
   343  
   344  				if got, want := record.RecordType(), expectedRecords[0].RecordType; got != want {
   345  					t.Errorf("got record.RecordType() = %d, want = %d", got, want)
   346  				}
   347  
   348  				if got := record.AuxDataLen(); got != 0 {
   349  					t.Errorf("got record.AuxDataLen() = %d, want = 0", got)
   350  				}
   351  
   352  				if got, want := record.MulticastAddress(), expectedRecords[0].MulticastAddress; got != want {
   353  					t.Errorf("got record.MulticastAddress() = %s, want = %s", got, want)
   354  				}
   355  
   356  				sources, ok := record.Sources()
   357  				if !ok {
   358  					t.Error("got record.Sources() = (_, false), want = (_, true)")
   359  					continue
   360  				}
   361  
   362  				expectedSources := expectedRecords[0].Sources
   363  				for {
   364  					if len(expectedSources) == 0 {
   365  						break
   366  					}
   367  
   368  					source, ok := sources.Next()
   369  					if !ok {
   370  						t.Fatal("got sources.Next() = (_, false), want = (_, true)")
   371  					}
   372  					if source != expectedSources[0] {
   373  						t.Errorf("got sources.Next() = %s, want = %s", source, expectedSources[0])
   374  					}
   375  
   376  					expectedSources = expectedSources[1:]
   377  				}
   378  
   379  				expectedRecords = expectedRecords[1:]
   380  			}
   381  
   382  			if record, res := records.Next(); res != MLDv2ReportMulticastAddressRecordIteratorNextDone {
   383  				t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, MLDv2ReportMulticastAddressRecordIteratorNextDone)
   384  			}
   385  		})
   386  	}
   387  }