github.com/flowerwrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/checker/checker.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package checker provides helper functions to check networking packets for 16 // validity. 17 package checker 18 19 import ( 20 "encoding/binary" 21 "reflect" 22 "testing" 23 24 "github.com/FlowerWrong/netstack/tcpip" 25 "github.com/FlowerWrong/netstack/tcpip/header" 26 "github.com/FlowerWrong/netstack/tcpip/seqnum" 27 ) 28 29 // NetworkChecker is a function to check a property of a network packet. 30 type NetworkChecker func(*testing.T, []header.Network) 31 32 // TransportChecker is a function to check a property of a transport packet. 33 type TransportChecker func(*testing.T, header.Transport) 34 35 // IPv4 checks the validity and properties of the given IPv4 packet. It is 36 // expected to be used in conjunction with other network checkers for specific 37 // properties. For example, to check the source and destination address, one 38 // would call: 39 // 40 // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) 41 func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { 42 t.Helper() 43 44 ipv4 := header.IPv4(b) 45 46 if !ipv4.IsValid(len(b)) { 47 t.Error("Not a valid IPv4 packet") 48 } 49 50 xsum := ipv4.CalculateChecksum() 51 if xsum != 0 && xsum != 0xffff { 52 t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) 53 } 54 55 for _, f := range checkers { 56 f(t, []header.Network{ipv4}) 57 } 58 if t.Failed() { 59 t.FailNow() 60 } 61 } 62 63 // IPv6 checks the validity and properties of the given IPv6 packet. The usage 64 // is similar to IPv4. 65 func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { 66 t.Helper() 67 68 ipv6 := header.IPv6(b) 69 if !ipv6.IsValid(len(b)) { 70 t.Error("Not a valid IPv6 packet") 71 } 72 73 for _, f := range checkers { 74 f(t, []header.Network{ipv6}) 75 } 76 if t.Failed() { 77 t.FailNow() 78 } 79 } 80 81 // SrcAddr creates a checker that checks the source address. 82 func SrcAddr(addr tcpip.Address) NetworkChecker { 83 return func(t *testing.T, h []header.Network) { 84 t.Helper() 85 86 if a := h[0].SourceAddress(); a != addr { 87 t.Errorf("Bad source address, got %v, want %v", a, addr) 88 } 89 } 90 } 91 92 // DstAddr creates a checker that checks the destination address. 93 func DstAddr(addr tcpip.Address) NetworkChecker { 94 return func(t *testing.T, h []header.Network) { 95 t.Helper() 96 97 if a := h[0].DestinationAddress(); a != addr { 98 t.Errorf("Bad destination address, got %v, want %v", a, addr) 99 } 100 } 101 } 102 103 // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). 104 func TTL(ttl uint8) NetworkChecker { 105 return func(t *testing.T, h []header.Network) { 106 var v uint8 107 switch ip := h[0].(type) { 108 case header.IPv4: 109 v = ip.TTL() 110 case header.IPv6: 111 v = ip.HopLimit() 112 } 113 if v != ttl { 114 t.Fatalf("Bad TTL, got %v, want %v", v, ttl) 115 } 116 } 117 } 118 119 // PayloadLen creates a checker that checks the payload length. 120 func PayloadLen(plen int) NetworkChecker { 121 return func(t *testing.T, h []header.Network) { 122 t.Helper() 123 124 if l := len(h[0].Payload()); l != plen { 125 t.Errorf("Bad payload length, got %v, want %v", l, plen) 126 } 127 } 128 } 129 130 // FragmentOffset creates a checker that checks the FragmentOffset field. 131 func FragmentOffset(offset uint16) NetworkChecker { 132 return func(t *testing.T, h []header.Network) { 133 t.Helper() 134 135 // We only do this of IPv4 for now. 136 switch ip := h[0].(type) { 137 case header.IPv4: 138 if v := ip.FragmentOffset(); v != offset { 139 t.Errorf("Bad fragment offset, got %v, want %v", v, offset) 140 } 141 } 142 } 143 } 144 145 // FragmentFlags creates a checker that checks the fragment flags field. 146 func FragmentFlags(flags uint8) NetworkChecker { 147 return func(t *testing.T, h []header.Network) { 148 t.Helper() 149 150 // We only do this of IPv4 for now. 151 switch ip := h[0].(type) { 152 case header.IPv4: 153 if v := ip.Flags(); v != flags { 154 t.Errorf("Bad fragment offset, got %v, want %v", v, flags) 155 } 156 } 157 } 158 } 159 160 // TOS creates a checker that checks the TOS field. 161 func TOS(tos uint8, label uint32) NetworkChecker { 162 return func(t *testing.T, h []header.Network) { 163 t.Helper() 164 165 if v, l := h[0].TOS(); v != tos || l != label { 166 t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) 167 } 168 } 169 } 170 171 // Raw creates a checker that checks the bytes of payload. 172 // The checker always checks the payload of the last network header. 173 // For instance, in case of IPv6 fragments, the payload that will be checked 174 // is the one containing the actual data that the packet is carrying, without 175 // the bytes added by the IPv6 fragmentation. 176 func Raw(want []byte) NetworkChecker { 177 return func(t *testing.T, h []header.Network) { 178 t.Helper() 179 180 if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { 181 t.Errorf("Wrong payload, got %v, want %v", got, want) 182 } 183 } 184 } 185 186 // IPv6Fragment creates a checker that validates an IPv6 fragment. 187 func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { 188 return func(t *testing.T, h []header.Network) { 189 t.Helper() 190 191 if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { 192 t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) 193 } 194 195 ipv6Frag := header.IPv6Fragment(h[0].Payload()) 196 if !ipv6Frag.IsValid() { 197 t.Error("Not a valid IPv6 fragment") 198 } 199 200 for _, f := range checkers { 201 f(t, []header.Network{h[0], ipv6Frag}) 202 } 203 if t.Failed() { 204 t.FailNow() 205 } 206 } 207 } 208 209 // TCP creates a checker that checks that the transport protocol is TCP and 210 // potentially additional transport header fields. 211 func TCP(checkers ...TransportChecker) NetworkChecker { 212 return func(t *testing.T, h []header.Network) { 213 t.Helper() 214 215 first := h[0] 216 last := h[len(h)-1] 217 218 if p := last.TransportProtocol(); p != header.TCPProtocolNumber { 219 t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) 220 } 221 222 // Verify the checksum. 223 tcp := header.TCP(last.Payload()) 224 l := uint16(len(tcp)) 225 226 xsum := header.Checksum([]byte(first.SourceAddress()), 0) 227 xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) 228 xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) 229 xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) 230 xsum = header.Checksum(tcp, xsum) 231 232 if xsum != 0 && xsum != 0xffff { 233 t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) 234 } 235 236 // Run the transport checkers. 237 for _, f := range checkers { 238 f(t, tcp) 239 } 240 if t.Failed() { 241 t.FailNow() 242 } 243 } 244 } 245 246 // UDP creates a checker that checks that the transport protocol is UDP and 247 // potentially additional transport header fields. 248 func UDP(checkers ...TransportChecker) NetworkChecker { 249 return func(t *testing.T, h []header.Network) { 250 t.Helper() 251 252 last := h[len(h)-1] 253 254 if p := last.TransportProtocol(); p != header.UDPProtocolNumber { 255 t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) 256 } 257 258 udp := header.UDP(last.Payload()) 259 for _, f := range checkers { 260 f(t, udp) 261 } 262 if t.Failed() { 263 t.FailNow() 264 } 265 } 266 } 267 268 // SrcPort creates a checker that checks the source port. 269 func SrcPort(port uint16) TransportChecker { 270 return func(t *testing.T, h header.Transport) { 271 t.Helper() 272 273 if p := h.SourcePort(); p != port { 274 t.Errorf("Bad source port, got %v, want %v", p, port) 275 } 276 } 277 } 278 279 // DstPort creates a checker that checks the destination port. 280 func DstPort(port uint16) TransportChecker { 281 return func(t *testing.T, h header.Transport) { 282 if p := h.DestinationPort(); p != port { 283 t.Errorf("Bad destination port, got %v, want %v", p, port) 284 } 285 } 286 } 287 288 // SeqNum creates a checker that checks the sequence number. 289 func SeqNum(seq uint32) TransportChecker { 290 return func(t *testing.T, h header.Transport) { 291 t.Helper() 292 293 tcp, ok := h.(header.TCP) 294 if !ok { 295 return 296 } 297 298 if s := tcp.SequenceNumber(); s != seq { 299 t.Errorf("Bad sequence number, got %v, want %v", s, seq) 300 } 301 } 302 } 303 304 // AckNum creates a checker that checks the ack number. 305 func AckNum(seq uint32) TransportChecker { 306 return func(t *testing.T, h header.Transport) { 307 t.Helper() 308 tcp, ok := h.(header.TCP) 309 if !ok { 310 return 311 } 312 313 if s := tcp.AckNumber(); s != seq { 314 t.Errorf("Bad ack number, got %v, want %v", s, seq) 315 } 316 } 317 } 318 319 // Window creates a checker that checks the tcp window. 320 func Window(window uint16) TransportChecker { 321 return func(t *testing.T, h header.Transport) { 322 tcp, ok := h.(header.TCP) 323 if !ok { 324 return 325 } 326 327 if w := tcp.WindowSize(); w != window { 328 t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) 329 } 330 } 331 } 332 333 // TCPFlags creates a checker that checks the tcp flags. 334 func TCPFlags(flags uint8) TransportChecker { 335 return func(t *testing.T, h header.Transport) { 336 t.Helper() 337 338 tcp, ok := h.(header.TCP) 339 if !ok { 340 return 341 } 342 343 if f := tcp.Flags(); f != flags { 344 t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) 345 } 346 } 347 } 348 349 // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the 350 // given mask, match the supplied flags. 351 func TCPFlagsMatch(flags, mask uint8) TransportChecker { 352 return func(t *testing.T, h header.Transport) { 353 tcp, ok := h.(header.TCP) 354 if !ok { 355 return 356 } 357 358 if f := tcp.Flags(); (f & mask) != (flags & mask) { 359 t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) 360 } 361 } 362 } 363 364 // TCPSynOptions creates a checker that checks the presence of TCP options in 365 // SYN segments. 366 // 367 // If wndscale is negative, the window scale option must not be present. 368 func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { 369 return func(t *testing.T, h header.Transport) { 370 tcp, ok := h.(header.TCP) 371 if !ok { 372 return 373 } 374 opts := tcp.Options() 375 limit := len(opts) 376 foundMSS := false 377 foundWS := false 378 foundTS := false 379 foundSACKPermitted := false 380 tsVal := uint32(0) 381 tsEcr := uint32(0) 382 for i := 0; i < limit; { 383 switch opts[i] { 384 case header.TCPOptionEOL: 385 i = limit 386 case header.TCPOptionNOP: 387 i++ 388 case header.TCPOptionMSS: 389 v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) 390 if wantOpts.MSS != v { 391 t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) 392 } 393 foundMSS = true 394 i += 4 395 case header.TCPOptionWS: 396 if wantOpts.WS < 0 { 397 t.Error("WS present when it shouldn't be") 398 } 399 v := int(opts[i+2]) 400 if v != wantOpts.WS { 401 t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) 402 } 403 foundWS = true 404 i += 3 405 case header.TCPOptionTS: 406 if i+9 >= limit { 407 t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) 408 } 409 if opts[i+1] != 10 { 410 t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) 411 } 412 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 413 tsEcr = uint32(0) 414 if tcp.Flags()&header.TCPFlagAck != 0 { 415 // If the syn is an SYN-ACK then read 416 // the tsEcr value as well. 417 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 418 } 419 foundTS = true 420 i += 10 421 case header.TCPOptionSACKPermitted: 422 if i+1 >= limit { 423 t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) 424 } 425 if opts[i+1] != 2 { 426 t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) 427 } 428 foundSACKPermitted = true 429 i += 2 430 431 default: 432 i += int(opts[i+1]) 433 } 434 } 435 436 if !foundMSS { 437 t.Errorf("MSS option not found. Options: %x", opts) 438 } 439 440 if !foundWS && wantOpts.WS >= 0 { 441 t.Errorf("WS option not found. Options: %x", opts) 442 } 443 if wantOpts.TS && !foundTS { 444 t.Errorf("TS option not found. Options: %x", opts) 445 } 446 if foundTS && tsVal == 0 { 447 t.Error("TS option specified but the timestamp value is zero") 448 } 449 if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { 450 t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) 451 } 452 if wantOpts.SACKPermitted && !foundSACKPermitted { 453 t.Errorf("SACKPermitted option not found. Options: %x", opts) 454 } 455 } 456 } 457 458 // TCPTimestampChecker creates a checker that validates that a TCP segment has a 459 // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and 460 // wantTSEcr values with those in the TCP segment (if present). 461 // 462 // If wantTSVal or wantTSEcr is zero then the corresponding comparison is 463 // skipped. 464 func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { 465 return func(t *testing.T, h header.Transport) { 466 tcp, ok := h.(header.TCP) 467 if !ok { 468 return 469 } 470 opts := []byte(tcp.Options()) 471 limit := len(opts) 472 foundTS := false 473 tsVal := uint32(0) 474 tsEcr := uint32(0) 475 for i := 0; i < limit; { 476 switch opts[i] { 477 case header.TCPOptionEOL: 478 i = limit 479 case header.TCPOptionNOP: 480 i++ 481 case header.TCPOptionTS: 482 if i+9 >= limit { 483 t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) 484 } 485 if opts[i+1] != 10 { 486 t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) 487 } 488 tsVal = binary.BigEndian.Uint32(opts[i+2:]) 489 tsEcr = binary.BigEndian.Uint32(opts[i+6:]) 490 foundTS = true 491 i += 10 492 default: 493 // We don't recognize this option, just skip over it. 494 if i+2 > limit { 495 return 496 } 497 l := int(opts[i+1]) 498 if i < 2 || i+l > limit { 499 return 500 } 501 i += l 502 } 503 } 504 505 if wantTS != foundTS { 506 t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) 507 } 508 if wantTS && wantTSVal != 0 && wantTSVal != tsVal { 509 t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) 510 } 511 if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { 512 t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) 513 } 514 } 515 } 516 517 // TCPNoSACKBlockChecker creates a checker that verifies that the segment does not 518 // contain any SACK blocks in the TCP options. 519 func TCPNoSACKBlockChecker() TransportChecker { 520 return TCPSACKBlockChecker(nil) 521 } 522 523 // TCPSACKBlockChecker creates a checker that verifies that the segment does 524 // contain the specified SACK blocks in the TCP options. 525 func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { 526 return func(t *testing.T, h header.Transport) { 527 t.Helper() 528 tcp, ok := h.(header.TCP) 529 if !ok { 530 return 531 } 532 var gotSACKBlocks []header.SACKBlock 533 534 opts := []byte(tcp.Options()) 535 limit := len(opts) 536 for i := 0; i < limit; { 537 switch opts[i] { 538 case header.TCPOptionEOL: 539 i = limit 540 case header.TCPOptionNOP: 541 i++ 542 case header.TCPOptionSACK: 543 if i+2 > limit { 544 // Malformed SACK block. 545 t.Errorf("malformed SACK option in options: %v", opts) 546 } 547 sackOptionLen := int(opts[i+1]) 548 if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { 549 // Malformed SACK block. 550 t.Errorf("malformed SACK option length in options: %v", opts) 551 } 552 numBlocks := sackOptionLen / 8 553 for j := 0; j < numBlocks; j++ { 554 start := binary.BigEndian.Uint32(opts[i+2+j*8:]) 555 end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) 556 gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ 557 Start: seqnum.Value(start), 558 End: seqnum.Value(end), 559 }) 560 } 561 i += sackOptionLen 562 default: 563 // We don't recognize this option, just skip over it. 564 if i+2 > limit { 565 break 566 } 567 l := int(opts[i+1]) 568 if l < 2 || i+l > limit { 569 break 570 } 571 i += l 572 } 573 } 574 575 if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { 576 t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) 577 } 578 } 579 } 580 581 // Payload creates a checker that checks the payload. 582 func Payload(want []byte) TransportChecker { 583 return func(t *testing.T, h header.Transport) { 584 if got := h.Payload(); !reflect.DeepEqual(got, want) { 585 t.Errorf("Wrong payload, got %v, want %v", got, want) 586 } 587 } 588 } 589 590 // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and 591 // potentially additional ICMPv4 header fields. 592 func ICMPv4(checkers ...TransportChecker) NetworkChecker { 593 return func(t *testing.T, h []header.Network) { 594 t.Helper() 595 596 last := h[len(h)-1] 597 598 if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { 599 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) 600 } 601 602 icmp := header.ICMPv4(last.Payload()) 603 for _, f := range checkers { 604 f(t, icmp) 605 } 606 if t.Failed() { 607 t.FailNow() 608 } 609 } 610 } 611 612 // ICMPv4Type creates a checker that checks the ICMPv4 Type field. 613 func ICMPv4Type(want header.ICMPv4Type) TransportChecker { 614 return func(t *testing.T, h header.Transport) { 615 t.Helper() 616 icmpv4, ok := h.(header.ICMPv4) 617 if !ok { 618 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) 619 } 620 if got := icmpv4.Type(); got != want { 621 t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) 622 } 623 } 624 } 625 626 // ICMPv4Code creates a checker that checks the ICMPv4 Code field. 627 func ICMPv4Code(want byte) TransportChecker { 628 return func(t *testing.T, h header.Transport) { 629 t.Helper() 630 icmpv4, ok := h.(header.ICMPv4) 631 if !ok { 632 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) 633 } 634 if got := icmpv4.Code(); got != want { 635 t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) 636 } 637 } 638 } 639 640 // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and 641 // potentially additional ICMPv6 header fields. 642 func ICMPv6(checkers ...TransportChecker) NetworkChecker { 643 return func(t *testing.T, h []header.Network) { 644 t.Helper() 645 646 last := h[len(h)-1] 647 648 if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { 649 t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) 650 } 651 652 icmp := header.ICMPv6(last.Payload()) 653 for _, f := range checkers { 654 f(t, icmp) 655 } 656 if t.Failed() { 657 t.FailNow() 658 } 659 } 660 } 661 662 // ICMPv6Type creates a checker that checks the ICMPv6 Type field. 663 func ICMPv6Type(want header.ICMPv6Type) TransportChecker { 664 return func(t *testing.T, h header.Transport) { 665 t.Helper() 666 icmpv6, ok := h.(header.ICMPv6) 667 if !ok { 668 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) 669 } 670 if got := icmpv6.Type(); got != want { 671 t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) 672 } 673 } 674 } 675 676 // ICMPv6Code creates a checker that checks the ICMPv6 Code field. 677 func ICMPv6Code(want byte) TransportChecker { 678 return func(t *testing.T, h header.Transport) { 679 t.Helper() 680 icmpv6, ok := h.(header.ICMPv6) 681 if !ok { 682 t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) 683 } 684 if got := icmpv6.Code(); got != want { 685 t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) 686 } 687 } 688 }