gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/header/mld_test.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package header 16 17 import ( 18 "encoding/binary" 19 "fmt" 20 "testing" 21 "time" 22 23 "gvisor.dev/gvisor/pkg/tcpip" 24 "gvisor.dev/gvisor/pkg/tcpip/testutil" 25 ) 26 27 func TestMLD(t *testing.T) { 28 b := []byte{ 29 // Maximum Response Delay 30 0, 0, 31 32 // Reserved 33 0, 0, 34 35 // MulticastAddress 36 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 37 } 38 39 const maxRespDelay = 513 40 binary.BigEndian.PutUint16(b, maxRespDelay) 41 42 mld := MLD(b) 43 44 if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want { 45 t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) 46 } 47 48 const newMaxRespDelay = 1234 49 mld.SetMaximumResponseDelay(newMaxRespDelay) 50 if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want { 51 t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) 52 } 53 54 if got, want := mld.MulticastAddress(), tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { 55 t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want) 56 } 57 58 multicastAddress := tcpip.AddrFrom16([16]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}) 59 mld.SetMulticastAddress(multicastAddress) 60 if got := mld.MulticastAddress(); got != multicastAddress { 61 t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress) 62 } 63 } 64 65 func TestMLDv2MaximumResponseDelay(t *testing.T) { 66 const ( 67 exponentialResponseDelayStartCode = 32768 68 mantMaxRespBits = 12 69 ) 70 71 type respCodeTest struct { 72 maxResponseCode uint16 73 expectedMaxResponseDelay time.Duration 74 } 75 76 exponentialRespDelay := func(mant, exp uint16) respCodeTest { 77 return respCodeTest{ 78 maxResponseCode: exponentialResponseDelayStartCode | mant | exp<<mantMaxRespBits, 79 expectedMaxResponseDelay: ((time.Duration(mant) | 0x1000) << (time.Duration(exp) + 3)) * time.Millisecond, 80 } 81 } 82 83 tests := []respCodeTest{ 84 { 85 maxResponseCode: 0, 86 expectedMaxResponseDelay: 0, 87 }, 88 { 89 maxResponseCode: 1, 90 expectedMaxResponseDelay: time.Millisecond, 91 }, 92 { 93 maxResponseCode: exponentialResponseDelayStartCode - 1, 94 expectedMaxResponseDelay: (exponentialResponseDelayStartCode - 1) * time.Millisecond, 95 }, 96 exponentialRespDelay(0, 0), 97 exponentialRespDelay(1, 0), 98 exponentialRespDelay(0, 1), 99 exponentialRespDelay(1, 1), 100 } 101 102 for _, test := range tests { 103 t.Run(fmt.Sprintf("Code=%d", test.maxResponseCode), func(t *testing.T) { 104 if got := MLDv2MaximumResponseDelay(test.maxResponseCode); got != test.expectedMaxResponseDelay { 105 t.Errorf("got MLDv2MaximumResponseDelay(%d) = %s, want = %s", test.maxResponseCode, got, test.expectedMaxResponseDelay) 106 } 107 }) 108 } 109 } 110 111 func TestMLDv2Query(t *testing.T) { 112 const ( 113 exponentialQueryIntervalStartCode = 128 114 mantQQICBits = 4 115 ) 116 117 qrvs := []uint8{0, 1, 2, 3, 4, 5, 6, 7} 118 119 type qqicTest struct { 120 val uint8 121 expectedInterval time.Duration 122 } 123 124 exponentialQQIC := func(mant, exp uint8) qqicTest { 125 return qqicTest{ 126 val: exponentialQueryIntervalStartCode | mant | exp<<mantQQICBits, 127 expectedInterval: ((time.Duration(mant) | 0x10) << (time.Duration(exp) + 3)) * time.Second, 128 } 129 } 130 131 queryIntervalCodes := []qqicTest{ 132 { 133 val: 0, 134 expectedInterval: 0, 135 }, 136 { 137 val: 1, 138 expectedInterval: time.Second, 139 }, 140 { 141 val: exponentialQueryIntervalStartCode - 1, 142 expectedInterval: (exponentialQueryIntervalStartCode - 1) * time.Second, 143 }, 144 { 145 val: exponentialQueryIntervalStartCode, 146 expectedInterval: exponentialQueryIntervalStartCode * time.Second, 147 }, 148 exponentialQQIC(0, 0), 149 exponentialQQIC(1, 0), 150 exponentialQQIC(0, 1), 151 exponentialQQIC(1, 1), 152 } 153 154 sourceAddrs := []tcpip.Address{ 155 testutil.MustParse6("a00::a"), 156 testutil.MustParse6("b00::b"), 157 testutil.MustParse6("c00::c"), 158 } 159 160 sources := []struct { 161 count uint16 162 expectedOK bool 163 }{ 164 { 165 count: 0, 166 expectedOK: true, 167 }, 168 { 169 count: 0, 170 expectedOK: true, 171 }, 172 { 173 count: 1, 174 expectedOK: true, 175 }, 176 { 177 count: uint16(len(sourceAddrs)), 178 expectedOK: true, 179 }, 180 { 181 count: uint16(len(sourceAddrs) + 1), 182 expectedOK: false, 183 }, 184 } 185 186 for _, respCode := range []uint16{0x0001, 0x0100} { 187 for _, qrv := range qrvs { 188 for _, qqic := range queryIntervalCodes { 189 for _, source := range sources { 190 t.Run(fmt.Sprintf("MaxRespCode=%d QRV=%d QQIC=%d Sources=%d", respCode, qrv, qqic.val, source.count), func(t *testing.T) { 191 b := []byte{ 192 // Maximum Response Code 193 0, 0, 194 195 // Reserved 196 0, 0, 197 198 // MulticastAddress 199 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 200 201 // Resv, S, QRV 202 qrv, 203 204 // QQIC 205 qqic.val, 206 207 // Number of Sources 208 0, 0, 209 210 // Sources 211 0xA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xA, 212 0xB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xB, 213 0xC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xC, 214 } 215 216 binary.BigEndian.PutUint16(b[mldMaximumResponseDelayOffset:], respCode) 217 binary.BigEndian.PutUint16(b[mldv2QueryNumberOfSourcesOffset:], source.count) 218 219 query := MLDv2Query(b) 220 if got := query.MaximumResponseCode(); got != respCode { 221 t.Errorf("got query.MaximumResponseCode() = %d, want = %d", got, respCode) 222 } 223 if got := query.QuerierRobustnessVariable(); got != qrv { 224 t.Errorf("got query.QuerierRobustnessVariable() = %d, want = %d", got, qrv) 225 } 226 if got := query.QuerierQueryInterval(); got != qqic.expectedInterval { 227 t.Errorf("got query.QuerierQueryInterval() = %s, want = %s", got, qqic.expectedInterval) 228 } 229 if got, want := query.MulticastAddress(), tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { 230 t.Errorf("got query.MulticastAddress() = %s, want = %s", got, want) 231 } 232 233 iterator, ok := query.Sources() 234 if ok != source.expectedOK { 235 t.Errorf("got query.Sources() = (_, %t), want = (_, %t)", ok, source.expectedOK) 236 } 237 if !source.expectedOK { 238 return 239 } 240 241 sourceAddrs := sourceAddrs[:source.count] 242 for i := uint16(0); ; i++ { 243 if len(sourceAddrs) == 0 { 244 break 245 } 246 247 source, ok := iterator.Next() 248 if !ok { 249 t.Fatalf("expected %d-th source", i) 250 } 251 if source != sourceAddrs[0] { 252 t.Errorf("got %d-th source = %s, want = %s", i, source, sourceAddrs[0]) 253 } 254 255 sourceAddrs = sourceAddrs[1:] 256 } 257 if len(sourceAddrs) != 0 { 258 t.Errorf("missing sources = %#v", sourceAddrs) 259 } 260 if source, ok := iterator.Next(); ok { 261 t.Errorf("unexpected source = %s", source) 262 } 263 }) 264 } 265 } 266 } 267 } 268 } 269 270 func TestMLDv2Report(t *testing.T) { 271 var ( 272 mcastAddr1 = testutil.MustParse6("ff02::a") 273 mcastAddr2 = testutil.MustParse6("ff02::b") 274 mcastAddr3 = testutil.MustParse6("ff02::c") 275 276 srcAddr1 = testutil.MustParse6("a::a") 277 srcAddr2 = testutil.MustParse6("b::b") 278 srcAddr3 = testutil.MustParse6("c::c") 279 ) 280 281 tests := []struct { 282 name string 283 serializer MLDv2ReportSerializer 284 }{ 285 { 286 name: "zero reports", 287 serializer: MLDv2ReportSerializer{}, 288 }, 289 { 290 name: "one record with one source", 291 serializer: MLDv2ReportSerializer{ 292 Records: []MLDv2ReportMulticastAddressRecordSerializer{ 293 { 294 RecordType: MLDv2ReportRecordModeIsInclude, 295 MulticastAddress: mcastAddr1, 296 Sources: []tcpip.Address{srcAddr1}, 297 }, 298 }, 299 }, 300 }, 301 { 302 name: "multiple records with multiple sources", 303 serializer: MLDv2ReportSerializer{ 304 Records: []MLDv2ReportMulticastAddressRecordSerializer{ 305 { 306 RecordType: MLDv2ReportRecordModeIsInclude, 307 MulticastAddress: mcastAddr1, 308 Sources: nil, 309 }, 310 { 311 RecordType: MLDv2ReportRecordModeIsExclude, 312 MulticastAddress: mcastAddr2, 313 Sources: []tcpip.Address{srcAddr1, srcAddr2, srcAddr3}, 314 }, 315 { 316 RecordType: MLDv2ReportRecordChangeToIncludeMode, 317 MulticastAddress: mcastAddr3, 318 Sources: []tcpip.Address{srcAddr1, srcAddr2}, 319 }, 320 }, 321 }, 322 }, 323 } 324 325 for _, test := range tests { 326 t.Run(test.name, func(t *testing.T) { 327 b := make([]byte, test.serializer.Length()) 328 test.serializer.SerializeInto(b) 329 330 report := MLDv2Report(b) 331 expectedRecords := test.serializer.Records 332 333 records := report.MulticastAddressRecords() 334 for { 335 if len(expectedRecords) == 0 { 336 break 337 } 338 339 record, res := records.Next() 340 if res != MLDv2ReportMulticastAddressRecordIteratorNextOk { 341 t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, MLDv2ReportMulticastAddressRecordIteratorNextOk) 342 } 343 344 if got, want := record.RecordType(), expectedRecords[0].RecordType; got != want { 345 t.Errorf("got record.RecordType() = %d, want = %d", got, want) 346 } 347 348 if got := record.AuxDataLen(); got != 0 { 349 t.Errorf("got record.AuxDataLen() = %d, want = 0", got) 350 } 351 352 if got, want := record.MulticastAddress(), expectedRecords[0].MulticastAddress; got != want { 353 t.Errorf("got record.MulticastAddress() = %s, want = %s", got, want) 354 } 355 356 sources, ok := record.Sources() 357 if !ok { 358 t.Error("got record.Sources() = (_, false), want = (_, true)") 359 continue 360 } 361 362 expectedSources := expectedRecords[0].Sources 363 for { 364 if len(expectedSources) == 0 { 365 break 366 } 367 368 source, ok := sources.Next() 369 if !ok { 370 t.Fatal("got sources.Next() = (_, false), want = (_, true)") 371 } 372 if source != expectedSources[0] { 373 t.Errorf("got sources.Next() = %s, want = %s", source, expectedSources[0]) 374 } 375 376 expectedSources = expectedSources[1:] 377 } 378 379 expectedRecords = expectedRecords[1:] 380 } 381 382 if record, res := records.Next(); res != MLDv2ReportMulticastAddressRecordIteratorNextDone { 383 t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, MLDv2ReportMulticastAddressRecordIteratorNextDone) 384 } 385 }) 386 } 387 }