gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/checker/checker.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package checker provides helper functions to check networking packets for 16 // validity. 17 package checker 18 19 import ( 20 "encoding/binary" 21 "slices" 22 "testing" 23 "time" 24 25 "github.com/google/go-cmp/cmp" 26 "gvisor.dev/gvisor/pkg/buffer" 27 "gvisor.dev/gvisor/pkg/tcpip" 28 "gvisor.dev/gvisor/pkg/tcpip/checksum" 29 "gvisor.dev/gvisor/pkg/tcpip/header" 30 "gvisor.dev/gvisor/pkg/tcpip/seqnum" 31 ) 32 33 // NetworkChecker is a function to check a property of a network packet. 34 type NetworkChecker func(*testing.T, []header.Network) 35 36 // TransportChecker is a function to check a property of a transport packet. 37 type TransportChecker func(*testing.T, header.Transport) 38 39 // ControlMessagesChecker is a function to check a property of ancillary data. 40 type ControlMessagesChecker func(*testing.T, tcpip.ReceivableControlMessages) 41 42 // IPv4 checks the validity and properties of the given IPv4 packet. It is 43 // expected to be used in conjunction with other network checkers for specific 44 // properties. For example, to check the source and destination address, one 45 // would call: 46 // 47 // checker.IPv4(t, v, checker.SrcAddr(x), checker.DstAddr(y)) 48 func IPv4(t *testing.T, v *buffer.View, checkers ...NetworkChecker) { 49 t.Helper() 50 51 ipv4 := header.IPv4(v.AsSlice()) 52 53 if !ipv4.IsValid(len(v.AsSlice())) { 54 t.Fatalf("Not a valid IPv4 packet: %x", ipv4) 55 } 56 57 if !ipv4.IsChecksumValid() { 58 t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) 59 } 60 61 for _, f := range checkers { 62 f(t, []header.Network{ipv4}) 63 } 64 if t.Failed() { 65 t.FailNow() 66 } 67 } 68 69 // IPv6 checks the validity and properties of the given IPv6 packet. The usage 70 // is similar to IPv4. 71 func IPv6(t *testing.T, v *buffer.View, checkers ...NetworkChecker) { 72 t.Helper() 73 74 ipv6 := header.IPv6(v.AsSlice()) 75 if !ipv6.IsValid(len(v.AsSlice())) { 76 t.Fatalf("Not a valid IPv6 packet: %x", ipv6) 77 } 78 79 for _, f := range checkers { 80 f(t, []header.Network{ipv6}) 81 } 82 if t.Failed() { 83 t.FailNow() 84 } 85 } 86 87 // SrcAddr creates a checker that checks the source address. 88 func SrcAddr(addr tcpip.Address) NetworkChecker { 89 return func(t *testing.T, h []header.Network) { 90 t.Helper() 91 92 if a := h[0].SourceAddress(); a != addr { 93 t.Errorf("Bad source address, got %v, want %v", a, addr) 94 } 95 } 96 } 97 98 // DstAddr creates a checker that checks the destination address. 99 func DstAddr(addr tcpip.Address) NetworkChecker { 100 return func(t *testing.T, h []header.Network) { 101 t.Helper() 102 103 if a := h[0].DestinationAddress(); a != addr { 104 t.Errorf("Bad destination address, got %v, want %v", a, addr) 105 } 106 } 107 } 108 109 // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). 110 func TTL(ttl uint8) NetworkChecker { 111 return func(t *testing.T, h []header.Network) { 112 t.Helper() 113 114 var v uint8 115 switch ip := h[0].(type) { 116 case header.IPv4: 117 v = ip.TTL() 118 case header.IPv6: 119 v = ip.HopLimit() 120 case *ipv6HeaderWithExtHdr: 121 v = ip.HopLimit() 122 default: 123 t.Fatalf("unrecognized header type %T for TTL evaluation", ip) 124 } 125 if v != ttl { 126 t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) 127 } 128 } 129 } 130 131 // IPFullLength creates a checker for the full IP packet length. The 132 // expected size is checked against both the Total Length in the 133 // header and the number of bytes received. 134 func IPFullLength(packetLength uint16) NetworkChecker { 135 return func(t *testing.T, h []header.Network) { 136 t.Helper() 137 138 var v uint16 139 var l uint16 140 switch ip := h[0].(type) { 141 case header.IPv4: 142 v = ip.TotalLength() 143 l = uint16(len(ip)) 144 case header.IPv6: 145 v = ip.PayloadLength() + header.IPv6FixedHeaderSize 146 l = uint16(len(ip)) 147 default: 148 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) 149 } 150 if l != packetLength { 151 t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) 152 } 153 if v != packetLength { 154 t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) 155 } 156 } 157 } 158 159 // IPv4HeaderLength creates a checker that checks the IPv4 Header length. 160 func IPv4HeaderLength(headerLength int) NetworkChecker { 161 return func(t *testing.T, h []header.Network) { 162 t.Helper() 163 164 switch ip := h[0].(type) { 165 case header.IPv4: 166 if hl := ip.HeaderLength(); hl != uint8(headerLength) { 167 t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) 168 } 169 default: 170 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) 171 } 172 } 173 } 174 175 // PayloadLen creates a checker that checks the payload length. 176 func PayloadLen(payloadLength int) NetworkChecker { 177 return func(t *testing.T, h []header.Network) { 178 t.Helper() 179 180 if l := len(h[0].Payload()); l != payloadLength { 181 t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) 182 } 183 } 184 } 185 186 // IPPayload creates a checker that checks the payload. 187 func IPPayload(payload []byte) NetworkChecker { 188 return func(t *testing.T, h []header.Network) { 189 t.Helper() 190 191 got := h[0].Payload() 192 193 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 194 if len(got) == 0 && len(payload) == 0 { 195 return 196 } 197 198 if diff := cmp.Diff(payload, got); diff != "" { 199 t.Errorf("payload mismatch (-want +got):\n%s", diff) 200 } 201 } 202 } 203 204 // IPv4Options returns a checker that checks the options in an IPv4 packet. 205 func IPv4Options(want header.IPv4Options) NetworkChecker { 206 return func(t *testing.T, h []header.Network) { 207 t.Helper() 208 209 ip, ok := h[0].(header.IPv4) 210 if !ok { 211 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) 212 } 213 options := ip.Options() 214 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 215 if len(want) == 0 && len(options) == 0 { 216 return 217 } 218 if diff := cmp.Diff(want, options); diff != "" { 219 t.Errorf("options mismatch (-want +got):\n%s", diff) 220 } 221 } 222 } 223 224 // IPv4RouterAlert returns a checker that checks that the RouterAlert option is 225 // set in an IPv4 packet. 226 func IPv4RouterAlert() NetworkChecker { 227 return func(t *testing.T, h []header.Network) { 228 t.Helper() 229 ip, ok := h[0].(header.IPv4) 230 if !ok { 231 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) 232 } 233 iterator := ip.Options().MakeIterator() 234 for { 235 opt, done, err := iterator.Next() 236 if err != nil { 237 t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) 238 } 239 if done { 240 break 241 } 242 if opt.Type() != header.IPv4OptionRouterAlertType { 243 continue 244 } 245 want := [header.IPv4OptionRouterAlertLength]byte{ 246 byte(header.IPv4OptionRouterAlertType), 247 header.IPv4OptionRouterAlertLength, 248 header.IPv4OptionRouterAlertValue, 249 header.IPv4OptionRouterAlertValue, 250 } 251 if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { 252 t.Errorf("router alert option mismatch (-want +got):\n%s", diff) 253 } 254 return 255 } 256 t.Errorf("failed to find router alert option in %v", ip.Options()) 257 } 258 } 259 260 // FragmentOffset creates a checker that checks the FragmentOffset field. 261 func FragmentOffset(offset uint16) NetworkChecker { 262 return func(t *testing.T, h []header.Network) { 263 t.Helper() 264 265 // We only do this for IPv4 for now. 266 switch ip := h[0].(type) { 267 case header.IPv4: 268 if v := ip.FragmentOffset(); v != offset { 269 t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) 270 } 271 } 272 } 273 } 274 275 // FragmentFlags creates a checker that checks the fragment flags field. 276 func FragmentFlags(flags uint8) NetworkChecker { 277 return func(t *testing.T, h []header.Network) { 278 t.Helper() 279 280 // We only do this for IPv4 for now. 281 switch ip := h[0].(type) { 282 case header.IPv4: 283 if v := ip.Flags(); v != flags { 284 t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) 285 } 286 } 287 } 288 } 289 290 // ReceiveTClass creates a checker that checks the TCLASS field in 291 // ControlMessages. 292 func ReceiveTClass(want uint32) ControlMessagesChecker { 293 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 294 t.Helper() 295 if !cm.HasTClass { 296 t.Error("got cm.HasTClass = false, want = true") 297 } else if got := cm.TClass; got != want { 298 t.Errorf("got cm.TClass = %d, want %d", got, want) 299 } 300 } 301 } 302 303 // NoTClassReceived creates a checker that checks the absence of the TCLASS 304 // field in ControlMessages. 305 func NoTClassReceived() ControlMessagesChecker { 306 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 307 t.Helper() 308 if cm.HasTClass { 309 t.Error("got cm.HasTClass = true, want = false") 310 } 311 } 312 } 313 314 // ReceiveTOS creates a checker that checks the TOS field in ControlMessages. 315 func ReceiveTOS(want uint8) ControlMessagesChecker { 316 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 317 t.Helper() 318 if !cm.HasTOS { 319 t.Error("got cm.HasTOS = false, want = true") 320 } else if got := cm.TOS; got != want { 321 t.Errorf("got cm.TOS = %d, want %d", got, want) 322 } 323 } 324 } 325 326 // NoTOSReceived creates a checker that checks the absence of the TOS field in 327 // ControlMessages. 328 func NoTOSReceived() ControlMessagesChecker { 329 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 330 t.Helper() 331 if cm.HasTOS { 332 t.Error("got cm.HasTOS = true, want = false") 333 } 334 } 335 } 336 337 // ReceiveTTL creates a checker that checks the TTL field in 338 // ControlMessages. 339 func ReceiveTTL(want uint8) ControlMessagesChecker { 340 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 341 t.Helper() 342 if !cm.HasTTL { 343 t.Errorf("got cm.HasTTL = %t, want = true", cm.HasTTL) 344 } else if got := cm.TTL; got != want { 345 t.Errorf("got cm.TTL = %d, want = %d", got, want) 346 } 347 } 348 } 349 350 // NoTTLReceived creates a checker that checks the absence of the TTL field in 351 // ControlMessages. 352 func NoTTLReceived() ControlMessagesChecker { 353 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 354 t.Helper() 355 if cm.HasTTL { 356 t.Error("got cm.HasTTL = true, want = false") 357 } 358 } 359 } 360 361 // ReceiveHopLimit creates a checker that checks the HopLimit field in 362 // ControlMessages. 363 func ReceiveHopLimit(want uint8) ControlMessagesChecker { 364 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 365 t.Helper() 366 if !cm.HasHopLimit { 367 t.Errorf("got cm.HasHopLimit = %t, want = true", cm.HasHopLimit) 368 } else if got := cm.HopLimit; got != want { 369 t.Errorf("got cm.HopLimit = %d, want = %d", got, want) 370 } 371 } 372 } 373 374 // NoHopLimitReceived creates a checker that checks the absence of the HopLimit 375 // field in ControlMessages. 376 func NoHopLimitReceived() ControlMessagesChecker { 377 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 378 t.Helper() 379 if cm.HasHopLimit { 380 t.Error("got cm.HasHopLimit = true, want = false") 381 } 382 } 383 } 384 385 // ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in 386 // ControlMessages. 387 func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { 388 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 389 t.Helper() 390 if !cm.HasIPPacketInfo { 391 t.Error("got cm.HasIPPacketInfo = false, want = true") 392 } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { 393 t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) 394 } 395 } 396 } 397 398 // NoIPPacketInfoReceived creates a checker that checks the PacketInfo field in 399 // ControlMessages. 400 func NoIPPacketInfoReceived() ControlMessagesChecker { 401 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 402 t.Helper() 403 if cm.HasIPPacketInfo { 404 t.Error("got cm.HasIPPacketInfo = true, want = false") 405 } 406 } 407 } 408 409 // ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field 410 // in ControlMessages. 411 func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker { 412 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 413 t.Helper() 414 if !cm.HasIPv6PacketInfo { 415 t.Error("got cm.HasIPv6PacketInfo = false, want = true") 416 } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" { 417 t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff) 418 } 419 } 420 } 421 422 // NoIPv6PacketInfoReceived creates a checker that checks the PacketInfo field 423 // in ControlMessages. 424 func NoIPv6PacketInfoReceived() ControlMessagesChecker { 425 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 426 t.Helper() 427 if cm.HasIPv6PacketInfo { 428 t.Error("got cm.HasIPv6PacketInfo = true, want = false") 429 } 430 } 431 } 432 433 // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress 434 // field in ControlMessages. 435 func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { 436 return func(t *testing.T, cm tcpip.ReceivableControlMessages) { 437 t.Helper() 438 if !cm.HasOriginalDstAddress { 439 t.Error("got cm.HasOriginalDstAddress = false, want = true") 440 } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { 441 t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) 442 } 443 } 444 } 445 446 // TOS creates a checker that checks the TOS field. 447 func TOS(tos uint8, label uint32) NetworkChecker { 448 return func(t *testing.T, h []header.Network) { 449 t.Helper() 450 451 if v, l := h[0].TOS(); v != tos || l != label { 452 t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) 453 } 454 } 455 } 456 457 // Raw creates a checker that checks the bytes of payload. 458 // The checker always checks the payload of the last network header. 459 // For instance, in case of IPv6 fragments, the payload that will be checked 460 // is the one containing the actual data that the packet is carrying, without 461 // the bytes added by the IPv6 fragmentation. 462 func Raw(want []byte) NetworkChecker { 463 return func(t *testing.T, h []header.Network) { 464 t.Helper() 465 466 if got := h[len(h)-1].Payload(); !slices.Equal(got, want) { 467 t.Errorf("Wrong payload, got %v, want %v", got, want) 468 } 469 } 470 } 471 472 // IPv6Fragment creates a checker that validates an IPv6 fragment. 473 func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { 474 return func(t *testing.T, h []header.Network) { 475 t.Helper() 476 477 if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { 478 t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) 479 } 480 481 ipv6Frag := header.IPv6Fragment(h[0].Payload()) 482 if !ipv6Frag.IsValid() { 483 t.Error("Not a valid IPv6 fragment") 484 } 485 486 for _, f := range checkers { 487 f(t, []header.Network{h[0], ipv6Frag}) 488 } 489 if t.Failed() { 490 t.FailNow() 491 } 492 } 493 } 494 495 // TCP creates a checker that checks that the transport protocol is TCP and 496 // potentially additional transport header fields. 497 func TCP(checkers ...TransportChecker) NetworkChecker { 498 return func(t *testing.T, h []header.Network) { 499 t.Helper() 500 501 first := h[0] 502 last := h[len(h)-1] 503 504 if p := last.TransportProtocol(); p != header.TCPProtocolNumber { 505 t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) 506 } 507 508 tcp := header.TCP(last.Payload()) 509 payload := tcp.Payload() 510 payloadChecksum := checksum.Checksum(payload, 0) 511 if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { 512 t.Errorf("Bad checksum, got = %d", tcp.Checksum()) 513 } 514 515 // Run the transport checkers. 516 for _, f := range checkers { 517 f(t, tcp) 518 } 519 if t.Failed() { 520 t.FailNow() 521 } 522 } 523 } 524 525 // UDP creates a checker that checks that the transport protocol is UDP and 526 // potentially additional transport header fields. 527 func UDP(checkers ...TransportChecker) NetworkChecker { 528 return func(t *testing.T, h []header.Network) { 529 t.Helper() 530 531 last := h[len(h)-1] 532 533 if p := last.TransportProtocol(); p != header.UDPProtocolNumber { 534 t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) 535 } 536 537 udp := header.UDP(last.Payload()) 538 for _, f := range checkers { 539 f(t, udp) 540 } 541 if t.Failed() { 542 t.FailNow() 543 } 544 } 545 } 546 547 // SrcPort creates a checker that checks the source port. 548 func SrcPort(port uint16) TransportChecker { 549 return func(t *testing.T, h header.Transport) { 550 t.Helper() 551 552 if p := h.SourcePort(); p != port { 553 t.Errorf("Bad source port, got = %d, want = %d", p, port) 554 } 555 } 556 } 557 558 // DstPort creates a checker that checks the destination port. 559 func DstPort(port uint16) TransportChecker { 560 return func(t *testing.T, h header.Transport) { 561 t.Helper() 562 563 if p := h.DestinationPort(); p != port { 564 t.Errorf("Bad destination port, got = %d, want = %d", p, port) 565 } 566 } 567 } 568 569 // TransportChecksum creates a checker that checks the checksum value. 570 func TransportChecksum(want uint16) TransportChecker { 571 return func(t *testing.T, transportHdr header.Transport) { 572 t.Helper() 573 574 if got := transportHdr.Checksum(); got != want { 575 t.Errorf("got transportHdr.Checksum() = %d, want = %d", got, want) 576 } 577 } 578 } 579 580 // NoChecksum creates a checker that checks if the checksum is zero. 581 func NoChecksum(noChecksum bool) TransportChecker { 582 return func(t *testing.T, h header.Transport) { 583 t.Helper() 584 585 udp, ok := h.(header.UDP) 586 if !ok { 587 t.Fatalf("UDP header not found in h: %T", h) 588 } 589 590 if b := udp.Checksum() == 0; b != noChecksum { 591 t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) 592 } 593 } 594 } 595 596 // TCPSeqNum creates a checker that checks the sequence number. 597 func TCPSeqNum(seq uint32) TransportChecker { 598 return func(t *testing.T, h header.Transport) { 599 t.Helper() 600 601 tcp, ok := h.(header.TCP) 602 if !ok { 603 t.Fatalf("TCP header not found in h: %T", h) 604 } 605 606 if s := tcp.SequenceNumber(); s != seq { 607 t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) 608 } 609 } 610 } 611 612 // TCPAckNum creates a checker that checks the ack number. 613 func TCPAckNum(seq uint32) TransportChecker { 614 return func(t *testing.T, h header.Transport) { 615 t.Helper() 616 617 tcp, ok := h.(header.TCP) 618 if !ok { 619 t.Fatalf("TCP header not found in h: %T", h) 620 } 621 622 if s := tcp.AckNumber(); s != seq { 623 t.Errorf("Bad ack number, got = %d, want = %d", s, seq) 624 } 625 } 626 } 627 628 // TCPWindow creates a checker that checks the tcp window. 629 func TCPWindow(window uint16) TransportChecker { 630 return func(t *testing.T, h header.Transport) { 631 t.Helper() 632 633 tcp, ok := h.(header.TCP) 634 if !ok { 635 t.Fatalf("TCP header not found in hdr : %T", h) 636 } 637 638 if w := tcp.WindowSize(); w != window { 639 t.Errorf("Bad window, got %d, want %d", w, window) 640 } 641 } 642 } 643 644 // TCPWindowGreaterThanEq creates a checker that checks that the TCP window 645 // is greater than or equal to the provided value. 646 func TCPWindowGreaterThanEq(window uint16) TransportChecker { 647 return func(t *testing.T, h header.Transport) { 648 t.Helper() 649 650 tcp, ok := h.(header.TCP) 651 if !ok { 652 t.Fatalf("TCP header not found in h: %T", h) 653 } 654 655 if w := tcp.WindowSize(); w < window { 656 t.Errorf("Bad window, got %d, want > %d", w, window) 657 } 658 } 659 } 660 661 // TCPWindowLessThanEq creates a checker that checks that the tcp window 662 // is less than or equal to the provided value. 663 func TCPWindowLessThanEq(window uint16) TransportChecker { 664 return func(t *testing.T, h header.Transport) { 665 t.Helper() 666 667 tcp, ok := h.(header.TCP) 668 if !ok { 669 t.Fatalf("TCP header not found in h: %T", h) 670 } 671 672 if w := tcp.WindowSize(); w > window { 673 t.Errorf("Bad window, got %d, want < %d", w, window) 674 } 675 } 676 } 677 678 // TCPFlags creates a checker that checks the tcp flags. 679 func TCPFlags(flags header.TCPFlags) TransportChecker { 680 return func(t *testing.T, h header.Transport) { 681 t.Helper() 682 683 tcp, ok := h.(header.TCP) 684 if !ok { 685 t.Fatalf("TCP header not found in h: %T", h) 686 } 687 688 if got := tcp.Flags(); got != flags { 689 t.Errorf("got tcp.Flags() = %s, want %s", got, flags) 690 } 691 } 692 } 693 694 // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the 695 // given mask, match the supplied flags. 696 func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker { 697 return func(t *testing.T, h header.Transport) { 698 t.Helper() 699 700 tcp, ok := h.(header.TCP) 701 if !ok { 702 t.Fatalf("TCP header not found in h: %T", h) 703 } 704 705 if got := tcp.Flags(); (got & mask) != (flags & mask) { 706 t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask) 707 } 708 } 709 } 710 711 // TCPSynOptions creates a checker that checks the presence of TCP options in 712 // SYN segments. 713 // 714 // If wndscale is negative, the window scale option must not be present. 715 func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { 716 return func(t *testing.T, h header.Transport) { 717 t.Helper() 718 719 tcp, ok := h.(header.TCP) 720 if !ok { 721 return 722 } 723 opts := tcp.Options() 724 limit := len(opts) 725 foundMSS := false 726 foundWS := false 727 foundTS := false 728 foundSACKPermitted := false 729 tsVal := uint32(0) 730 tsEcr := uint32(0) 731 for i := 0; i < limit; { 732 switch opts[i] { 733 case header.TCPOptionEOL: 734 i = limit 735 case header.TCPOptionNOP: 736 i++ 737 case header.TCPOptionMSS: 738 v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) 739 if wantOpts.MSS != v { 740 t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) 741 } 742 foundMSS = true 743 i += 4 744 case header.TCPOptionWS: 745 if wantOpts.WS < 0 { 746 t.Error("WS present when it shouldn't be") 747 } 748 v := int(opts[i+2]) 749 if v != wantOpts.WS { 750 t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) 751 } 752 foundWS = true 753 i += 3 754 case header.TCPOptionTS: 755 if i+9 >= limit { 756 t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) 757 } 758 if opts[i+1] != 10 { 759 t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) 760 } 761 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 762 tsEcr = uint32(0) 763 if tcp.Flags()&header.TCPFlagAck != 0 { 764 // If the syn is an SYN-ACK then read 765 // the tsEcr value as well. 766 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 767 } 768 foundTS = true 769 i += 10 770 case header.TCPOptionSACKPermitted: 771 if i+1 >= limit { 772 t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) 773 } 774 if opts[i+1] != 2 { 775 t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) 776 } 777 foundSACKPermitted = true 778 i += 2 779 780 default: 781 i += int(opts[i+1]) 782 } 783 } 784 785 if !foundMSS { 786 t.Errorf("MSS option not found. Options: %x", opts) 787 } 788 789 if !foundWS && wantOpts.WS >= 0 { 790 t.Errorf("WS option not found. Options: %x", opts) 791 } 792 if wantOpts.TS && !foundTS { 793 t.Errorf("TS option not found. Options: %x", opts) 794 } 795 if foundTS && tsVal == 0 { 796 t.Error("TS option specified but the timestamp value is zero") 797 } 798 if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { 799 t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) 800 } 801 if wantOpts.SACKPermitted && !foundSACKPermitted { 802 t.Errorf("SACKPermitted option not found. Options: %x", opts) 803 } 804 } 805 } 806 807 // TCPTimestampChecker creates a checker that validates that a TCP segment has a 808 // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and 809 // wantTSEcr values with those in the TCP segment (if present). 810 // 811 // If wantTSVal or wantTSEcr is zero then the corresponding comparison is 812 // skipped. 813 func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { 814 return func(t *testing.T, h header.Transport) { 815 t.Helper() 816 817 tcp, ok := h.(header.TCP) 818 if !ok { 819 return 820 } 821 opts := tcp.Options() 822 limit := len(opts) 823 foundTS := false 824 tsVal := uint32(0) 825 tsEcr := uint32(0) 826 for i := 0; i < limit; { 827 switch opts[i] { 828 case header.TCPOptionEOL: 829 i = limit 830 case header.TCPOptionNOP: 831 i++ 832 case header.TCPOptionTS: 833 if i+9 >= limit { 834 t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) 835 } 836 if opts[i+1] != 10 { 837 t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) 838 } 839 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 840 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 841 foundTS = true 842 i += 10 843 default: 844 // We don't recognize this option, just skip over it. 845 if i+2 > limit { 846 return 847 } 848 l := int(opts[i+1]) 849 if l < 2 || i+l > limit { 850 return 851 } 852 i += l 853 } 854 } 855 856 if wantTS != foundTS { 857 t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) 858 } 859 if wantTS && wantTSVal != 0 && wantTSVal != tsVal { 860 t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) 861 } 862 if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { 863 t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) 864 } 865 } 866 } 867 868 // TCPSACKBlockChecker creates a checker that verifies that the segment does 869 // contain the specified SACK blocks in the TCP options. 870 func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { 871 return func(t *testing.T, h header.Transport) { 872 t.Helper() 873 tcp, ok := h.(header.TCP) 874 if !ok { 875 return 876 } 877 var gotSACKBlocks []header.SACKBlock 878 879 opts := tcp.Options() 880 limit := len(opts) 881 for i := 0; i < limit; { 882 switch opts[i] { 883 case header.TCPOptionEOL: 884 i = limit 885 case header.TCPOptionNOP: 886 i++ 887 case header.TCPOptionSACK: 888 if i+2 > limit { 889 // Malformed SACK block. 890 t.Errorf("malformed SACK option in options: %v", opts) 891 } 892 sackOptionLen := int(opts[i+1]) 893 if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { 894 // Malformed SACK block. 895 t.Errorf("malformed SACK option length in options: %v", opts) 896 } 897 numBlocks := sackOptionLen / 8 898 for j := 0; j < numBlocks; j++ { 899 start := binary.BigEndian.Uint32(opts[i+2+j*8:]) 900 end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) 901 gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ 902 Start: seqnum.Value(start), 903 End: seqnum.Value(end), 904 }) 905 } 906 i += sackOptionLen 907 default: 908 // We don't recognize this option, just skip over it. 909 if i+2 > limit { 910 break 911 } 912 l := int(opts[i+1]) 913 if l < 2 || i+l > limit { 914 break 915 } 916 i += l 917 } 918 } 919 920 if !slices.Equal(gotSACKBlocks, sackBlocks) { 921 t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) 922 } 923 } 924 } 925 926 // Payload creates a checker that checks the payload. 927 func Payload(want []byte) TransportChecker { 928 return func(t *testing.T, h header.Transport) { 929 t.Helper() 930 931 if got := h.Payload(); !slices.Equal(got, want) { 932 t.Errorf("Wrong payload, got %v, want %v", got, want) 933 } 934 } 935 } 936 937 // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 938 // and potentially additional ICMPv4 header fields. 939 func ICMPv4(checkers ...TransportChecker) NetworkChecker { 940 return func(t *testing.T, h []header.Network) { 941 t.Helper() 942 943 last := h[len(h)-1] 944 945 if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { 946 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) 947 } 948 949 icmp := header.ICMPv4(last.Payload()) 950 for _, f := range checkers { 951 f(t, icmp) 952 } 953 if t.Failed() { 954 t.FailNow() 955 } 956 } 957 } 958 959 // ICMPv4Type creates a checker that checks the ICMPv4 Type field. 960 func ICMPv4Type(want header.ICMPv4Type) TransportChecker { 961 return func(t *testing.T, h header.Transport) { 962 t.Helper() 963 964 icmpv4, ok := h.(header.ICMPv4) 965 if !ok { 966 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 967 } 968 if got := icmpv4.Type(); got != want { 969 t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) 970 } 971 } 972 } 973 974 // ICMPv4Code creates a checker that checks the ICMPv4 Code field. 975 func ICMPv4Code(want header.ICMPv4Code) TransportChecker { 976 return func(t *testing.T, h header.Transport) { 977 t.Helper() 978 979 icmpv4, ok := h.(header.ICMPv4) 980 if !ok { 981 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 982 } 983 if got := icmpv4.Code(); got != want { 984 t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) 985 } 986 } 987 } 988 989 // ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. 990 func ICMPv4Ident(want uint16) TransportChecker { 991 return func(t *testing.T, h header.Transport) { 992 t.Helper() 993 994 icmpv4, ok := h.(header.ICMPv4) 995 if !ok { 996 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 997 } 998 if got := icmpv4.Ident(); got != want { 999 t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) 1000 } 1001 } 1002 } 1003 1004 // ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. 1005 func ICMPv4Seq(want uint16) TransportChecker { 1006 return func(t *testing.T, h header.Transport) { 1007 t.Helper() 1008 1009 icmpv4, ok := h.(header.ICMPv4) 1010 if !ok { 1011 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 1012 } 1013 if got := icmpv4.Sequence(); got != want { 1014 t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) 1015 } 1016 } 1017 } 1018 1019 // ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. 1020 func ICMPv4Pointer(want uint8) TransportChecker { 1021 return func(t *testing.T, h header.Transport) { 1022 t.Helper() 1023 1024 icmpv4, ok := h.(header.ICMPv4) 1025 if !ok { 1026 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 1027 } 1028 if got := icmpv4.Pointer(); got != want { 1029 t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) 1030 } 1031 } 1032 } 1033 1034 // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. 1035 // This assumes that the payload exactly makes up the rest of the slice. 1036 func ICMPv4Checksum() TransportChecker { 1037 return func(t *testing.T, h header.Transport) { 1038 t.Helper() 1039 1040 icmpv4, ok := h.(header.ICMPv4) 1041 if !ok { 1042 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 1043 } 1044 heldChecksum := icmpv4.Checksum() 1045 icmpv4.SetChecksum(0) 1046 newChecksum := ^checksum.Checksum(icmpv4, 0) 1047 icmpv4.SetChecksum(heldChecksum) 1048 if heldChecksum != newChecksum { 1049 t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) 1050 } 1051 } 1052 } 1053 1054 // ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. 1055 func ICMPv4Payload(want []byte) TransportChecker { 1056 return func(t *testing.T, h header.Transport) { 1057 t.Helper() 1058 1059 icmpv4, ok := h.(header.ICMPv4) 1060 if !ok { 1061 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 1062 } 1063 payload := icmpv4.Payload() 1064 1065 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 1066 if len(want) == 0 && len(payload) == 0 { 1067 return 1068 } 1069 1070 if diff := cmp.Diff(want, payload); diff != "" { 1071 t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) 1072 } 1073 } 1074 } 1075 1076 // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and 1077 // potentially additional ICMPv6 header fields. 1078 // 1079 // ICMPv6 will validate the checksum field before calling checkers. 1080 func ICMPv6(checkers ...TransportChecker) NetworkChecker { 1081 return func(t *testing.T, h []header.Network) { 1082 t.Helper() 1083 1084 last := h[len(h)-1] 1085 1086 if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { 1087 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) 1088 } 1089 1090 icmp := header.ICMPv6(last.Payload()) 1091 if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 1092 Header: icmp, 1093 Src: last.SourceAddress(), 1094 Dst: last.DestinationAddress(), 1095 }); got != want { 1096 t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) 1097 } 1098 1099 for _, f := range checkers { 1100 f(t, icmp) 1101 } 1102 if t.Failed() { 1103 t.FailNow() 1104 } 1105 } 1106 } 1107 1108 // ICMPv6Type creates a checker that checks the ICMPv6 Type field. 1109 func ICMPv6Type(want header.ICMPv6Type) TransportChecker { 1110 return func(t *testing.T, h header.Transport) { 1111 t.Helper() 1112 1113 icmpv6, ok := h.(header.ICMPv6) 1114 if !ok { 1115 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1116 } 1117 if got := icmpv6.Type(); got != want { 1118 t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) 1119 } 1120 } 1121 } 1122 1123 // ICMPv6Code creates a checker that checks the ICMPv6 Code field. 1124 func ICMPv6Code(want header.ICMPv6Code) TransportChecker { 1125 return func(t *testing.T, h header.Transport) { 1126 t.Helper() 1127 1128 icmpv6, ok := h.(header.ICMPv6) 1129 if !ok { 1130 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1131 } 1132 if got := icmpv6.Code(); got != want { 1133 t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) 1134 } 1135 } 1136 } 1137 1138 // ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific 1139 // field. 1140 func ICMPv6TypeSpecific(want uint32) TransportChecker { 1141 return func(t *testing.T, h header.Transport) { 1142 t.Helper() 1143 1144 icmpv6, ok := h.(header.ICMPv6) 1145 if !ok { 1146 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1147 } 1148 if got := icmpv6.TypeSpecific(); got != want { 1149 t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) 1150 } 1151 } 1152 } 1153 1154 // ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. 1155 func ICMPv6Payload(want []byte) TransportChecker { 1156 return func(t *testing.T, h header.Transport) { 1157 t.Helper() 1158 1159 icmpv6, ok := h.(header.ICMPv6) 1160 if !ok { 1161 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1162 } 1163 payload := icmpv6.Payload() 1164 1165 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 1166 if len(want) == 0 && len(payload) == 0 { 1167 return 1168 } 1169 1170 if diff := cmp.Diff(want, payload); diff != "" { 1171 t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) 1172 } 1173 } 1174 } 1175 1176 // MLD creates a checker that checks that the packet contains a valid MLD 1177 // message for type of mldType, with potentially additional checks specified by 1178 // checkers. 1179 // 1180 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1181 // MLD message as far as the size of the message (minSize) is concerned. The 1182 // values within the message are up to checkers to validate. 1183 func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { 1184 return func(t *testing.T, h []header.Network) { 1185 t.Helper() 1186 1187 // Check normal ICMPv6 first. 1188 ICMPv6( 1189 ICMPv6Type(msgType), 1190 ICMPv6Code(0))(t, h) 1191 1192 last := h[len(h)-1] 1193 1194 icmp := header.ICMPv6(last.Payload()) 1195 if got := len(icmp.MessageBody()); got < minSize { 1196 t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) 1197 } 1198 1199 for _, f := range checkers { 1200 f(t, icmp) 1201 } 1202 if t.Failed() { 1203 t.FailNow() 1204 } 1205 } 1206 } 1207 1208 // MLDMaxRespDelay creates a checker that checks the Maximum Response Delay 1209 // field of a MLD message. 1210 // 1211 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1212 // containing a valid MLD message as far as the size is concerned. 1213 func MLDMaxRespDelay(want time.Duration) TransportChecker { 1214 return func(t *testing.T, h header.Transport) { 1215 t.Helper() 1216 1217 icmp := h.(header.ICMPv6) 1218 ns := header.MLD(icmp.MessageBody()) 1219 1220 if got := ns.MaximumResponseDelay(); got != want { 1221 t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) 1222 } 1223 } 1224 } 1225 1226 // MLDMulticastAddressUnordered creates a checker that checks that the multicast 1227 // address in the MLD message is expected to be seen. 1228 // 1229 // The seen address is removed from the expected groups map. 1230 // 1231 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1232 // containing a valid MLD message as far as the size is concerned. 1233 func MLDMulticastAddressUnordered(expectedGroups map[tcpip.Address]struct{}) TransportChecker { 1234 return func(t *testing.T, h header.Transport) { 1235 t.Helper() 1236 1237 icmp := h.(header.ICMPv6) 1238 ns := header.MLD(icmp.MessageBody()) 1239 1240 addr := ns.MulticastAddress() 1241 1242 if _, ok := expectedGroups[addr]; !ok { 1243 t.Errorf("unexpected multicast group %s", addr) 1244 } else { 1245 delete(expectedGroups, addr) 1246 } 1247 } 1248 } 1249 1250 // MLDMulticastAddress creates a checker that checks the Multicast Address 1251 // field of a MLD message. 1252 // 1253 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1254 // containing a valid MLD message as far as the size is concerned. 1255 func MLDMulticastAddress(want tcpip.Address) TransportChecker { 1256 return MLDMulticastAddressUnordered(map[tcpip.Address]struct{}{ 1257 want: struct{}{}, 1258 }) 1259 } 1260 1261 // MLDv2Report creates a checker that checks that the packet contains a valid 1262 // MLDv2 report with the specified records. 1263 // 1264 // Note that observed records are removed from expectedRecords. No error is 1265 // logged if the report does not have all the records expected. 1266 func MLDv2Report(expectedRecords map[tcpip.Address]header.MLDv2ReportRecordType) NetworkChecker { 1267 return func(t *testing.T, h []header.Network) { 1268 t.Helper() 1269 1270 // Check normal ICMPv6 first. 1271 ICMPv6( 1272 ICMPv6Type(header.ICMPv6MulticastListenerV2Report), 1273 ICMPv6Code(0))(t, h) 1274 1275 last := h[len(h)-1] 1276 icmp := header.ICMPv6(last.Payload()) 1277 report := header.MLDv2Report(icmp.MessageBody()) 1278 1279 records := report.MulticastAddressRecords() 1280 for len(expectedRecords) != 0 { 1281 record, res := records.Next() 1282 switch res { 1283 case header.MLDv2ReportMulticastAddressRecordIteratorNextOk: 1284 case header.MLDv2ReportMulticastAddressRecordIteratorNextDone: 1285 return 1286 default: 1287 t.Fatalf("unhandled res = %d", res) 1288 } 1289 1290 addr := record.MulticastAddress() 1291 expectedRecordType, ok := expectedRecords[addr] 1292 if !ok { 1293 t.Errorf("unexpected record for address %s", addr) 1294 continue 1295 } 1296 1297 if got, want := record.RecordType(), expectedRecordType; got != want { 1298 t.Errorf("got record.RecordType() = %d, want = %d", got, want) 1299 } 1300 1301 if got := record.AuxDataLen(); got != 0 { 1302 t.Errorf("got record.AuxDataLen() = %d, want = 0", got) 1303 } 1304 1305 sources, ok := record.Sources() 1306 if !ok { 1307 t.Error("got record.Sources() = (_, false), want = (_, true)") 1308 continue 1309 } 1310 1311 if source, ok := sources.Next(); ok { 1312 t.Fatalf("got sources.Next() = (%s, true), want = (_, false)", source) 1313 } 1314 1315 delete(expectedRecords, addr) 1316 } 1317 1318 if record, res := records.Next(); res != header.MLDv2ReportMulticastAddressRecordIteratorNextDone { 1319 t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, header.MLDv2ReportMulticastAddressRecordIteratorNextDone) 1320 } 1321 } 1322 } 1323 1324 // NDP creates a checker that checks that the packet contains a valid NDP 1325 // message for type of ty, with potentially additional checks specified by 1326 // checkers. 1327 // 1328 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1329 // NDP message as far as the size of the message (minSize) is concerned. The 1330 // values within the message are up to checkers to validate. 1331 func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { 1332 return func(t *testing.T, h []header.Network) { 1333 t.Helper() 1334 1335 // Check normal ICMPv6 first. 1336 ICMPv6( 1337 ICMPv6Type(msgType), 1338 ICMPv6Code(0))(t, h) 1339 1340 last := h[len(h)-1] 1341 1342 icmp := header.ICMPv6(last.Payload()) 1343 if got := len(icmp.MessageBody()); got < minSize { 1344 t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) 1345 } 1346 1347 for _, f := range checkers { 1348 f(t, icmp) 1349 } 1350 if t.Failed() { 1351 t.FailNow() 1352 } 1353 } 1354 } 1355 1356 // NDPNS creates a checker that checks that the packet contains a valid NDP 1357 // Neighbor Solicitation message (as per the raw wire format), with potentially 1358 // additional checks specified by checkers. 1359 // 1360 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1361 // NDPNS message as far as the size of the message is concerned. The values 1362 // within the message are up to checkers to validate. 1363 func NDPNS(checkers ...TransportChecker) NetworkChecker { 1364 return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) 1365 } 1366 1367 // NDPNSTargetAddress creates a checker that checks the Target Address field of 1368 // a header.NDPNeighborSolicit. 1369 // 1370 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1371 // containing a valid NDPNS message as far as the size is concerned. 1372 func NDPNSTargetAddress(want tcpip.Address) TransportChecker { 1373 return func(t *testing.T, h header.Transport) { 1374 t.Helper() 1375 1376 icmp := h.(header.ICMPv6) 1377 ns := header.NDPNeighborSolicit(icmp.MessageBody()) 1378 1379 if got := ns.TargetAddress(); got != want { 1380 t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) 1381 } 1382 } 1383 } 1384 1385 // NDPNA creates a checker that checks that the packet contains a valid NDP 1386 // Neighbor Advertisement message (as per the raw wire format), with potentially 1387 // additional checks specified by checkers. 1388 // 1389 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1390 // NDPNA message as far as the size of the message is concerned. The values 1391 // within the message are up to checkers to validate. 1392 func NDPNA(checkers ...TransportChecker) NetworkChecker { 1393 return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) 1394 } 1395 1396 // NDPNATargetAddress creates a checker that checks the Target Address field of 1397 // a header.NDPNeighborAdvert. 1398 // 1399 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1400 // containing a valid NDPNA message as far as the size is concerned. 1401 func NDPNATargetAddress(want tcpip.Address) TransportChecker { 1402 return func(t *testing.T, h header.Transport) { 1403 t.Helper() 1404 1405 icmp := h.(header.ICMPv6) 1406 na := header.NDPNeighborAdvert(icmp.MessageBody()) 1407 1408 if got := na.TargetAddress(); got != want { 1409 t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) 1410 } 1411 } 1412 } 1413 1414 // NDPNASolicitedFlag creates a checker that checks the Solicited field of 1415 // a header.NDPNeighborAdvert. 1416 // 1417 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1418 // containing a valid NDPNA message as far as the size is concerned. 1419 func NDPNASolicitedFlag(want bool) TransportChecker { 1420 return func(t *testing.T, h header.Transport) { 1421 t.Helper() 1422 1423 icmp := h.(header.ICMPv6) 1424 na := header.NDPNeighborAdvert(icmp.MessageBody()) 1425 1426 if got := na.SolicitedFlag(); got != want { 1427 t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) 1428 } 1429 } 1430 } 1431 1432 // ndpOptions checks that optsBuf only contains opts. 1433 func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { 1434 t.Helper() 1435 1436 it, err := optsBuf.Iter(true) 1437 if err != nil { 1438 t.Errorf("optsBuf.Iter(true): %s", err) 1439 return 1440 } 1441 1442 i := 0 1443 for { 1444 opt, done, err := it.Next() 1445 if err != nil { 1446 // This should never happen as Iter(true) above did not return an error. 1447 t.Fatalf("unexpected error when iterating over NDP options: %s", err) 1448 } 1449 if done { 1450 break 1451 } 1452 1453 if i >= len(opts) { 1454 t.Errorf("got unexpected option: %s", opt) 1455 continue 1456 } 1457 1458 switch wantOpt := opts[i].(type) { 1459 case header.NDPSourceLinkLayerAddressOption: 1460 gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) 1461 if !ok { 1462 t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) 1463 } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { 1464 t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) 1465 } 1466 case header.NDPTargetLinkLayerAddressOption: 1467 gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) 1468 if !ok { 1469 t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) 1470 } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { 1471 t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) 1472 } 1473 case header.NDPNonceOption: 1474 gotOpt, ok := opt.(header.NDPNonceOption) 1475 if !ok { 1476 t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) 1477 } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" { 1478 t.Errorf("nonce mismatch (-want +got):\n%s", diff) 1479 } 1480 default: 1481 t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) 1482 } 1483 1484 i++ 1485 } 1486 1487 if missing := opts[i:]; len(missing) > 0 { 1488 t.Errorf("missing options: %s", missing) 1489 } 1490 } 1491 1492 // NDPNAOptions creates a checker that checks that the packet contains the 1493 // provided NDP options within an NDP Neighbor Solicitation message. 1494 // 1495 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1496 // containing a valid NDPNA message as far as the size is concerned. 1497 func NDPNAOptions(opts []header.NDPOption) TransportChecker { 1498 return func(t *testing.T, h header.Transport) { 1499 t.Helper() 1500 1501 icmp := h.(header.ICMPv6) 1502 na := header.NDPNeighborAdvert(icmp.MessageBody()) 1503 ndpOptions(t, na.Options(), opts) 1504 } 1505 } 1506 1507 // NDPNSOptions creates a checker that checks that the packet contains the 1508 // provided NDP options within an NDP Neighbor Solicitation message. 1509 // 1510 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1511 // containing a valid NDPNS message as far as the size is concerned. 1512 func NDPNSOptions(opts []header.NDPOption) TransportChecker { 1513 return func(t *testing.T, h header.Transport) { 1514 t.Helper() 1515 1516 icmp := h.(header.ICMPv6) 1517 ns := header.NDPNeighborSolicit(icmp.MessageBody()) 1518 ndpOptions(t, ns.Options(), opts) 1519 } 1520 } 1521 1522 // NDPRS creates a checker that checks that the packet contains a valid NDP 1523 // Router Solicitation message (as per the raw wire format). 1524 // 1525 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1526 // NDPRS as far as the size of the message is concerned. The values within the 1527 // message are up to checkers to validate. 1528 func NDPRS(checkers ...TransportChecker) NetworkChecker { 1529 return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) 1530 } 1531 1532 // NDPRSOptions creates a checker that checks that the packet contains the 1533 // provided NDP options within an NDP Router Solicitation message. 1534 // 1535 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1536 // containing a valid NDPRS message as far as the size is concerned. 1537 func NDPRSOptions(opts []header.NDPOption) TransportChecker { 1538 return func(t *testing.T, h header.Transport) { 1539 t.Helper() 1540 1541 icmp := h.(header.ICMPv6) 1542 rs := header.NDPRouterSolicit(icmp.MessageBody()) 1543 ndpOptions(t, rs.Options(), opts) 1544 } 1545 } 1546 1547 // IGMP checks the validity and properties of the given IGMP packet. It is 1548 // expected to be used in conjunction with other IGMP transport checkers for 1549 // specific properties. 1550 func IGMP(checkers ...TransportChecker) NetworkChecker { 1551 return func(t *testing.T, h []header.Network) { 1552 t.Helper() 1553 1554 last := h[len(h)-1] 1555 1556 if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { 1557 t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) 1558 } 1559 1560 igmp := header.IGMP(last.Payload()) 1561 for _, f := range checkers { 1562 f(t, igmp) 1563 } 1564 if t.Failed() { 1565 t.FailNow() 1566 } 1567 } 1568 } 1569 1570 // IGMPType creates a checker that checks the IGMP Type field. 1571 func IGMPType(want header.IGMPType) TransportChecker { 1572 return func(t *testing.T, h header.Transport) { 1573 t.Helper() 1574 1575 igmp, ok := h.(header.IGMP) 1576 if !ok { 1577 t.Fatalf("got transport header = %T, want = header.IGMP", h) 1578 } 1579 if got := igmp.Type(); got != want { 1580 t.Errorf("got igmp.Type() = %d, want = %d", got, want) 1581 } 1582 } 1583 } 1584 1585 // IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. 1586 func IGMPMaxRespTime(want time.Duration) TransportChecker { 1587 return func(t *testing.T, h header.Transport) { 1588 t.Helper() 1589 1590 igmp, ok := h.(header.IGMP) 1591 if !ok { 1592 t.Fatalf("got transport header = %T, want = header.IGMP", h) 1593 } 1594 if got := igmp.MaxRespTime(); got != want { 1595 t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) 1596 } 1597 } 1598 } 1599 1600 // IGMPGroupAddressUnordered creates a checker that checks that the group 1601 // address in the IGMP message is expected to be seen. 1602 // 1603 // The seen address is removed from the expected groups map. 1604 // 1605 // The returned TransportChecker assumes that a valid IGMP is passed to it 1606 // containing a valid IGMP message as far as the size is concerned. 1607 func IGMPGroupAddressUnordered(expectedGroups map[tcpip.Address]struct{}) TransportChecker { 1608 return func(t *testing.T, h header.Transport) { 1609 t.Helper() 1610 1611 igmp, ok := h.(header.IGMP) 1612 if !ok { 1613 t.Fatalf("got transport header = %T, want = header.IGMP", h) 1614 } 1615 1616 addr := igmp.GroupAddress() 1617 1618 if _, ok := expectedGroups[addr]; !ok { 1619 t.Errorf("unexpected multicast group %s", addr) 1620 } else { 1621 delete(expectedGroups, addr) 1622 } 1623 } 1624 } 1625 1626 // IGMPGroupAddress creates a checker that checks the IGMP Group Address field. 1627 func IGMPGroupAddress(want tcpip.Address) TransportChecker { 1628 return IGMPGroupAddressUnordered(map[tcpip.Address]struct{}{ 1629 want: struct{}{}, 1630 }) 1631 } 1632 1633 // IGMPv3Report creates a checker that checks that the packet contains a valid 1634 // IGMPv3 report with the specified records. 1635 // 1636 // Note that observed records are removed from expectedRecords. No error is 1637 // logged if the report does not have all the records expected. 1638 func IGMPv3Report(expectedRecords map[tcpip.Address]header.IGMPv3ReportRecordType) NetworkChecker { 1639 return func(t *testing.T, h []header.Network) { 1640 t.Helper() 1641 1642 last := h[len(h)-1] 1643 if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { 1644 t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) 1645 } 1646 1647 igmp := header.IGMP(last.Payload()) 1648 if got := igmp.Type(); got != header.IGMPv3MembershipReport { 1649 t.Errorf("got igmp.Type() = %d, want = %d", got, header.IGMPv3MembershipReport) 1650 } 1651 1652 report := header.IGMPv3Report(igmp) 1653 if got, want := report.Checksum(), header.IGMPCalculateChecksum(igmp); got != want { 1654 t.Errorf("got report.Checksum() = %d, want = %d", got, want) 1655 } 1656 1657 records := report.GroupAddressRecords() 1658 for len(expectedRecords) != 0 { 1659 record, res := records.Next() 1660 switch res { 1661 case header.IGMPv3ReportGroupAddressRecordIteratorNextOk: 1662 case header.IGMPv3ReportGroupAddressRecordIteratorNextDone: 1663 return 1664 default: 1665 t.Fatalf("unhandled res = %d", res) 1666 } 1667 1668 addr := record.GroupAddress() 1669 expectedRecordType, ok := expectedRecords[addr] 1670 if !ok { 1671 t.Errorf("unexpected record for address %s", addr) 1672 continue 1673 } 1674 1675 if got, want := record.RecordType(), expectedRecordType; got != want { 1676 t.Errorf("got record.RecordType() = %d, want = %d", got, want) 1677 } 1678 1679 if got := record.AuxDataLen(); got != 0 { 1680 t.Errorf("got record.AuxDataLen() = %d, want = 0", got) 1681 } 1682 1683 sources, ok := record.Sources() 1684 if !ok { 1685 t.Error("got record.Sources() = (_, false), want = (_, true)") 1686 continue 1687 } 1688 1689 if source, ok := sources.Next(); ok { 1690 t.Fatalf("got sources.Next() = (%s, true), want = (_, false)", source) 1691 } 1692 1693 delete(expectedRecords, addr) 1694 } 1695 1696 if record, res := records.Next(); res != header.IGMPv3ReportGroupAddressRecordIteratorNextDone { 1697 t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, header.IGMPv3ReportGroupAddressRecordIteratorNextDone) 1698 } 1699 } 1700 } 1701 1702 // IPv6ExtHdrChecker is a function to check an extension header. 1703 type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) 1704 1705 // IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. 1706 func IPv6WithExtHdr(t *testing.T, v *buffer.View, checkers ...NetworkChecker) { 1707 t.Helper() 1708 1709 ipv6 := header.IPv6(v.AsSlice()) 1710 if !ipv6.IsValid(len(v.AsSlice())) { 1711 t.Error("not a valid IPv6 packet") 1712 return 1713 } 1714 1715 payloadIterator := header.MakeIPv6PayloadIterator( 1716 header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), 1717 buffer.MakeWithData(ipv6.Payload()), 1718 ) 1719 defer payloadIterator.Release() 1720 1721 var rawPayloadHeader header.IPv6RawPayloadHeader 1722 for { 1723 h, done, err := payloadIterator.Next() 1724 if err != nil { 1725 t.Errorf("payloadIterator.Next(): %s", err) 1726 return 1727 } 1728 if done { 1729 t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) 1730 return 1731 } 1732 defer h.Release() 1733 r, ok := h.(header.IPv6RawPayloadHeader) 1734 if ok { 1735 rawPayloadHeader = r 1736 break 1737 } 1738 } 1739 1740 networkHeader := ipv6HeaderWithExtHdr{ 1741 IPv6: ipv6, 1742 transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), 1743 payload: rawPayloadHeader.Buf.Flatten(), 1744 } 1745 1746 for _, checker := range checkers { 1747 checker(t, []header.Network{&networkHeader}) 1748 } 1749 } 1750 1751 // IPv6ExtHdr checks for the presence of extension headers. 1752 // 1753 // All the extension headers in headers will be checked exhaustively in the 1754 // order provided. 1755 func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { 1756 return func(t *testing.T, h []header.Network) { 1757 t.Helper() 1758 1759 extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) 1760 if !ok { 1761 t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) 1762 return 1763 } 1764 1765 payloadIterator := header.MakeIPv6PayloadIterator( 1766 header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), 1767 buffer.MakeWithData(extHdrs.IPv6.Payload()), 1768 ) 1769 defer payloadIterator.Release() 1770 1771 for _, check := range headers { 1772 h, done, err := payloadIterator.Next() 1773 if err != nil { 1774 t.Errorf("payloadIterator.Next(): %s", err) 1775 return 1776 } 1777 if done { 1778 t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) 1779 return 1780 } 1781 check(t, h) 1782 h.Release() 1783 } 1784 // Validate we consumed all headers. 1785 // 1786 // The next one over should be a raw payload and then iterator should 1787 // terminate. 1788 wantDone := false 1789 for { 1790 h, done, err := payloadIterator.Next() 1791 if err != nil { 1792 t.Errorf("payloadIterator.Next(): %s", err) 1793 return 1794 } 1795 if done != wantDone { 1796 t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) 1797 return 1798 } 1799 if done { 1800 break 1801 } 1802 if _, ok := h.(header.IPv6RawPayloadHeader); !ok { 1803 t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) 1804 continue 1805 } else { 1806 h.Release() 1807 } 1808 wantDone = true 1809 } 1810 } 1811 } 1812 1813 var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) 1814 1815 // ipv6HeaderWithExtHdr provides a header.Network implementation that takes 1816 // extension headers into consideration, which is not the case with vanilla 1817 // header.IPv6. 1818 type ipv6HeaderWithExtHdr struct { 1819 header.IPv6 1820 transport tcpip.TransportProtocolNumber 1821 payload []byte 1822 } 1823 1824 // TransportProtocol implements header.Network. 1825 func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { 1826 return h.transport 1827 } 1828 1829 // Payload implements header.Network. 1830 func (h *ipv6HeaderWithExtHdr) Payload() []byte { 1831 return h.payload 1832 } 1833 1834 // IPv6ExtHdrOptionChecker is a function to check an extension header option. 1835 type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) 1836 1837 // IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop 1838 // extension header and validates the containing options with checkers. 1839 // 1840 // checkers must exhaustively contain all the expected options. 1841 func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { 1842 return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { 1843 t.Helper() 1844 1845 hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) 1846 if !ok { 1847 t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) 1848 return 1849 } 1850 optionsIterator := hbh.Iter() 1851 for _, f := range checkers { 1852 opt, done, err := optionsIterator.Next() 1853 if err != nil { 1854 t.Errorf("optionsIterator.Next(): %s", err) 1855 return 1856 } 1857 if done { 1858 t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) 1859 } 1860 f(t, opt) 1861 if uo, ok := opt.(*header.IPv6UnknownExtHdrOption); ok { 1862 uo.Data.Release() 1863 } 1864 } 1865 // Validate all options were consumed. 1866 for { 1867 opt, done, err := optionsIterator.Next() 1868 if err != nil { 1869 t.Errorf("optionsIterator.Next(): %s", err) 1870 return 1871 } 1872 if !done { 1873 t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) 1874 } 1875 if done { 1876 break 1877 } 1878 if uo, ok := opt.(*header.IPv6UnknownExtHdrOption); ok { 1879 uo.Data.Release() 1880 } 1881 } 1882 } 1883 } 1884 1885 // IPv6RouterAlert validates that an extension header option is the RouterAlert 1886 // option and matches on its value. 1887 func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { 1888 return func(t *testing.T, opt header.IPv6ExtHdrOption) { 1889 routerAlert, ok := opt.(*header.IPv6RouterAlertOption) 1890 if !ok { 1891 t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) 1892 return 1893 } 1894 if routerAlert.Value != want { 1895 t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) 1896 } 1897 } 1898 } 1899 1900 // IPv6UnknownOption validates that an extension header option is the 1901 // unknown header option. 1902 func IPv6UnknownOption() IPv6ExtHdrOptionChecker { 1903 return func(t *testing.T, opt header.IPv6ExtHdrOption) { 1904 _, ok := opt.(*header.IPv6UnknownExtHdrOption) 1905 if !ok { 1906 t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt) 1907 } 1908 } 1909 } 1910 1911 // IgnoreCmpPath returns a cmp.Option that ignores listed field paths. 1912 func IgnoreCmpPath(paths ...string) cmp.Option { 1913 ignores := map[string]struct{}{} 1914 for _, path := range paths { 1915 ignores[path] = struct{}{} 1916 } 1917 return cmp.FilterPath(func(path cmp.Path) bool { 1918 _, ok := ignores[path.String()] 1919 return ok 1920 }, cmp.Ignore()) 1921 }