github.com/google/netstack@v0.0.0-20191123085552-55fcc16cd0eb/tcpip/checker/checker.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package checker provides helper functions to check networking packets for 16 // validity. 17 package checker 18 19 import ( 20 "encoding/binary" 21 "reflect" 22 "testing" 23 24 "github.com/google/netstack/tcpip" 25 "github.com/google/netstack/tcpip/buffer" 26 "github.com/google/netstack/tcpip/header" 27 "github.com/google/netstack/tcpip/seqnum" 28 ) 29 30 // NetworkChecker is a function to check a property of a network packet. 31 type NetworkChecker func(*testing.T, []header.Network) 32 33 // TransportChecker is a function to check a property of a transport packet. 34 type TransportChecker func(*testing.T, header.Transport) 35 36 // IPv4 checks the validity and properties of the given IPv4 packet. It is 37 // expected to be used in conjunction with other network checkers for specific 38 // properties. For example, to check the source and destination address, one 39 // would call: 40 // 41 // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) 42 func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { 43 t.Helper() 44 45 ipv4 := header.IPv4(b) 46 47 if !ipv4.IsValid(len(b)) { 48 t.Error("Not a valid IPv4 packet") 49 } 50 51 xsum := ipv4.CalculateChecksum() 52 if xsum != 0 && xsum != 0xffff { 53 t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) 54 } 55 56 for _, f := range checkers { 57 f(t, []header.Network{ipv4}) 58 } 59 if t.Failed() { 60 t.FailNow() 61 } 62 } 63 64 // IPv6 checks the validity and properties of the given IPv6 packet. The usage 65 // is similar to IPv4. 66 func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { 67 t.Helper() 68 69 ipv6 := header.IPv6(b) 70 if !ipv6.IsValid(len(b)) { 71 t.Error("Not a valid IPv6 packet") 72 } 73 74 for _, f := range checkers { 75 f(t, []header.Network{ipv6}) 76 } 77 if t.Failed() { 78 t.FailNow() 79 } 80 } 81 82 // SrcAddr creates a checker that checks the source address. 83 func SrcAddr(addr tcpip.Address) NetworkChecker { 84 return func(t *testing.T, h []header.Network) { 85 t.Helper() 86 87 if a := h[0].SourceAddress(); a != addr { 88 t.Errorf("Bad source address, got %v, want %v", a, addr) 89 } 90 } 91 } 92 93 // DstAddr creates a checker that checks the destination address. 94 func DstAddr(addr tcpip.Address) NetworkChecker { 95 return func(t *testing.T, h []header.Network) { 96 t.Helper() 97 98 if a := h[0].DestinationAddress(); a != addr { 99 t.Errorf("Bad destination address, got %v, want %v", a, addr) 100 } 101 } 102 } 103 104 // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). 105 func TTL(ttl uint8) NetworkChecker { 106 return func(t *testing.T, h []header.Network) { 107 var v uint8 108 switch ip := h[0].(type) { 109 case header.IPv4: 110 v = ip.TTL() 111 case header.IPv6: 112 v = ip.HopLimit() 113 } 114 if v != ttl { 115 t.Fatalf("Bad TTL, got %v, want %v", v, ttl) 116 } 117 } 118 } 119 120 // PayloadLen creates a checker that checks the payload length. 121 func PayloadLen(plen int) NetworkChecker { 122 return func(t *testing.T, h []header.Network) { 123 t.Helper() 124 125 if l := len(h[0].Payload()); l != plen { 126 t.Errorf("Bad payload length, got %v, want %v", l, plen) 127 } 128 } 129 } 130 131 // FragmentOffset creates a checker that checks the FragmentOffset field. 132 func FragmentOffset(offset uint16) NetworkChecker { 133 return func(t *testing.T, h []header.Network) { 134 t.Helper() 135 136 // We only do this of IPv4 for now. 137 switch ip := h[0].(type) { 138 case header.IPv4: 139 if v := ip.FragmentOffset(); v != offset { 140 t.Errorf("Bad fragment offset, got %v, want %v", v, offset) 141 } 142 } 143 } 144 } 145 146 // FragmentFlags creates a checker that checks the fragment flags field. 147 func FragmentFlags(flags uint8) NetworkChecker { 148 return func(t *testing.T, h []header.Network) { 149 t.Helper() 150 151 // We only do this of IPv4 for now. 152 switch ip := h[0].(type) { 153 case header.IPv4: 154 if v := ip.Flags(); v != flags { 155 t.Errorf("Bad fragment offset, got %v, want %v", v, flags) 156 } 157 } 158 } 159 } 160 161 // TOS creates a checker that checks the TOS field. 162 func TOS(tos uint8, label uint32) NetworkChecker { 163 return func(t *testing.T, h []header.Network) { 164 t.Helper() 165 166 if v, l := h[0].TOS(); v != tos || l != label { 167 t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) 168 } 169 } 170 } 171 172 // Raw creates a checker that checks the bytes of payload. 173 // The checker always checks the payload of the last network header. 174 // For instance, in case of IPv6 fragments, the payload that will be checked 175 // is the one containing the actual data that the packet is carrying, without 176 // the bytes added by the IPv6 fragmentation. 177 func Raw(want []byte) NetworkChecker { 178 return func(t *testing.T, h []header.Network) { 179 t.Helper() 180 181 if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { 182 t.Errorf("Wrong payload, got %v, want %v", got, want) 183 } 184 } 185 } 186 187 // IPv6Fragment creates a checker that validates an IPv6 fragment. 188 func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { 189 return func(t *testing.T, h []header.Network) { 190 t.Helper() 191 192 if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { 193 t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) 194 } 195 196 ipv6Frag := header.IPv6Fragment(h[0].Payload()) 197 if !ipv6Frag.IsValid() { 198 t.Error("Not a valid IPv6 fragment") 199 } 200 201 for _, f := range checkers { 202 f(t, []header.Network{h[0], ipv6Frag}) 203 } 204 if t.Failed() { 205 t.FailNow() 206 } 207 } 208 } 209 210 // TCP creates a checker that checks that the transport protocol is TCP and 211 // potentially additional transport header fields. 212 func TCP(checkers ...TransportChecker) NetworkChecker { 213 return func(t *testing.T, h []header.Network) { 214 t.Helper() 215 216 first := h[0] 217 last := h[len(h)-1] 218 219 if p := last.TransportProtocol(); p != header.TCPProtocolNumber { 220 t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) 221 } 222 223 // Verify the checksum. 224 tcp := header.TCP(last.Payload()) 225 l := uint16(len(tcp)) 226 227 xsum := header.Checksum([]byte(first.SourceAddress()), 0) 228 xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) 229 xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) 230 xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) 231 xsum = header.Checksum(tcp, xsum) 232 233 if xsum != 0 && xsum != 0xffff { 234 t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) 235 } 236 237 // Run the transport checkers. 238 for _, f := range checkers { 239 f(t, tcp) 240 } 241 if t.Failed() { 242 t.FailNow() 243 } 244 } 245 } 246 247 // UDP creates a checker that checks that the transport protocol is UDP and 248 // potentially additional transport header fields. 249 func UDP(checkers ...TransportChecker) NetworkChecker { 250 return func(t *testing.T, h []header.Network) { 251 t.Helper() 252 253 last := h[len(h)-1] 254 255 if p := last.TransportProtocol(); p != header.UDPProtocolNumber { 256 t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) 257 } 258 259 udp := header.UDP(last.Payload()) 260 for _, f := range checkers { 261 f(t, udp) 262 } 263 if t.Failed() { 264 t.FailNow() 265 } 266 } 267 } 268 269 // SrcPort creates a checker that checks the source port. 270 func SrcPort(port uint16) TransportChecker { 271 return func(t *testing.T, h header.Transport) { 272 t.Helper() 273 274 if p := h.SourcePort(); p != port { 275 t.Errorf("Bad source port, got %v, want %v", p, port) 276 } 277 } 278 } 279 280 // DstPort creates a checker that checks the destination port. 281 func DstPort(port uint16) TransportChecker { 282 return func(t *testing.T, h header.Transport) { 283 if p := h.DestinationPort(); p != port { 284 t.Errorf("Bad destination port, got %v, want %v", p, port) 285 } 286 } 287 } 288 289 // SeqNum creates a checker that checks the sequence number. 290 func SeqNum(seq uint32) TransportChecker { 291 return func(t *testing.T, h header.Transport) { 292 t.Helper() 293 294 tcp, ok := h.(header.TCP) 295 if !ok { 296 return 297 } 298 299 if s := tcp.SequenceNumber(); s != seq { 300 t.Errorf("Bad sequence number, got %v, want %v", s, seq) 301 } 302 } 303 } 304 305 // AckNum creates a checker that checks the ack number. 306 func AckNum(seq uint32) TransportChecker { 307 return func(t *testing.T, h header.Transport) { 308 t.Helper() 309 tcp, ok := h.(header.TCP) 310 if !ok { 311 return 312 } 313 314 if s := tcp.AckNumber(); s != seq { 315 t.Errorf("Bad ack number, got %v, want %v", s, seq) 316 } 317 } 318 } 319 320 // Window creates a checker that checks the tcp window. 321 func Window(window uint16) TransportChecker { 322 return func(t *testing.T, h header.Transport) { 323 tcp, ok := h.(header.TCP) 324 if !ok { 325 return 326 } 327 328 if w := tcp.WindowSize(); w != window { 329 t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) 330 } 331 } 332 } 333 334 // TCPFlags creates a checker that checks the tcp flags. 335 func TCPFlags(flags uint8) TransportChecker { 336 return func(t *testing.T, h header.Transport) { 337 t.Helper() 338 339 tcp, ok := h.(header.TCP) 340 if !ok { 341 return 342 } 343 344 if f := tcp.Flags(); f != flags { 345 t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) 346 } 347 } 348 } 349 350 // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the 351 // given mask, match the supplied flags. 352 func TCPFlagsMatch(flags, mask uint8) TransportChecker { 353 return func(t *testing.T, h header.Transport) { 354 tcp, ok := h.(header.TCP) 355 if !ok { 356 return 357 } 358 359 if f := tcp.Flags(); (f & mask) != (flags & mask) { 360 t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) 361 } 362 } 363 } 364 365 // TCPSynOptions creates a checker that checks the presence of TCP options in 366 // SYN segments. 367 // 368 // If wndscale is negative, the window scale option must not be present. 369 func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { 370 return func(t *testing.T, h header.Transport) { 371 tcp, ok := h.(header.TCP) 372 if !ok { 373 return 374 } 375 opts := tcp.Options() 376 limit := len(opts) 377 foundMSS := false 378 foundWS := false 379 foundTS := false 380 foundSACKPermitted := false 381 tsVal := uint32(0) 382 tsEcr := uint32(0) 383 for i := 0; i < limit; { 384 switch opts[i] { 385 case header.TCPOptionEOL: 386 i = limit 387 case header.TCPOptionNOP: 388 i++ 389 case header.TCPOptionMSS: 390 v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) 391 if wantOpts.MSS != v { 392 t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) 393 } 394 foundMSS = true 395 i += 4 396 case header.TCPOptionWS: 397 if wantOpts.WS < 0 { 398 t.Error("WS present when it shouldn't be") 399 } 400 v := int(opts[i+2]) 401 if v != wantOpts.WS { 402 t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) 403 } 404 foundWS = true 405 i += 3 406 case header.TCPOptionTS: 407 if i+9 >= limit { 408 t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) 409 } 410 if opts[i+1] != 10 { 411 t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) 412 } 413 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 414 tsEcr = uint32(0) 415 if tcp.Flags()&header.TCPFlagAck != 0 { 416 // If the syn is an SYN-ACK then read 417 // the tsEcr value as well. 418 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 419 } 420 foundTS = true 421 i += 10 422 case header.TCPOptionSACKPermitted: 423 if i+1 >= limit { 424 t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) 425 } 426 if opts[i+1] != 2 { 427 t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) 428 } 429 foundSACKPermitted = true 430 i += 2 431 432 default: 433 i += int(opts[i+1]) 434 } 435 } 436 437 if !foundMSS { 438 t.Errorf("MSS option not found. Options: %x", opts) 439 } 440 441 if !foundWS && wantOpts.WS >= 0 { 442 t.Errorf("WS option not found. Options: %x", opts) 443 } 444 if wantOpts.TS && !foundTS { 445 t.Errorf("TS option not found. Options: %x", opts) 446 } 447 if foundTS && tsVal == 0 { 448 t.Error("TS option specified but the timestamp value is zero") 449 } 450 if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { 451 t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) 452 } 453 if wantOpts.SACKPermitted && !foundSACKPermitted { 454 t.Errorf("SACKPermitted option not found. Options: %x", opts) 455 } 456 } 457 } 458 459 // TCPTimestampChecker creates a checker that validates that a TCP segment has a 460 // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and 461 // wantTSEcr values with those in the TCP segment (if present). 462 // 463 // If wantTSVal or wantTSEcr is zero then the corresponding comparison is 464 // skipped. 465 func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { 466 return func(t *testing.T, h header.Transport) { 467 tcp, ok := h.(header.TCP) 468 if !ok { 469 return 470 } 471 opts := []byte(tcp.Options()) 472 limit := len(opts) 473 foundTS := false 474 tsVal := uint32(0) 475 tsEcr := uint32(0) 476 for i := 0; i < limit; { 477 switch opts[i] { 478 case header.TCPOptionEOL: 479 i = limit 480 case header.TCPOptionNOP: 481 i++ 482 case header.TCPOptionTS: 483 if i+9 >= limit { 484 t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) 485 } 486 if opts[i+1] != 10 { 487 t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) 488 } 489 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 490 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 491 foundTS = true 492 i += 10 493 default: 494 // We don't recognize this option, just skip over it. 495 if i+2 > limit { 496 return 497 } 498 l := int(opts[i+1]) 499 if i < 2 || i+l > limit { 500 return 501 } 502 i += l 503 } 504 } 505 506 if wantTS != foundTS { 507 t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) 508 } 509 if wantTS && wantTSVal != 0 && wantTSVal != tsVal { 510 t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) 511 } 512 if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { 513 t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) 514 } 515 } 516 } 517 518 // TCPNoSACKBlockChecker creates a checker that verifies that the segment does not 519 // contain any SACK blocks in the TCP options. 520 func TCPNoSACKBlockChecker() TransportChecker { 521 return TCPSACKBlockChecker(nil) 522 } 523 524 // TCPSACKBlockChecker creates a checker that verifies that the segment does 525 // contain the specified SACK blocks in the TCP options. 526 func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { 527 return func(t *testing.T, h header.Transport) { 528 t.Helper() 529 tcp, ok := h.(header.TCP) 530 if !ok { 531 return 532 } 533 var gotSACKBlocks []header.SACKBlock 534 535 opts := []byte(tcp.Options()) 536 limit := len(opts) 537 for i := 0; i < limit; { 538 switch opts[i] { 539 case header.TCPOptionEOL: 540 i = limit 541 case header.TCPOptionNOP: 542 i++ 543 case header.TCPOptionSACK: 544 if i+2 > limit { 545 // Malformed SACK block. 546 t.Errorf("malformed SACK option in options: %v", opts) 547 } 548 sackOptionLen := int(opts[i+1]) 549 if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { 550 // Malformed SACK block. 551 t.Errorf("malformed SACK option length in options: %v", opts) 552 } 553 numBlocks := sackOptionLen / 8 554 for j := 0; j < numBlocks; j++ { 555 start := binary.BigEndian.Uint32(opts[i+2+j*8:]) 556 end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) 557 gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ 558 Start: seqnum.Value(start), 559 End: seqnum.Value(end), 560 }) 561 } 562 i += sackOptionLen 563 default: 564 // We don't recognize this option, just skip over it. 565 if i+2 > limit { 566 break 567 } 568 l := int(opts[i+1]) 569 if l < 2 || i+l > limit { 570 break 571 } 572 i += l 573 } 574 } 575 576 if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { 577 t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) 578 } 579 } 580 } 581 582 // Payload creates a checker that checks the payload. 583 func Payload(want []byte) TransportChecker { 584 return func(t *testing.T, h header.Transport) { 585 if got := h.Payload(); !reflect.DeepEqual(got, want) { 586 t.Errorf("Wrong payload, got %v, want %v", got, want) 587 } 588 } 589 } 590 591 // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and 592 // potentially additional ICMPv4 header fields. 593 func ICMPv4(checkers ...TransportChecker) NetworkChecker { 594 return func(t *testing.T, h []header.Network) { 595 t.Helper() 596 597 last := h[len(h)-1] 598 599 if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { 600 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) 601 } 602 603 icmp := header.ICMPv4(last.Payload()) 604 for _, f := range checkers { 605 f(t, icmp) 606 } 607 if t.Failed() { 608 t.FailNow() 609 } 610 } 611 } 612 613 // ICMPv4Type creates a checker that checks the ICMPv4 Type field. 614 func ICMPv4Type(want header.ICMPv4Type) TransportChecker { 615 return func(t *testing.T, h header.Transport) { 616 t.Helper() 617 icmpv4, ok := h.(header.ICMPv4) 618 if !ok { 619 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) 620 } 621 if got := icmpv4.Type(); got != want { 622 t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) 623 } 624 } 625 } 626 627 // ICMPv4Code creates a checker that checks the ICMPv4 Code field. 628 func ICMPv4Code(want byte) TransportChecker { 629 return func(t *testing.T, h header.Transport) { 630 t.Helper() 631 icmpv4, ok := h.(header.ICMPv4) 632 if !ok { 633 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) 634 } 635 if got := icmpv4.Code(); got != want { 636 t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) 637 } 638 } 639 } 640 641 // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and 642 // potentially additional ICMPv6 header fields. 643 // 644 // ICMPv6 will validate the checksum field before calling checkers. 645 func ICMPv6(checkers ...TransportChecker) NetworkChecker { 646 return func(t *testing.T, h []header.Network) { 647 t.Helper() 648 649 last := h[len(h)-1] 650 651 if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { 652 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) 653 } 654 655 icmp := header.ICMPv6(last.Payload()) 656 if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { 657 t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) 658 } 659 660 for _, f := range checkers { 661 f(t, icmp) 662 } 663 if t.Failed() { 664 t.FailNow() 665 } 666 } 667 } 668 669 // ICMPv6Type creates a checker that checks the ICMPv6 Type field. 670 func ICMPv6Type(want header.ICMPv6Type) TransportChecker { 671 return func(t *testing.T, h header.Transport) { 672 t.Helper() 673 icmpv6, ok := h.(header.ICMPv6) 674 if !ok { 675 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) 676 } 677 if got := icmpv6.Type(); got != want { 678 t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) 679 } 680 } 681 } 682 683 // ICMPv6Code creates a checker that checks the ICMPv6 Code field. 684 func ICMPv6Code(want byte) TransportChecker { 685 return func(t *testing.T, h header.Transport) { 686 t.Helper() 687 icmpv6, ok := h.(header.ICMPv6) 688 if !ok { 689 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) 690 } 691 if got := icmpv6.Code(); got != want { 692 t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) 693 } 694 } 695 } 696 697 // NDP creates a checker that checks that the packet contains a valid NDP 698 // message for type of ty, with potentially additional checks specified by 699 // checkers. 700 // 701 // checkers may assume that a valid ICMPv6 is passed to it containing a valid 702 // NDP message as far as the size of the message (minSize) is concerned. The 703 // values within the message are up to checkers to validate. 704 func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { 705 return func(t *testing.T, h []header.Network) { 706 t.Helper() 707 708 // Check normal ICMPv6 first. 709 ICMPv6( 710 ICMPv6Type(msgType), 711 ICMPv6Code(0))(t, h) 712 713 last := h[len(h)-1] 714 715 icmp := header.ICMPv6(last.Payload()) 716 if got := len(icmp.NDPPayload()); got < minSize { 717 t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) 718 } 719 720 for _, f := range checkers { 721 f(t, icmp) 722 } 723 if t.Failed() { 724 t.FailNow() 725 } 726 } 727 } 728 729 // NDPNS creates a checker that checks that the packet contains a valid NDP 730 // Neighbor Solicitation message (as per the raw wire format), with potentially 731 // additional checks specified by checkers. 732 // 733 // checkers may assume that a valid ICMPv6 is passed to it containing a valid 734 // NDPNS message as far as the size of the messages concerned. The values within 735 // the message are up to checkers to validate. 736 func NDPNS(checkers ...TransportChecker) NetworkChecker { 737 return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) 738 } 739 740 // NDPNSTargetAddress creates a checker that checks the Target Address field of 741 // a header.NDPNeighborSolicit. 742 // 743 // The returned TransportChecker assumes that a valid ICMPv6 is passed to it 744 // containing a valid NDPNS message as far as the size is concerned. 745 func NDPNSTargetAddress(want tcpip.Address) TransportChecker { 746 return func(t *testing.T, h header.Transport) { 747 t.Helper() 748 749 icmp := h.(header.ICMPv6) 750 ns := header.NDPNeighborSolicit(icmp.NDPPayload()) 751 752 if got := ns.TargetAddress(); got != want { 753 t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) 754 } 755 } 756 }