gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/internal/testutil/testutil.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package testutil defines types and functions used to test Network Layer 16 // functionality such as IP fragmentation. 17 package testutil 18 19 import ( 20 "fmt" 21 "math/rand" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 "gvisor.dev/gvisor/pkg/buffer" 26 "gvisor.dev/gvisor/pkg/tcpip" 27 "gvisor.dev/gvisor/pkg/tcpip/checker" 28 "gvisor.dev/gvisor/pkg/tcpip/header" 29 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 30 "gvisor.dev/gvisor/pkg/tcpip/stack" 31 ) 32 33 // MockLinkEndpoint is an endpoint used for testing, it stores packets written 34 // to it and can mock errors. 35 type MockLinkEndpoint struct { 36 // WrittenPackets is where packets written to the endpoint are stored. 37 WrittenPackets []*stack.PacketBuffer 38 39 mtu uint32 40 err tcpip.Error 41 allowPackets int 42 } 43 44 // NewMockLinkEndpoint creates a new MockLinkEndpoint. 45 // 46 // err is the error that will be returned once allowPackets packets are written 47 // to the endpoint. 48 func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { 49 return &MockLinkEndpoint{ 50 mtu: mtu, 51 err: err, 52 allowPackets: allowPackets, 53 } 54 } 55 56 // MTU implements LinkEndpoint.MTU. 57 func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } 58 59 // Capabilities implements LinkEndpoint.Capabilities. 60 func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } 61 62 // MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. 63 func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } 64 65 // LinkAddress implements LinkEndpoint.LinkAddress. 66 func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } 67 68 // WritePackets implements LinkEndpoint.WritePackets. 69 func (ep *MockLinkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { 70 var n int 71 for _, pkt := range pkts.AsSlice() { 72 if ep.allowPackets == 0 { 73 return n, ep.err 74 } 75 ep.allowPackets-- 76 ep.WrittenPackets = append(ep.WrittenPackets, pkt.IncRef()) 77 n++ 78 } 79 return n, nil 80 } 81 82 // Attach implements LinkEndpoint.Attach. 83 func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} 84 85 // IsAttached implements LinkEndpoint.IsAttached. 86 func (*MockLinkEndpoint) IsAttached() bool { return false } 87 88 // Wait implements LinkEndpoint.Wait. 89 func (*MockLinkEndpoint) Wait() {} 90 91 // ARPHardwareType implements LinkEndpoint.ARPHardwareType. 92 func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } 93 94 // AddHeader implements LinkEndpoint.AddHeader. 95 func (*MockLinkEndpoint) AddHeader(*stack.PacketBuffer) {} 96 97 // ParseHeader implements LinkEndpoint.ParseHeader. 98 func (*MockLinkEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true } 99 100 // Close releases all resources. 101 func (ep *MockLinkEndpoint) Close() { 102 for _, pkt := range ep.WrittenPackets { 103 pkt.DecRef() 104 } 105 ep.WrittenPackets = nil 106 } 107 108 // MakeRandPkt generates a randomized packet. transportHeaderLength indicates 109 // how many random bytes will be copied in the Transport Header. 110 // extraHeaderReserveLength indicates how much extra space will be reserved for 111 // the other headers. The payload is made from Views of the sizes listed in 112 // viewSizes. 113 func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { 114 var buf buffer.Buffer 115 116 for _, s := range viewSizes { 117 newView := buffer.NewViewSize(s) 118 if _, err := rand.Read(newView.AsSlice()); err != nil { 119 panic(fmt.Sprintf("rand.Read: %s", err)) 120 } 121 buf.Append(newView) 122 } 123 124 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 125 ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, 126 Payload: buf, 127 }) 128 pkt.NetworkProtocolNumber = proto 129 if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { 130 panic(fmt.Sprintf("rand.Read: %s", err)) 131 } 132 return pkt 133 } 134 135 func checkIGMPStats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) { 136 t.Helper() 137 138 if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != reports { 139 t.Errorf("got s.Stats().IGMP.PacketsSent.V2MembershipReport.Value() = %d, want = %d", got, reports) 140 } 141 if got := s.Stats().IGMP.PacketsSent.V3MembershipReport.Value(); got != reportsV2 { 142 t.Errorf("got s.Stats().IGMP.PacketsSent.V3MembershipReport.Value() = %d, want = %d", got, reportsV2) 143 } 144 if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != leaves { 145 t.Errorf("got s.Stats().IGMP.PacketsSent.LeaveGroup.Value() = %d, want = %d", got, leaves) 146 } 147 } 148 149 // CheckIGMPv2Stats checks IGMPv2 stats. 150 func CheckIGMPv2Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) { 151 t.Helper() 152 // We still check V3 stats in V2 compatibility tests because the test may send 153 // V3 reports before we drop into compatibility mode. 154 checkIGMPStats(t, s, reports, leaves, reportsV2) 155 } 156 157 // CheckIGMPv3Stats checks IGMPv3 stats. 158 func CheckIGMPv3Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) { 159 t.Helper() 160 // In IGMPv3 tests, reports/leaves are just IGMPv3 reports. 161 checkIGMPStats(t, s, 0 /* reports */, 0 /* leaves */, reports+leaves+reportsV2) 162 } 163 164 func checkMLDStats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) { 165 t.Helper() 166 167 if got := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport.Value(); got != reports { 168 t.Errorf("got s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport.Value() = %d, want = %d", got, reports) 169 } 170 if got := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReportV2.Value(); got != reportsV2 { 171 t.Errorf("got s.Stats().ICMP.V6.PacketsSent.MulticastListenerReportV2.Value() = %d, want = %d", got, reportsV2) 172 } 173 if got := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone.Value(); got != leaves { 174 t.Errorf("got s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone.Value() = %d, want = %d", got, leaves) 175 } 176 } 177 178 // CheckMLDv1Stats checks MLDv1 stats. 179 func CheckMLDv1Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) { 180 t.Helper() 181 // We still check V2 stats in V1 compatibility tests because the test may send 182 // V2 reports before we drop into compatibility mode. 183 checkMLDStats(t, s, reports, leaves, reportsV2) 184 } 185 186 // CheckMLDv2Stats checks MLDv2 stats. 187 func CheckMLDv2Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) { 188 t.Helper() 189 // In MLDv2 tests, reports/leaves are just MLDv2 reports. 190 checkMLDStats(t, s, 0 /* reports */, 0 /* leaves */, reports+leaves+reportsV2) 191 } 192 193 // ValidateIGMPv3ReportWithRecords validates an IGMPv3 report. 194 // 195 // Note that observed records are removed from expectedRecords. No error is 196 // logged if the report does not have all the records expected. 197 func ValidateIGMPv3ReportWithRecords(t *testing.T, v *buffer.View, srcAddr tcpip.Address, expectedRecords map[tcpip.Address]header.IGMPv3ReportRecordType) { 198 t.Helper() 199 200 checker.IPv4(t, v, 201 checker.SrcAddr(srcAddr), 202 checker.DstAddr(header.IGMPv3RoutersAddress), 203 checker.TTL(header.IGMPTTL), 204 checker.IPv4RouterAlert(), 205 checker.IGMPv3Report(expectedRecords), 206 ) 207 } 208 209 // ValidateIGMPv3Report validates an IGMPv3 report. 210 func ValidateIGMPv3Report(t *testing.T, v *buffer.View, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.IGMPv3ReportRecordType) { 211 t.Helper() 212 213 records := make(map[tcpip.Address]header.IGMPv3ReportRecordType) 214 for _, addr := range addrs { 215 records[addr] = recordType 216 } 217 218 ValidateIGMPv3ReportWithRecords(t, v, srcAddr, records) 219 220 if diff := cmp.Diff(map[tcpip.Address]header.IGMPv3ReportRecordType{}, records); diff != "" { 221 t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff) 222 } 223 } 224 225 // ValidateIGMPv3RecordsAcrossReports validates IGMPv3 records across one or 226 // more reports. 227 func ValidateIGMPv3RecordsAcrossReports(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.IGMPv3ReportRecordType) { 228 t.Helper() 229 230 expectedRecords := make(map[tcpip.Address]header.IGMPv3ReportRecordType) 231 for _, addr := range addrs { 232 expectedRecords[addr] = recordType 233 } 234 235 for len(expectedRecords) != 0 { 236 p := e.Read() 237 if p == nil { 238 t.Fatalf("expected IGMP message with expectedRecords = %#v", expectedRecords) 239 } 240 v := stack.PayloadSince(p.NetworkHeader()) 241 ValidateIGMPv3ReportWithRecords(t, v, srcAddr, expectedRecords) 242 v.Release() 243 p.DecRef() 244 } 245 246 if diff := cmp.Diff(map[tcpip.Address]header.IGMPv3ReportRecordType{}, expectedRecords); diff != "" { 247 t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff) 248 } 249 } 250 251 // ValidMultipleIGMPv2ReportLeaves validates the reception of multiple IGMPv2 252 // report/leave messages. 253 func ValidMultipleIGMPv2ReportLeaves(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, leave bool) { 254 t.Helper() 255 256 expectedGroups := make(map[tcpip.Address]struct{}) 257 for _, addr := range addrs { 258 expectedGroups[addr] = struct{}{} 259 } 260 261 igmpType := header.IGMPv2MembershipReport 262 if leave { 263 igmpType = header.IGMPLeaveGroup 264 } 265 266 for len(expectedGroups) != 0 { 267 p := e.Read() 268 if p == nil { 269 t.Fatalf("expected IGMP message with expectedGroups = %#v", expectedGroups) 270 } 271 v := stack.PayloadSince(p.NetworkHeader()) 272 checker.IPv4(t, v, 273 checker.SrcAddr(srcAddr), 274 checker.TTL(header.IGMPTTL), 275 checker.IPv4RouterAlert(), 276 checker.IGMP( 277 checker.IGMPType(igmpType), 278 checker.IGMPMaxRespTime(0), 279 checker.IGMPGroupAddressUnordered(expectedGroups), 280 ), 281 ) 282 v.Release() 283 p.DecRef() 284 } 285 286 if diff := cmp.Diff(map[tcpip.Address]struct{}{}, expectedGroups); diff != "" { 287 t.Errorf("post-validation groups map mismatch (-want +got):\n%s", diff) 288 } 289 } 290 291 // ValidateMLDv2ReportWithRecords validates an MLDv2 report. 292 // 293 // Note that observed records are removed from expectedRecords. No error is 294 // logged if the report does not have all the records expected. 295 func ValidateMLDv2ReportWithRecords(t *testing.T, v *buffer.View, srcAddr tcpip.Address, expectedRecords map[tcpip.Address]header.MLDv2ReportRecordType) { 296 t.Helper() 297 298 checker.IPv6WithExtHdr(t, v, 299 checker.IPv6ExtHdr( 300 checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), 301 ), 302 checker.SrcAddr(srcAddr), 303 checker.DstAddr(header.MLDv2RoutersAddress), 304 checker.TTL(header.MLDHopLimit), 305 checker.MLDv2Report(expectedRecords), 306 ) 307 } 308 309 // ValidateMLDv2Report validates an MLDv2 report. 310 func ValidateMLDv2Report(t *testing.T, v *buffer.View, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.MLDv2ReportRecordType) { 311 t.Helper() 312 313 records := make(map[tcpip.Address]header.MLDv2ReportRecordType) 314 for _, addr := range addrs { 315 records[addr] = recordType 316 } 317 318 ValidateMLDv2ReportWithRecords(t, v, srcAddr, records) 319 320 if diff := cmp.Diff(map[tcpip.Address]header.MLDv2ReportRecordType{}, records); diff != "" { 321 t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff) 322 } 323 } 324 325 // ValidateMLDv2RecordsAcrossReports validates MLDv2 records across one or more 326 // reports. 327 func ValidateMLDv2RecordsAcrossReports(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.MLDv2ReportRecordType) { 328 t.Helper() 329 330 expectedRecords := make(map[tcpip.Address]header.MLDv2ReportRecordType) 331 for _, addr := range addrs { 332 expectedRecords[addr] = recordType 333 } 334 335 for len(expectedRecords) != 0 { 336 p := e.Read() 337 if p == nil { 338 t.Fatalf("expected MLD Message with expectedRecords = %#v", expectedRecords) 339 } 340 v := stack.PayloadSince(p.NetworkHeader()) 341 ValidateMLDv2ReportWithRecords(t, v, srcAddr, expectedRecords) 342 v.Release() 343 p.DecRef() 344 } 345 346 if diff := cmp.Diff(map[tcpip.Address]header.MLDv2ReportRecordType{}, expectedRecords); diff != "" { 347 t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff) 348 } 349 } 350 351 // ValidMultipleMLDv1ReportLeaves validates the reception of multiple MLDv1 352 // report/leave messages. 353 func ValidMultipleMLDv1ReportLeaves(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, leave bool) { 354 t.Helper() 355 356 expectedGroups := make(map[tcpip.Address]struct{}) 357 for _, addr := range addrs { 358 expectedGroups[addr] = struct{}{} 359 } 360 361 mldType := header.ICMPv6MulticastListenerReport 362 if leave { 363 mldType = header.ICMPv6MulticastListenerDone 364 } 365 366 for len(expectedGroups) != 0 { 367 p := e.Read() 368 if p == nil { 369 t.Fatalf("expected MLD Message with expectedGroups = %#v", expectedGroups) 370 } 371 v := stack.PayloadSince(p.NetworkHeader()) 372 checker.IPv6WithExtHdr(t, v, 373 checker.IPv6ExtHdr( 374 checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), 375 ), 376 checker.SrcAddr(srcAddr), 377 checker.TTL(header.MLDHopLimit), 378 checker.MLD(mldType, header.MLDMinimumSize, 379 checker.MLDMaxRespDelay(0), 380 checker.MLDMulticastAddressUnordered(expectedGroups), 381 ), 382 ) 383 v.Release() 384 p.DecRef() 385 } 386 387 if diff := cmp.Diff(map[tcpip.Address]struct{}{}, expectedGroups); diff != "" { 388 t.Errorf("post-validation groups map mismatch (-want +got):\n%s", diff) 389 } 390 }