github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/checker/checker.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package checker provides helper functions to check networking packets for 16 // validity. 17 package checker 18 19 import ( 20 "encoding/binary" 21 "reflect" 22 "testing" 23 "time" 24 25 "github.com/google/go-cmp/cmp" 26 "github.com/SagerNet/gvisor/pkg/tcpip" 27 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 28 "github.com/SagerNet/gvisor/pkg/tcpip/header" 29 "github.com/SagerNet/gvisor/pkg/tcpip/seqnum" 30 ) 31 32 // NetworkChecker is a function to check a property of a network packet. 33 type NetworkChecker func(*testing.T, []header.Network) 34 35 // TransportChecker is a function to check a property of a transport packet. 36 type TransportChecker func(*testing.T, header.Transport) 37 38 // ControlMessagesChecker is a function to check a property of ancillary data. 39 type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) 40 41 // IPv4 checks the validity and properties of the given IPv4 packet. It is 42 // expected to be used in conjunction with other network checkers for specific 43 // properties. For example, to check the source and destination address, one 44 // would call: 45 // 46 // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) 47 func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { 48 t.Helper() 49 50 ipv4 := header.IPv4(b) 51 52 if !ipv4.IsValid(len(b)) { 53 t.Fatalf("Not a valid IPv4 packet: %x", ipv4) 54 } 55 56 if !ipv4.IsChecksumValid() { 57 t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) 58 } 59 60 for _, f := range checkers { 61 f(t, []header.Network{ipv4}) 62 } 63 if t.Failed() { 64 t.FailNow() 65 } 66 } 67 68 // IPv6 checks the validity and properties of the given IPv6 packet. The usage 69 // is similar to IPv4. 70 func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { 71 t.Helper() 72 73 ipv6 := header.IPv6(b) 74 if !ipv6.IsValid(len(b)) { 75 t.Fatalf("Not a valid IPv6 packet: %x", ipv6) 76 } 77 78 for _, f := range checkers { 79 f(t, []header.Network{ipv6}) 80 } 81 if t.Failed() { 82 t.FailNow() 83 } 84 } 85 86 // SrcAddr creates a checker that checks the source address. 87 func SrcAddr(addr tcpip.Address) NetworkChecker { 88 return func(t *testing.T, h []header.Network) { 89 t.Helper() 90 91 if a := h[0].SourceAddress(); a != addr { 92 t.Errorf("Bad source address, got %v, want %v", a, addr) 93 } 94 } 95 } 96 97 // DstAddr creates a checker that checks the destination address. 98 func DstAddr(addr tcpip.Address) NetworkChecker { 99 return func(t *testing.T, h []header.Network) { 100 t.Helper() 101 102 if a := h[0].DestinationAddress(); a != addr { 103 t.Errorf("Bad destination address, got %v, want %v", a, addr) 104 } 105 } 106 } 107 108 // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). 109 func TTL(ttl uint8) NetworkChecker { 110 return func(t *testing.T, h []header.Network) { 111 t.Helper() 112 113 var v uint8 114 switch ip := h[0].(type) { 115 case header.IPv4: 116 v = ip.TTL() 117 case header.IPv6: 118 v = ip.HopLimit() 119 case *ipv6HeaderWithExtHdr: 120 v = ip.HopLimit() 121 default: 122 t.Fatalf("unrecognized header type %T for TTL evaluation", ip) 123 } 124 if v != ttl { 125 t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) 126 } 127 } 128 } 129 130 // IPFullLength creates a checker for the full IP packet length. The 131 // expected size is checked against both the Total Length in the 132 // header and the number of bytes received. 133 func IPFullLength(packetLength uint16) NetworkChecker { 134 return func(t *testing.T, h []header.Network) { 135 t.Helper() 136 137 var v uint16 138 var l uint16 139 switch ip := h[0].(type) { 140 case header.IPv4: 141 v = ip.TotalLength() 142 l = uint16(len(ip)) 143 case header.IPv6: 144 v = ip.PayloadLength() + header.IPv6FixedHeaderSize 145 l = uint16(len(ip)) 146 default: 147 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) 148 } 149 if l != packetLength { 150 t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) 151 } 152 if v != packetLength { 153 t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) 154 } 155 } 156 } 157 158 // IPv4HeaderLength creates a checker that checks the IPv4 Header length. 159 func IPv4HeaderLength(headerLength int) NetworkChecker { 160 return func(t *testing.T, h []header.Network) { 161 t.Helper() 162 163 switch ip := h[0].(type) { 164 case header.IPv4: 165 if hl := ip.HeaderLength(); hl != uint8(headerLength) { 166 t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) 167 } 168 default: 169 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) 170 } 171 } 172 } 173 174 // PayloadLen creates a checker that checks the payload length. 175 func PayloadLen(payloadLength int) NetworkChecker { 176 return func(t *testing.T, h []header.Network) { 177 t.Helper() 178 179 if l := len(h[0].Payload()); l != payloadLength { 180 t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) 181 } 182 } 183 } 184 185 // IPPayload creates a checker that checks the payload. 186 func IPPayload(payload []byte) NetworkChecker { 187 return func(t *testing.T, h []header.Network) { 188 t.Helper() 189 190 got := h[0].Payload() 191 192 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 193 if len(got) == 0 && len(payload) == 0 { 194 return 195 } 196 197 if diff := cmp.Diff(payload, got); diff != "" { 198 t.Errorf("payload mismatch (-want +got):\n%s", diff) 199 } 200 } 201 } 202 203 // IPv4Options returns a checker that checks the options in an IPv4 packet. 204 func IPv4Options(want header.IPv4Options) NetworkChecker { 205 return func(t *testing.T, h []header.Network) { 206 t.Helper() 207 208 ip, ok := h[0].(header.IPv4) 209 if !ok { 210 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) 211 } 212 options := ip.Options() 213 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 214 if len(want) == 0 && len(options) == 0 { 215 return 216 } 217 if diff := cmp.Diff(want, options); diff != "" { 218 t.Errorf("options mismatch (-want +got):\n%s", diff) 219 } 220 } 221 } 222 223 // IPv4RouterAlert returns a checker that checks that the RouterAlert option is 224 // set in an IPv4 packet. 225 func IPv4RouterAlert() NetworkChecker { 226 return func(t *testing.T, h []header.Network) { 227 t.Helper() 228 ip, ok := h[0].(header.IPv4) 229 if !ok { 230 t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) 231 } 232 iterator := ip.Options().MakeIterator() 233 for { 234 opt, done, err := iterator.Next() 235 if err != nil { 236 t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) 237 } 238 if done { 239 break 240 } 241 if opt.Type() != header.IPv4OptionRouterAlertType { 242 continue 243 } 244 want := [header.IPv4OptionRouterAlertLength]byte{ 245 byte(header.IPv4OptionRouterAlertType), 246 header.IPv4OptionRouterAlertLength, 247 header.IPv4OptionRouterAlertValue, 248 header.IPv4OptionRouterAlertValue, 249 } 250 if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { 251 t.Errorf("router alert option mismatch (-want +got):\n%s", diff) 252 } 253 return 254 } 255 t.Errorf("failed to find router alert option in %v", ip.Options()) 256 } 257 } 258 259 // FragmentOffset creates a checker that checks the FragmentOffset field. 260 func FragmentOffset(offset uint16) NetworkChecker { 261 return func(t *testing.T, h []header.Network) { 262 t.Helper() 263 264 // We only do this for IPv4 for now. 265 switch ip := h[0].(type) { 266 case header.IPv4: 267 if v := ip.FragmentOffset(); v != offset { 268 t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) 269 } 270 } 271 } 272 } 273 274 // FragmentFlags creates a checker that checks the fragment flags field. 275 func FragmentFlags(flags uint8) NetworkChecker { 276 return func(t *testing.T, h []header.Network) { 277 t.Helper() 278 279 // We only do this for IPv4 for now. 280 switch ip := h[0].(type) { 281 case header.IPv4: 282 if v := ip.Flags(); v != flags { 283 t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) 284 } 285 } 286 } 287 } 288 289 // ReceiveTClass creates a checker that checks the TCLASS field in 290 // ControlMessages. 291 func ReceiveTClass(want uint32) ControlMessagesChecker { 292 return func(t *testing.T, cm tcpip.ControlMessages) { 293 t.Helper() 294 if !cm.HasTClass { 295 t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) 296 } else if got := cm.TClass; got != want { 297 t.Errorf("got cm.TClass = %d, want %d", got, want) 298 } 299 } 300 } 301 302 // ReceiveTOS creates a checker that checks the TOS field in ControlMessages. 303 func ReceiveTOS(want uint8) ControlMessagesChecker { 304 return func(t *testing.T, cm tcpip.ControlMessages) { 305 t.Helper() 306 if !cm.HasTOS { 307 t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) 308 } else if got := cm.TOS; got != want { 309 t.Errorf("got cm.TOS = %d, want %d", got, want) 310 } 311 } 312 } 313 314 // ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in 315 // ControlMessages. 316 func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { 317 return func(t *testing.T, cm tcpip.ControlMessages) { 318 t.Helper() 319 if !cm.HasIPPacketInfo { 320 t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) 321 } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { 322 t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) 323 } 324 } 325 } 326 327 // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress 328 // field in ControlMessages. 329 func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { 330 return func(t *testing.T, cm tcpip.ControlMessages) { 331 t.Helper() 332 if !cm.HasOriginalDstAddress { 333 t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) 334 } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { 335 t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) 336 } 337 } 338 } 339 340 // TOS creates a checker that checks the TOS field. 341 func TOS(tos uint8, label uint32) NetworkChecker { 342 return func(t *testing.T, h []header.Network) { 343 t.Helper() 344 345 if v, l := h[0].TOS(); v != tos || l != label { 346 t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) 347 } 348 } 349 } 350 351 // Raw creates a checker that checks the bytes of payload. 352 // The checker always checks the payload of the last network header. 353 // For instance, in case of IPv6 fragments, the payload that will be checked 354 // is the one containing the actual data that the packet is carrying, without 355 // the bytes added by the IPv6 fragmentation. 356 func Raw(want []byte) NetworkChecker { 357 return func(t *testing.T, h []header.Network) { 358 t.Helper() 359 360 if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { 361 t.Errorf("Wrong payload, got %v, want %v", got, want) 362 } 363 } 364 } 365 366 // IPv6Fragment creates a checker that validates an IPv6 fragment. 367 func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { 368 return func(t *testing.T, h []header.Network) { 369 t.Helper() 370 371 if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { 372 t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) 373 } 374 375 ipv6Frag := header.IPv6Fragment(h[0].Payload()) 376 if !ipv6Frag.IsValid() { 377 t.Error("Not a valid IPv6 fragment") 378 } 379 380 for _, f := range checkers { 381 f(t, []header.Network{h[0], ipv6Frag}) 382 } 383 if t.Failed() { 384 t.FailNow() 385 } 386 } 387 } 388 389 // TCP creates a checker that checks that the transport protocol is TCP and 390 // potentially additional transport header fields. 391 func TCP(checkers ...TransportChecker) NetworkChecker { 392 return func(t *testing.T, h []header.Network) { 393 t.Helper() 394 395 first := h[0] 396 last := h[len(h)-1] 397 398 if p := last.TransportProtocol(); p != header.TCPProtocolNumber { 399 t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) 400 } 401 402 tcp := header.TCP(last.Payload()) 403 payload := tcp.Payload() 404 payloadChecksum := header.Checksum(payload, 0) 405 if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { 406 t.Errorf("Bad checksum, got = %d", tcp.Checksum()) 407 } 408 409 // Run the transport checkers. 410 for _, f := range checkers { 411 f(t, tcp) 412 } 413 if t.Failed() { 414 t.FailNow() 415 } 416 } 417 } 418 419 // UDP creates a checker that checks that the transport protocol is UDP and 420 // potentially additional transport header fields. 421 func UDP(checkers ...TransportChecker) NetworkChecker { 422 return func(t *testing.T, h []header.Network) { 423 t.Helper() 424 425 last := h[len(h)-1] 426 427 if p := last.TransportProtocol(); p != header.UDPProtocolNumber { 428 t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) 429 } 430 431 udp := header.UDP(last.Payload()) 432 for _, f := range checkers { 433 f(t, udp) 434 } 435 if t.Failed() { 436 t.FailNow() 437 } 438 } 439 } 440 441 // SrcPort creates a checker that checks the source port. 442 func SrcPort(port uint16) TransportChecker { 443 return func(t *testing.T, h header.Transport) { 444 t.Helper() 445 446 if p := h.SourcePort(); p != port { 447 t.Errorf("Bad source port, got = %d, want = %d", p, port) 448 } 449 } 450 } 451 452 // DstPort creates a checker that checks the destination port. 453 func DstPort(port uint16) TransportChecker { 454 return func(t *testing.T, h header.Transport) { 455 t.Helper() 456 457 if p := h.DestinationPort(); p != port { 458 t.Errorf("Bad destination port, got = %d, want = %d", p, port) 459 } 460 } 461 } 462 463 // NoChecksum creates a checker that checks if the checksum is zero. 464 func NoChecksum(noChecksum bool) TransportChecker { 465 return func(t *testing.T, h header.Transport) { 466 t.Helper() 467 468 udp, ok := h.(header.UDP) 469 if !ok { 470 t.Fatalf("UDP header not found in h: %T", h) 471 } 472 473 if b := udp.Checksum() == 0; b != noChecksum { 474 t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) 475 } 476 } 477 } 478 479 // TCPSeqNum creates a checker that checks the sequence number. 480 func TCPSeqNum(seq uint32) TransportChecker { 481 return func(t *testing.T, h header.Transport) { 482 t.Helper() 483 484 tcp, ok := h.(header.TCP) 485 if !ok { 486 t.Fatalf("TCP header not found in h: %T", h) 487 } 488 489 if s := tcp.SequenceNumber(); s != seq { 490 t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) 491 } 492 } 493 } 494 495 // TCPAckNum creates a checker that checks the ack number. 496 func TCPAckNum(seq uint32) TransportChecker { 497 return func(t *testing.T, h header.Transport) { 498 t.Helper() 499 500 tcp, ok := h.(header.TCP) 501 if !ok { 502 t.Fatalf("TCP header not found in h: %T", h) 503 } 504 505 if s := tcp.AckNumber(); s != seq { 506 t.Errorf("Bad ack number, got = %d, want = %d", s, seq) 507 } 508 } 509 } 510 511 // TCPWindow creates a checker that checks the tcp window. 512 func TCPWindow(window uint16) TransportChecker { 513 return func(t *testing.T, h header.Transport) { 514 t.Helper() 515 516 tcp, ok := h.(header.TCP) 517 if !ok { 518 t.Fatalf("TCP header not found in hdr : %T", h) 519 } 520 521 if w := tcp.WindowSize(); w != window { 522 t.Errorf("Bad window, got %d, want %d", w, window) 523 } 524 } 525 } 526 527 // TCPWindowGreaterThanEq creates a checker that checks that the TCP window 528 // is greater than or equal to the provided value. 529 func TCPWindowGreaterThanEq(window uint16) TransportChecker { 530 return func(t *testing.T, h header.Transport) { 531 t.Helper() 532 533 tcp, ok := h.(header.TCP) 534 if !ok { 535 t.Fatalf("TCP header not found in h: %T", h) 536 } 537 538 if w := tcp.WindowSize(); w < window { 539 t.Errorf("Bad window, got %d, want > %d", w, window) 540 } 541 } 542 } 543 544 // TCPWindowLessThanEq creates a checker that checks that the tcp window 545 // is less than or equal to the provided value. 546 func TCPWindowLessThanEq(window uint16) TransportChecker { 547 return func(t *testing.T, h header.Transport) { 548 t.Helper() 549 550 tcp, ok := h.(header.TCP) 551 if !ok { 552 t.Fatalf("TCP header not found in h: %T", h) 553 } 554 555 if w := tcp.WindowSize(); w > window { 556 t.Errorf("Bad window, got %d, want < %d", w, window) 557 } 558 } 559 } 560 561 // TCPFlags creates a checker that checks the tcp flags. 562 func TCPFlags(flags header.TCPFlags) TransportChecker { 563 return func(t *testing.T, h header.Transport) { 564 t.Helper() 565 566 tcp, ok := h.(header.TCP) 567 if !ok { 568 t.Fatalf("TCP header not found in h: %T", h) 569 } 570 571 if got := tcp.Flags(); got != flags { 572 t.Errorf("got tcp.Flags() = %s, want %s", got, flags) 573 } 574 } 575 } 576 577 // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the 578 // given mask, match the supplied flags. 579 func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker { 580 return func(t *testing.T, h header.Transport) { 581 t.Helper() 582 583 tcp, ok := h.(header.TCP) 584 if !ok { 585 t.Fatalf("TCP header not found in h: %T", h) 586 } 587 588 if got := tcp.Flags(); (got & mask) != (flags & mask) { 589 t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask) 590 } 591 } 592 } 593 594 // TCPSynOptions creates a checker that checks the presence of TCP options in 595 // SYN segments. 596 // 597 // If wndscale is negative, the window scale option must not be present. 598 func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { 599 return func(t *testing.T, h header.Transport) { 600 t.Helper() 601 602 tcp, ok := h.(header.TCP) 603 if !ok { 604 return 605 } 606 opts := tcp.Options() 607 limit := len(opts) 608 foundMSS := false 609 foundWS := false 610 foundTS := false 611 foundSACKPermitted := false 612 tsVal := uint32(0) 613 tsEcr := uint32(0) 614 for i := 0; i < limit; { 615 switch opts[i] { 616 case header.TCPOptionEOL: 617 i = limit 618 case header.TCPOptionNOP: 619 i++ 620 case header.TCPOptionMSS: 621 v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) 622 if wantOpts.MSS != v { 623 t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) 624 } 625 foundMSS = true 626 i += 4 627 case header.TCPOptionWS: 628 if wantOpts.WS < 0 { 629 t.Error("WS present when it shouldn't be") 630 } 631 v := int(opts[i+2]) 632 if v != wantOpts.WS { 633 t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) 634 } 635 foundWS = true 636 i += 3 637 case header.TCPOptionTS: 638 if i+9 >= limit { 639 t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) 640 } 641 if opts[i+1] != 10 { 642 t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) 643 } 644 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 645 tsEcr = uint32(0) 646 if tcp.Flags()&header.TCPFlagAck != 0 { 647 // If the syn is an SYN-ACK then read 648 // the tsEcr value as well. 649 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 650 } 651 foundTS = true 652 i += 10 653 case header.TCPOptionSACKPermitted: 654 if i+1 >= limit { 655 t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) 656 } 657 if opts[i+1] != 2 { 658 t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) 659 } 660 foundSACKPermitted = true 661 i += 2 662 663 default: 664 i += int(opts[i+1]) 665 } 666 } 667 668 if !foundMSS { 669 t.Errorf("MSS option not found. Options: %x", opts) 670 } 671 672 if !foundWS && wantOpts.WS >= 0 { 673 t.Errorf("WS option not found. Options: %x", opts) 674 } 675 if wantOpts.TS && !foundTS { 676 t.Errorf("TS option not found. Options: %x", opts) 677 } 678 if foundTS && tsVal == 0 { 679 t.Error("TS option specified but the timestamp value is zero") 680 } 681 if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { 682 t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) 683 } 684 if wantOpts.SACKPermitted && !foundSACKPermitted { 685 t.Errorf("SACKPermitted option not found. Options: %x", opts) 686 } 687 } 688 } 689 690 // TCPTimestampChecker creates a checker that validates that a TCP segment has a 691 // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and 692 // wantTSEcr values with those in the TCP segment (if present). 693 // 694 // If wantTSVal or wantTSEcr is zero then the corresponding comparison is 695 // skipped. 696 func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { 697 return func(t *testing.T, h header.Transport) { 698 t.Helper() 699 700 tcp, ok := h.(header.TCP) 701 if !ok { 702 return 703 } 704 opts := tcp.Options() 705 limit := len(opts) 706 foundTS := false 707 tsVal := uint32(0) 708 tsEcr := uint32(0) 709 for i := 0; i < limit; { 710 switch opts[i] { 711 case header.TCPOptionEOL: 712 i = limit 713 case header.TCPOptionNOP: 714 i++ 715 case header.TCPOptionTS: 716 if i+9 >= limit { 717 t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) 718 } 719 if opts[i+1] != 10 { 720 t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) 721 } 722 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 723 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 724 foundTS = true 725 i += 10 726 default: 727 // We don't recognize this option, just skip over it. 728 if i+2 > limit { 729 return 730 } 731 l := int(opts[i+1]) 732 if i < 2 || i+l > limit { 733 return 734 } 735 i += l 736 } 737 } 738 739 if wantTS != foundTS { 740 t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) 741 } 742 if wantTS && wantTSVal != 0 && wantTSVal != tsVal { 743 t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) 744 } 745 if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { 746 t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) 747 } 748 } 749 } 750 751 // TCPSACKBlockChecker creates a checker that verifies that the segment does 752 // contain the specified SACK blocks in the TCP options. 753 func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { 754 return func(t *testing.T, h header.Transport) { 755 t.Helper() 756 tcp, ok := h.(header.TCP) 757 if !ok { 758 return 759 } 760 var gotSACKBlocks []header.SACKBlock 761 762 opts := tcp.Options() 763 limit := len(opts) 764 for i := 0; i < limit; { 765 switch opts[i] { 766 case header.TCPOptionEOL: 767 i = limit 768 case header.TCPOptionNOP: 769 i++ 770 case header.TCPOptionSACK: 771 if i+2 > limit { 772 // Malformed SACK block. 773 t.Errorf("malformed SACK option in options: %v", opts) 774 } 775 sackOptionLen := int(opts[i+1]) 776 if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { 777 // Malformed SACK block. 778 t.Errorf("malformed SACK option length in options: %v", opts) 779 } 780 numBlocks := sackOptionLen / 8 781 for j := 0; j < numBlocks; j++ { 782 start := binary.BigEndian.Uint32(opts[i+2+j*8:]) 783 end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) 784 gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ 785 Start: seqnum.Value(start), 786 End: seqnum.Value(end), 787 }) 788 } 789 i += sackOptionLen 790 default: 791 // We don't recognize this option, just skip over it. 792 if i+2 > limit { 793 break 794 } 795 l := int(opts[i+1]) 796 if l < 2 || i+l > limit { 797 break 798 } 799 i += l 800 } 801 } 802 803 if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { 804 t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) 805 } 806 } 807 } 808 809 // Payload creates a checker that checks the payload. 810 func Payload(want []byte) TransportChecker { 811 return func(t *testing.T, h header.Transport) { 812 t.Helper() 813 814 if got := h.Payload(); !reflect.DeepEqual(got, want) { 815 t.Errorf("Wrong payload, got %v, want %v", got, want) 816 } 817 } 818 } 819 820 // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 821 // and potentially additional ICMPv4 header fields. 822 func ICMPv4(checkers ...TransportChecker) NetworkChecker { 823 return func(t *testing.T, h []header.Network) { 824 t.Helper() 825 826 last := h[len(h)-1] 827 828 if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { 829 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) 830 } 831 832 icmp := header.ICMPv4(last.Payload()) 833 for _, f := range checkers { 834 f(t, icmp) 835 } 836 if t.Failed() { 837 t.FailNow() 838 } 839 } 840 } 841 842 // ICMPv4Type creates a checker that checks the ICMPv4 Type field. 843 func ICMPv4Type(want header.ICMPv4Type) TransportChecker { 844 return func(t *testing.T, h header.Transport) { 845 t.Helper() 846 847 icmpv4, ok := h.(header.ICMPv4) 848 if !ok { 849 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 850 } 851 if got := icmpv4.Type(); got != want { 852 t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) 853 } 854 } 855 } 856 857 // ICMPv4Code creates a checker that checks the ICMPv4 Code field. 858 func ICMPv4Code(want header.ICMPv4Code) TransportChecker { 859 return func(t *testing.T, h header.Transport) { 860 t.Helper() 861 862 icmpv4, ok := h.(header.ICMPv4) 863 if !ok { 864 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 865 } 866 if got := icmpv4.Code(); got != want { 867 t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) 868 } 869 } 870 } 871 872 // ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. 873 func ICMPv4Ident(want uint16) TransportChecker { 874 return func(t *testing.T, h header.Transport) { 875 t.Helper() 876 877 icmpv4, ok := h.(header.ICMPv4) 878 if !ok { 879 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 880 } 881 if got := icmpv4.Ident(); got != want { 882 t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) 883 } 884 } 885 } 886 887 // ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. 888 func ICMPv4Seq(want uint16) TransportChecker { 889 return func(t *testing.T, h header.Transport) { 890 t.Helper() 891 892 icmpv4, ok := h.(header.ICMPv4) 893 if !ok { 894 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 895 } 896 if got := icmpv4.Sequence(); got != want { 897 t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) 898 } 899 } 900 } 901 902 // ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. 903 func ICMPv4Pointer(want uint8) TransportChecker { 904 return func(t *testing.T, h header.Transport) { 905 t.Helper() 906 907 icmpv4, ok := h.(header.ICMPv4) 908 if !ok { 909 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 910 } 911 if got := icmpv4.Pointer(); got != want { 912 t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) 913 } 914 } 915 } 916 917 // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. 918 // This assumes that the payload exactly makes up the rest of the slice. 919 func ICMPv4Checksum() TransportChecker { 920 return func(t *testing.T, h header.Transport) { 921 t.Helper() 922 923 icmpv4, ok := h.(header.ICMPv4) 924 if !ok { 925 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 926 } 927 heldChecksum := icmpv4.Checksum() 928 icmpv4.SetChecksum(0) 929 newChecksum := ^header.Checksum(icmpv4, 0) 930 icmpv4.SetChecksum(heldChecksum) 931 if heldChecksum != newChecksum { 932 t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) 933 } 934 } 935 } 936 937 // ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. 938 func ICMPv4Payload(want []byte) TransportChecker { 939 return func(t *testing.T, h header.Transport) { 940 t.Helper() 941 942 icmpv4, ok := h.(header.ICMPv4) 943 if !ok { 944 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) 945 } 946 payload := icmpv4.Payload() 947 948 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 949 if len(want) == 0 && len(payload) == 0 { 950 return 951 } 952 953 if diff := cmp.Diff(want, payload); diff != "" { 954 t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) 955 } 956 } 957 } 958 959 // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and 960 // potentially additional ICMPv6 header fields. 961 // 962 // ICMPv6 will validate the checksum field before calling checkers. 963 func ICMPv6(checkers ...TransportChecker) NetworkChecker { 964 return func(t *testing.T, h []header.Network) { 965 t.Helper() 966 967 last := h[len(h)-1] 968 969 if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { 970 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) 971 } 972 973 icmp := header.ICMPv6(last.Payload()) 974 if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 975 Header: icmp, 976 Src: last.SourceAddress(), 977 Dst: last.DestinationAddress(), 978 }); got != want { 979 t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) 980 } 981 982 for _, f := range checkers { 983 f(t, icmp) 984 } 985 if t.Failed() { 986 t.FailNow() 987 } 988 } 989 } 990 991 // ICMPv6Type creates a checker that checks the ICMPv6 Type field. 992 func ICMPv6Type(want header.ICMPv6Type) TransportChecker { 993 return func(t *testing.T, h header.Transport) { 994 t.Helper() 995 996 icmpv6, ok := h.(header.ICMPv6) 997 if !ok { 998 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 999 } 1000 if got := icmpv6.Type(); got != want { 1001 t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) 1002 } 1003 } 1004 } 1005 1006 // ICMPv6Code creates a checker that checks the ICMPv6 Code field. 1007 func ICMPv6Code(want header.ICMPv6Code) TransportChecker { 1008 return func(t *testing.T, h header.Transport) { 1009 t.Helper() 1010 1011 icmpv6, ok := h.(header.ICMPv6) 1012 if !ok { 1013 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1014 } 1015 if got := icmpv6.Code(); got != want { 1016 t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) 1017 } 1018 } 1019 } 1020 1021 // ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific 1022 // field. 1023 func ICMPv6TypeSpecific(want uint32) TransportChecker { 1024 return func(t *testing.T, h header.Transport) { 1025 t.Helper() 1026 1027 icmpv6, ok := h.(header.ICMPv6) 1028 if !ok { 1029 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1030 } 1031 if got := icmpv6.TypeSpecific(); got != want { 1032 t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) 1033 } 1034 } 1035 } 1036 1037 // ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. 1038 func ICMPv6Payload(want []byte) TransportChecker { 1039 return func(t *testing.T, h header.Transport) { 1040 t.Helper() 1041 1042 icmpv6, ok := h.(header.ICMPv6) 1043 if !ok { 1044 t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) 1045 } 1046 payload := icmpv6.Payload() 1047 1048 // cmp.Diff does not consider nil slices equal to empty slices, but we do. 1049 if len(want) == 0 && len(payload) == 0 { 1050 return 1051 } 1052 1053 if diff := cmp.Diff(want, payload); diff != "" { 1054 t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) 1055 } 1056 } 1057 } 1058 1059 // MLD creates a checker that checks that the packet contains a valid MLD 1060 // message for type of mldType, with potentially additional checks specified by 1061 // checkers. 1062 // 1063 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1064 // MLD message as far as the size of the message (minSize) is concerned. The 1065 // values within the message are up to checkers to validate. 1066 func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { 1067 return func(t *testing.T, h []header.Network) { 1068 t.Helper() 1069 1070 // Check normal ICMPv6 first. 1071 ICMPv6( 1072 ICMPv6Type(msgType), 1073 ICMPv6Code(0))(t, h) 1074 1075 last := h[len(h)-1] 1076 1077 icmp := header.ICMPv6(last.Payload()) 1078 if got := len(icmp.MessageBody()); got < minSize { 1079 t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) 1080 } 1081 1082 for _, f := range checkers { 1083 f(t, icmp) 1084 } 1085 if t.Failed() { 1086 t.FailNow() 1087 } 1088 } 1089 } 1090 1091 // MLDMaxRespDelay creates a checker that checks the Maximum Response Delay 1092 // field of a MLD message. 1093 // 1094 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1095 // containing a valid MLD message as far as the size is concerned. 1096 func MLDMaxRespDelay(want time.Duration) TransportChecker { 1097 return func(t *testing.T, h header.Transport) { 1098 t.Helper() 1099 1100 icmp := h.(header.ICMPv6) 1101 ns := header.MLD(icmp.MessageBody()) 1102 1103 if got := ns.MaximumResponseDelay(); got != want { 1104 t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) 1105 } 1106 } 1107 } 1108 1109 // MLDMulticastAddress creates a checker that checks the Multicast Address 1110 // field of a MLD message. 1111 // 1112 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1113 // containing a valid MLD message as far as the size is concerned. 1114 func MLDMulticastAddress(want tcpip.Address) TransportChecker { 1115 return func(t *testing.T, h header.Transport) { 1116 t.Helper() 1117 1118 icmp := h.(header.ICMPv6) 1119 ns := header.MLD(icmp.MessageBody()) 1120 1121 if got := ns.MulticastAddress(); got != want { 1122 t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) 1123 } 1124 } 1125 } 1126 1127 // NDP creates a checker that checks that the packet contains a valid NDP 1128 // message for type of ty, with potentially additional checks specified by 1129 // checkers. 1130 // 1131 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1132 // NDP message as far as the size of the message (minSize) is concerned. The 1133 // values within the message are up to checkers to validate. 1134 func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { 1135 return func(t *testing.T, h []header.Network) { 1136 t.Helper() 1137 1138 // Check normal ICMPv6 first. 1139 ICMPv6( 1140 ICMPv6Type(msgType), 1141 ICMPv6Code(0))(t, h) 1142 1143 last := h[len(h)-1] 1144 1145 icmp := header.ICMPv6(last.Payload()) 1146 if got := len(icmp.MessageBody()); got < minSize { 1147 t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) 1148 } 1149 1150 for _, f := range checkers { 1151 f(t, icmp) 1152 } 1153 if t.Failed() { 1154 t.FailNow() 1155 } 1156 } 1157 } 1158 1159 // NDPNS creates a checker that checks that the packet contains a valid NDP 1160 // Neighbor Solicitation message (as per the raw wire format), with potentially 1161 // additional checks specified by checkers. 1162 // 1163 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1164 // NDPNS message as far as the size of the message is concerned. The values 1165 // within the message are up to checkers to validate. 1166 func NDPNS(checkers ...TransportChecker) NetworkChecker { 1167 return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) 1168 } 1169 1170 // NDPNSTargetAddress creates a checker that checks the Target Address field of 1171 // a header.NDPNeighborSolicit. 1172 // 1173 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1174 // containing a valid NDPNS message as far as the size is concerned. 1175 func NDPNSTargetAddress(want tcpip.Address) TransportChecker { 1176 return func(t *testing.T, h header.Transport) { 1177 t.Helper() 1178 1179 icmp := h.(header.ICMPv6) 1180 ns := header.NDPNeighborSolicit(icmp.MessageBody()) 1181 1182 if got := ns.TargetAddress(); got != want { 1183 t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) 1184 } 1185 } 1186 } 1187 1188 // NDPNA creates a checker that checks that the packet contains a valid NDP 1189 // Neighbor Advertisement message (as per the raw wire format), with potentially 1190 // additional checks specified by checkers. 1191 // 1192 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1193 // NDPNA message as far as the size of the message is concerned. The values 1194 // within the message are up to checkers to validate. 1195 func NDPNA(checkers ...TransportChecker) NetworkChecker { 1196 return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) 1197 } 1198 1199 // NDPNATargetAddress creates a checker that checks the Target Address field of 1200 // a header.NDPNeighborAdvert. 1201 // 1202 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1203 // containing a valid NDPNA message as far as the size is concerned. 1204 func NDPNATargetAddress(want tcpip.Address) TransportChecker { 1205 return func(t *testing.T, h header.Transport) { 1206 t.Helper() 1207 1208 icmp := h.(header.ICMPv6) 1209 na := header.NDPNeighborAdvert(icmp.MessageBody()) 1210 1211 if got := na.TargetAddress(); got != want { 1212 t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) 1213 } 1214 } 1215 } 1216 1217 // NDPNASolicitedFlag creates a checker that checks the Solicited field of 1218 // a header.NDPNeighborAdvert. 1219 // 1220 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1221 // containing a valid NDPNA message as far as the size is concerned. 1222 func NDPNASolicitedFlag(want bool) TransportChecker { 1223 return func(t *testing.T, h header.Transport) { 1224 t.Helper() 1225 1226 icmp := h.(header.ICMPv6) 1227 na := header.NDPNeighborAdvert(icmp.MessageBody()) 1228 1229 if got := na.SolicitedFlag(); got != want { 1230 t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) 1231 } 1232 } 1233 } 1234 1235 // ndpOptions checks that optsBuf only contains opts. 1236 func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { 1237 t.Helper() 1238 1239 it, err := optsBuf.Iter(true) 1240 if err != nil { 1241 t.Errorf("optsBuf.Iter(true): %s", err) 1242 return 1243 } 1244 1245 i := 0 1246 for { 1247 opt, done, err := it.Next() 1248 if err != nil { 1249 // This should never happen as Iter(true) above did not return an error. 1250 t.Fatalf("unexpected error when iterating over NDP options: %s", err) 1251 } 1252 if done { 1253 break 1254 } 1255 1256 if i >= len(opts) { 1257 t.Errorf("got unexpected option: %s", opt) 1258 continue 1259 } 1260 1261 switch wantOpt := opts[i].(type) { 1262 case header.NDPSourceLinkLayerAddressOption: 1263 gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) 1264 if !ok { 1265 t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) 1266 } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { 1267 t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) 1268 } 1269 case header.NDPTargetLinkLayerAddressOption: 1270 gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) 1271 if !ok { 1272 t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) 1273 } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { 1274 t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) 1275 } 1276 case header.NDPNonceOption: 1277 gotOpt, ok := opt.(header.NDPNonceOption) 1278 if !ok { 1279 t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) 1280 } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" { 1281 t.Errorf("nonce mismatch (-want +got):\n%s", diff) 1282 } 1283 default: 1284 t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) 1285 } 1286 1287 i++ 1288 } 1289 1290 if missing := opts[i:]; len(missing) > 0 { 1291 t.Errorf("missing options: %s", missing) 1292 } 1293 } 1294 1295 // NDPNAOptions creates a checker that checks that the packet contains the 1296 // provided NDP options within an NDP Neighbor Solicitation message. 1297 // 1298 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1299 // containing a valid NDPNA message as far as the size is concerned. 1300 func NDPNAOptions(opts []header.NDPOption) TransportChecker { 1301 return func(t *testing.T, h header.Transport) { 1302 t.Helper() 1303 1304 icmp := h.(header.ICMPv6) 1305 na := header.NDPNeighborAdvert(icmp.MessageBody()) 1306 ndpOptions(t, na.Options(), opts) 1307 } 1308 } 1309 1310 // NDPNSOptions creates a checker that checks that the packet contains the 1311 // provided NDP options within an NDP Neighbor Solicitation message. 1312 // 1313 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1314 // containing a valid NDPNS message as far as the size is concerned. 1315 func NDPNSOptions(opts []header.NDPOption) TransportChecker { 1316 return func(t *testing.T, h header.Transport) { 1317 t.Helper() 1318 1319 icmp := h.(header.ICMPv6) 1320 ns := header.NDPNeighborSolicit(icmp.MessageBody()) 1321 ndpOptions(t, ns.Options(), opts) 1322 } 1323 } 1324 1325 // NDPRS creates a checker that checks that the packet contains a valid NDP 1326 // Router Solicitation message (as per the raw wire format). 1327 // 1328 // Checkers may assume that a valid ICMPv6 is passed to it containing a valid 1329 // NDPRS as far as the size of the message is concerned. The values within the 1330 // message are up to checkers to validate. 1331 func NDPRS(checkers ...TransportChecker) NetworkChecker { 1332 return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) 1333 } 1334 1335 // NDPRSOptions creates a checker that checks that the packet contains the 1336 // provided NDP options within an NDP Router Solicitation message. 1337 // 1338 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 1339 // containing a valid NDPRS message as far as the size is concerned. 1340 func NDPRSOptions(opts []header.NDPOption) TransportChecker { 1341 return func(t *testing.T, h header.Transport) { 1342 t.Helper() 1343 1344 icmp := h.(header.ICMPv6) 1345 rs := header.NDPRouterSolicit(icmp.MessageBody()) 1346 ndpOptions(t, rs.Options(), opts) 1347 } 1348 } 1349 1350 // IGMP checks the validity and properties of the given IGMP packet. It is 1351 // expected to be used in conjunction with other IGMP transport checkers for 1352 // specific properties. 1353 func IGMP(checkers ...TransportChecker) NetworkChecker { 1354 return func(t *testing.T, h []header.Network) { 1355 t.Helper() 1356 1357 last := h[len(h)-1] 1358 1359 if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { 1360 t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) 1361 } 1362 1363 igmp := header.IGMP(last.Payload()) 1364 for _, f := range checkers { 1365 f(t, igmp) 1366 } 1367 if t.Failed() { 1368 t.FailNow() 1369 } 1370 } 1371 } 1372 1373 // IGMPType creates a checker that checks the IGMP Type field. 1374 func IGMPType(want header.IGMPType) TransportChecker { 1375 return func(t *testing.T, h header.Transport) { 1376 t.Helper() 1377 1378 igmp, ok := h.(header.IGMP) 1379 if !ok { 1380 t.Fatalf("got transport header = %T, want = header.IGMP", h) 1381 } 1382 if got := igmp.Type(); got != want { 1383 t.Errorf("got igmp.Type() = %d, want = %d", got, want) 1384 } 1385 } 1386 } 1387 1388 // IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. 1389 func IGMPMaxRespTime(want time.Duration) TransportChecker { 1390 return func(t *testing.T, h header.Transport) { 1391 t.Helper() 1392 1393 igmp, ok := h.(header.IGMP) 1394 if !ok { 1395 t.Fatalf("got transport header = %T, want = header.IGMP", h) 1396 } 1397 if got := igmp.MaxRespTime(); got != want { 1398 t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) 1399 } 1400 } 1401 } 1402 1403 // IGMPGroupAddress creates a checker that checks the IGMP Group Address field. 1404 func IGMPGroupAddress(want tcpip.Address) TransportChecker { 1405 return func(t *testing.T, h header.Transport) { 1406 t.Helper() 1407 1408 igmp, ok := h.(header.IGMP) 1409 if !ok { 1410 t.Fatalf("got transport header = %T, want = header.IGMP", h) 1411 } 1412 if got := igmp.GroupAddress(); got != want { 1413 t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) 1414 } 1415 } 1416 } 1417 1418 // IPv6ExtHdrChecker is a function to check an extension header. 1419 type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) 1420 1421 // IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. 1422 func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { 1423 t.Helper() 1424 1425 ipv6 := header.IPv6(b) 1426 if !ipv6.IsValid(len(b)) { 1427 t.Error("not a valid IPv6 packet") 1428 return 1429 } 1430 1431 payloadIterator := header.MakeIPv6PayloadIterator( 1432 header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), 1433 buffer.View(ipv6.Payload()).ToVectorisedView(), 1434 ) 1435 1436 var rawPayloadHeader header.IPv6RawPayloadHeader 1437 for { 1438 h, done, err := payloadIterator.Next() 1439 if err != nil { 1440 t.Errorf("payloadIterator.Next(): %s", err) 1441 return 1442 } 1443 if done { 1444 t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) 1445 return 1446 } 1447 r, ok := h.(header.IPv6RawPayloadHeader) 1448 if ok { 1449 rawPayloadHeader = r 1450 break 1451 } 1452 } 1453 1454 networkHeader := ipv6HeaderWithExtHdr{ 1455 IPv6: ipv6, 1456 transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), 1457 payload: rawPayloadHeader.Buf.ToView(), 1458 } 1459 1460 for _, checker := range checkers { 1461 checker(t, []header.Network{&networkHeader}) 1462 } 1463 } 1464 1465 // IPv6ExtHdr checks for the presence of extension headers. 1466 // 1467 // All the extension headers in headers will be checked exhaustively in the 1468 // order provided. 1469 func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { 1470 return func(t *testing.T, h []header.Network) { 1471 t.Helper() 1472 1473 extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) 1474 if !ok { 1475 t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) 1476 return 1477 } 1478 1479 payloadIterator := header.MakeIPv6PayloadIterator( 1480 header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), 1481 buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), 1482 ) 1483 1484 for _, check := range headers { 1485 h, done, err := payloadIterator.Next() 1486 if err != nil { 1487 t.Errorf("payloadIterator.Next(): %s", err) 1488 return 1489 } 1490 if done { 1491 t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) 1492 return 1493 } 1494 check(t, h) 1495 } 1496 // Validate we consumed all headers. 1497 // 1498 // The next one over should be a raw payload and then iterator should 1499 // terminate. 1500 wantDone := false 1501 for { 1502 h, done, err := payloadIterator.Next() 1503 if err != nil { 1504 t.Errorf("payloadIterator.Next(): %s", err) 1505 return 1506 } 1507 if done != wantDone { 1508 t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) 1509 return 1510 } 1511 if done { 1512 break 1513 } 1514 if _, ok := h.(header.IPv6RawPayloadHeader); !ok { 1515 t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) 1516 continue 1517 } 1518 wantDone = true 1519 } 1520 } 1521 } 1522 1523 var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) 1524 1525 // ipv6HeaderWithExtHdr provides a header.Network implementation that takes 1526 // extension headers into consideration, which is not the case with vanilla 1527 // header.IPv6. 1528 type ipv6HeaderWithExtHdr struct { 1529 header.IPv6 1530 transport tcpip.TransportProtocolNumber 1531 payload []byte 1532 } 1533 1534 // TransportProtocol implements header.Network. 1535 func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { 1536 return h.transport 1537 } 1538 1539 // Payload implements header.Network. 1540 func (h *ipv6HeaderWithExtHdr) Payload() []byte { 1541 return h.payload 1542 } 1543 1544 // IPv6ExtHdrOptionChecker is a function to check an extension header option. 1545 type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) 1546 1547 // IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop 1548 // extension header and validates the containing options with checkers. 1549 // 1550 // checkers must exhaustively contain all the expected options. 1551 func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { 1552 return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { 1553 t.Helper() 1554 1555 hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) 1556 if !ok { 1557 t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) 1558 return 1559 } 1560 optionsIterator := hbh.Iter() 1561 for _, f := range checkers { 1562 opt, done, err := optionsIterator.Next() 1563 if err != nil { 1564 t.Errorf("optionsIterator.Next(): %s", err) 1565 return 1566 } 1567 if done { 1568 t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) 1569 } 1570 f(t, opt) 1571 } 1572 // Validate all options were consumed. 1573 for { 1574 opt, done, err := optionsIterator.Next() 1575 if err != nil { 1576 t.Errorf("optionsIterator.Next(): %s", err) 1577 return 1578 } 1579 if !done { 1580 t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) 1581 } 1582 if done { 1583 break 1584 } 1585 } 1586 } 1587 } 1588 1589 // IPv6RouterAlert validates that an extension header option is the RouterAlert 1590 // option and matches on its value. 1591 func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { 1592 return func(t *testing.T, opt header.IPv6ExtHdrOption) { 1593 routerAlert, ok := opt.(*header.IPv6RouterAlertOption) 1594 if !ok { 1595 t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) 1596 return 1597 } 1598 if routerAlert.Value != want { 1599 t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) 1600 } 1601 } 1602 } 1603 1604 // IPv6UnknownOption validates that an extension header option is the 1605 // unknown header option. 1606 func IPv6UnknownOption() IPv6ExtHdrOptionChecker { 1607 return func(t *testing.T, opt header.IPv6ExtHdrOption) { 1608 _, ok := opt.(*header.IPv6UnknownExtHdrOption) 1609 if !ok { 1610 t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt) 1611 } 1612 } 1613 } 1614 1615 // IgnoreCmpPath returns a cmp.Option that ignores listed field paths. 1616 func IgnoreCmpPath(paths ...string) cmp.Option { 1617 ignores := map[string]struct{}{} 1618 for _, path := range paths { 1619 ignores[path] = struct{}{} 1620 } 1621 return cmp.FilterPath(func(path cmp.Path) bool { 1622 _, ok := ignores[path.String()] 1623 return ok 1624 }, cmp.Ignore()) 1625 }