github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/checker/checker.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package checker provides helper functions to check networking packets for
    16  // validity.
    17  package checker
    18  
    19  import (
    20  	"encoding/binary"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/FlowerWrong/netstack/tcpip"
    25  	"github.com/FlowerWrong/netstack/tcpip/header"
    26  	"github.com/FlowerWrong/netstack/tcpip/seqnum"
    27  )
    28  
    29  // NetworkChecker is a function to check a property of a network packet.
    30  type NetworkChecker func(*testing.T, []header.Network)
    31  
    32  // TransportChecker is a function to check a property of a transport packet.
    33  type TransportChecker func(*testing.T, header.Transport)
    34  
    35  // IPv4 checks the validity and properties of the given IPv4 packet. It is
    36  // expected to be used in conjunction with other network checkers for specific
    37  // properties. For example, to check the source and destination address, one
    38  // would call:
    39  //
    40  // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
    41  func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
    42  	t.Helper()
    43  
    44  	ipv4 := header.IPv4(b)
    45  
    46  	if !ipv4.IsValid(len(b)) {
    47  		t.Error("Not a valid IPv4 packet")
    48  	}
    49  
    50  	xsum := ipv4.CalculateChecksum()
    51  	if xsum != 0 && xsum != 0xffff {
    52  		t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
    53  	}
    54  
    55  	for _, f := range checkers {
    56  		f(t, []header.Network{ipv4})
    57  	}
    58  	if t.Failed() {
    59  		t.FailNow()
    60  	}
    61  }
    62  
    63  // IPv6 checks the validity and properties of the given IPv6 packet. The usage
    64  // is similar to IPv4.
    65  func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
    66  	t.Helper()
    67  
    68  	ipv6 := header.IPv6(b)
    69  	if !ipv6.IsValid(len(b)) {
    70  		t.Error("Not a valid IPv6 packet")
    71  	}
    72  
    73  	for _, f := range checkers {
    74  		f(t, []header.Network{ipv6})
    75  	}
    76  	if t.Failed() {
    77  		t.FailNow()
    78  	}
    79  }
    80  
    81  // SrcAddr creates a checker that checks the source address.
    82  func SrcAddr(addr tcpip.Address) NetworkChecker {
    83  	return func(t *testing.T, h []header.Network) {
    84  		t.Helper()
    85  
    86  		if a := h[0].SourceAddress(); a != addr {
    87  			t.Errorf("Bad source address, got %v, want %v", a, addr)
    88  		}
    89  	}
    90  }
    91  
    92  // DstAddr creates a checker that checks the destination address.
    93  func DstAddr(addr tcpip.Address) NetworkChecker {
    94  	return func(t *testing.T, h []header.Network) {
    95  		t.Helper()
    96  
    97  		if a := h[0].DestinationAddress(); a != addr {
    98  			t.Errorf("Bad destination address, got %v, want %v", a, addr)
    99  		}
   100  	}
   101  }
   102  
   103  // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
   104  func TTL(ttl uint8) NetworkChecker {
   105  	return func(t *testing.T, h []header.Network) {
   106  		var v uint8
   107  		switch ip := h[0].(type) {
   108  		case header.IPv4:
   109  			v = ip.TTL()
   110  		case header.IPv6:
   111  			v = ip.HopLimit()
   112  		}
   113  		if v != ttl {
   114  			t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
   115  		}
   116  	}
   117  }
   118  
   119  // PayloadLen creates a checker that checks the payload length.
   120  func PayloadLen(plen int) NetworkChecker {
   121  	return func(t *testing.T, h []header.Network) {
   122  		t.Helper()
   123  
   124  		if l := len(h[0].Payload()); l != plen {
   125  			t.Errorf("Bad payload length, got %v, want %v", l, plen)
   126  		}
   127  	}
   128  }
   129  
   130  // FragmentOffset creates a checker that checks the FragmentOffset field.
   131  func FragmentOffset(offset uint16) NetworkChecker {
   132  	return func(t *testing.T, h []header.Network) {
   133  		t.Helper()
   134  
   135  		// We only do this of IPv4 for now.
   136  		switch ip := h[0].(type) {
   137  		case header.IPv4:
   138  			if v := ip.FragmentOffset(); v != offset {
   139  				t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
   140  			}
   141  		}
   142  	}
   143  }
   144  
   145  // FragmentFlags creates a checker that checks the fragment flags field.
   146  func FragmentFlags(flags uint8) NetworkChecker {
   147  	return func(t *testing.T, h []header.Network) {
   148  		t.Helper()
   149  
   150  		// We only do this of IPv4 for now.
   151  		switch ip := h[0].(type) {
   152  		case header.IPv4:
   153  			if v := ip.Flags(); v != flags {
   154  				t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
   155  			}
   156  		}
   157  	}
   158  }
   159  
   160  // TOS creates a checker that checks the TOS field.
   161  func TOS(tos uint8, label uint32) NetworkChecker {
   162  	return func(t *testing.T, h []header.Network) {
   163  		t.Helper()
   164  
   165  		if v, l := h[0].TOS(); v != tos || l != label {
   166  			t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
   167  		}
   168  	}
   169  }
   170  
   171  // Raw creates a checker that checks the bytes of payload.
   172  // The checker always checks the payload of the last network header.
   173  // For instance, in case of IPv6 fragments, the payload that will be checked
   174  // is the one containing the actual data that the packet is carrying, without
   175  // the bytes added by the IPv6 fragmentation.
   176  func Raw(want []byte) NetworkChecker {
   177  	return func(t *testing.T, h []header.Network) {
   178  		t.Helper()
   179  
   180  		if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
   181  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   182  		}
   183  	}
   184  }
   185  
   186  // IPv6Fragment creates a checker that validates an IPv6 fragment.
   187  func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
   188  	return func(t *testing.T, h []header.Network) {
   189  		t.Helper()
   190  
   191  		if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
   192  			t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
   193  		}
   194  
   195  		ipv6Frag := header.IPv6Fragment(h[0].Payload())
   196  		if !ipv6Frag.IsValid() {
   197  			t.Error("Not a valid IPv6 fragment")
   198  		}
   199  
   200  		for _, f := range checkers {
   201  			f(t, []header.Network{h[0], ipv6Frag})
   202  		}
   203  		if t.Failed() {
   204  			t.FailNow()
   205  		}
   206  	}
   207  }
   208  
   209  // TCP creates a checker that checks that the transport protocol is TCP and
   210  // potentially additional transport header fields.
   211  func TCP(checkers ...TransportChecker) NetworkChecker {
   212  	return func(t *testing.T, h []header.Network) {
   213  		t.Helper()
   214  
   215  		first := h[0]
   216  		last := h[len(h)-1]
   217  
   218  		if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
   219  			t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
   220  		}
   221  
   222  		// Verify the checksum.
   223  		tcp := header.TCP(last.Payload())
   224  		l := uint16(len(tcp))
   225  
   226  		xsum := header.Checksum([]byte(first.SourceAddress()), 0)
   227  		xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
   228  		xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
   229  		xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
   230  		xsum = header.Checksum(tcp, xsum)
   231  
   232  		if xsum != 0 && xsum != 0xffff {
   233  			t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
   234  		}
   235  
   236  		// Run the transport checkers.
   237  		for _, f := range checkers {
   238  			f(t, tcp)
   239  		}
   240  		if t.Failed() {
   241  			t.FailNow()
   242  		}
   243  	}
   244  }
   245  
   246  // UDP creates a checker that checks that the transport protocol is UDP and
   247  // potentially additional transport header fields.
   248  func UDP(checkers ...TransportChecker) NetworkChecker {
   249  	return func(t *testing.T, h []header.Network) {
   250  		t.Helper()
   251  
   252  		last := h[len(h)-1]
   253  
   254  		if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
   255  			t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
   256  		}
   257  
   258  		udp := header.UDP(last.Payload())
   259  		for _, f := range checkers {
   260  			f(t, udp)
   261  		}
   262  		if t.Failed() {
   263  			t.FailNow()
   264  		}
   265  	}
   266  }
   267  
   268  // SrcPort creates a checker that checks the source port.
   269  func SrcPort(port uint16) TransportChecker {
   270  	return func(t *testing.T, h header.Transport) {
   271  		t.Helper()
   272  
   273  		if p := h.SourcePort(); p != port {
   274  			t.Errorf("Bad source port, got %v, want %v", p, port)
   275  		}
   276  	}
   277  }
   278  
   279  // DstPort creates a checker that checks the destination port.
   280  func DstPort(port uint16) TransportChecker {
   281  	return func(t *testing.T, h header.Transport) {
   282  		if p := h.DestinationPort(); p != port {
   283  			t.Errorf("Bad destination port, got %v, want %v", p, port)
   284  		}
   285  	}
   286  }
   287  
   288  // SeqNum creates a checker that checks the sequence number.
   289  func SeqNum(seq uint32) TransportChecker {
   290  	return func(t *testing.T, h header.Transport) {
   291  		t.Helper()
   292  
   293  		tcp, ok := h.(header.TCP)
   294  		if !ok {
   295  			return
   296  		}
   297  
   298  		if s := tcp.SequenceNumber(); s != seq {
   299  			t.Errorf("Bad sequence number, got %v, want %v", s, seq)
   300  		}
   301  	}
   302  }
   303  
   304  // AckNum creates a checker that checks the ack number.
   305  func AckNum(seq uint32) TransportChecker {
   306  	return func(t *testing.T, h header.Transport) {
   307  		t.Helper()
   308  		tcp, ok := h.(header.TCP)
   309  		if !ok {
   310  			return
   311  		}
   312  
   313  		if s := tcp.AckNumber(); s != seq {
   314  			t.Errorf("Bad ack number, got %v, want %v", s, seq)
   315  		}
   316  	}
   317  }
   318  
   319  // Window creates a checker that checks the tcp window.
   320  func Window(window uint16) TransportChecker {
   321  	return func(t *testing.T, h header.Transport) {
   322  		tcp, ok := h.(header.TCP)
   323  		if !ok {
   324  			return
   325  		}
   326  
   327  		if w := tcp.WindowSize(); w != window {
   328  			t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
   329  		}
   330  	}
   331  }
   332  
   333  // TCPFlags creates a checker that checks the tcp flags.
   334  func TCPFlags(flags uint8) TransportChecker {
   335  	return func(t *testing.T, h header.Transport) {
   336  		t.Helper()
   337  
   338  		tcp, ok := h.(header.TCP)
   339  		if !ok {
   340  			return
   341  		}
   342  
   343  		if f := tcp.Flags(); f != flags {
   344  			t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
   345  		}
   346  	}
   347  }
   348  
   349  // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
   350  // given mask, match the supplied flags.
   351  func TCPFlagsMatch(flags, mask uint8) TransportChecker {
   352  	return func(t *testing.T, h header.Transport) {
   353  		tcp, ok := h.(header.TCP)
   354  		if !ok {
   355  			return
   356  		}
   357  
   358  		if f := tcp.Flags(); (f & mask) != (flags & mask) {
   359  			t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
   360  		}
   361  	}
   362  }
   363  
   364  // TCPSynOptions creates a checker that checks the presence of TCP options in
   365  // SYN segments.
   366  //
   367  // If wndscale is negative, the window scale option must not be present.
   368  func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
   369  	return func(t *testing.T, h header.Transport) {
   370  		tcp, ok := h.(header.TCP)
   371  		if !ok {
   372  			return
   373  		}
   374  		opts := tcp.Options()
   375  		limit := len(opts)
   376  		foundMSS := false
   377  		foundWS := false
   378  		foundTS := false
   379  		foundSACKPermitted := false
   380  		tsVal := uint32(0)
   381  		tsEcr := uint32(0)
   382  		for i := 0; i < limit; {
   383  			switch opts[i] {
   384  			case header.TCPOptionEOL:
   385  				i = limit
   386  			case header.TCPOptionNOP:
   387  				i++
   388  			case header.TCPOptionMSS:
   389  				v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
   390  				if wantOpts.MSS != v {
   391  					t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
   392  				}
   393  				foundMSS = true
   394  				i += 4
   395  			case header.TCPOptionWS:
   396  				if wantOpts.WS < 0 {
   397  					t.Error("WS present when it shouldn't be")
   398  				}
   399  				v := int(opts[i+2])
   400  				if v != wantOpts.WS {
   401  					t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
   402  				}
   403  				foundWS = true
   404  				i += 3
   405  			case header.TCPOptionTS:
   406  				if i+9 >= limit {
   407  					t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
   408  				}
   409  				if opts[i+1] != 10 {
   410  					t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
   411  				}
   412  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   413  				tsEcr = uint32(0)
   414  				if tcp.Flags()&header.TCPFlagAck != 0 {
   415  					// If the syn is an SYN-ACK then read
   416  					// the tsEcr value as well.
   417  					tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   418  				}
   419  				foundTS = true
   420  				i += 10
   421  			case header.TCPOptionSACKPermitted:
   422  				if i+1 >= limit {
   423  					t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
   424  				}
   425  				if opts[i+1] != 2 {
   426  					t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
   427  				}
   428  				foundSACKPermitted = true
   429  				i += 2
   430  
   431  			default:
   432  				i += int(opts[i+1])
   433  			}
   434  		}
   435  
   436  		if !foundMSS {
   437  			t.Errorf("MSS option not found. Options: %x", opts)
   438  		}
   439  
   440  		if !foundWS && wantOpts.WS >= 0 {
   441  			t.Errorf("WS option not found. Options: %x", opts)
   442  		}
   443  		if wantOpts.TS && !foundTS {
   444  			t.Errorf("TS option not found. Options: %x", opts)
   445  		}
   446  		if foundTS && tsVal == 0 {
   447  			t.Error("TS option specified but the timestamp value is zero")
   448  		}
   449  		if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
   450  			t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
   451  		}
   452  		if wantOpts.SACKPermitted && !foundSACKPermitted {
   453  			t.Errorf("SACKPermitted option not found. Options: %x", opts)
   454  		}
   455  	}
   456  }
   457  
   458  // TCPTimestampChecker creates a checker that validates that a TCP segment has a
   459  // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
   460  // wantTSEcr values with those in the TCP segment (if present).
   461  //
   462  // If wantTSVal or wantTSEcr is zero then the corresponding comparison is
   463  // skipped.
   464  func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
   465  	return func(t *testing.T, h header.Transport) {
   466  		tcp, ok := h.(header.TCP)
   467  		if !ok {
   468  			return
   469  		}
   470  		opts := []byte(tcp.Options())
   471  		limit := len(opts)
   472  		foundTS := false
   473  		tsVal := uint32(0)
   474  		tsEcr := uint32(0)
   475  		for i := 0; i < limit; {
   476  			switch opts[i] {
   477  			case header.TCPOptionEOL:
   478  				i = limit
   479  			case header.TCPOptionNOP:
   480  				i++
   481  			case header.TCPOptionTS:
   482  				if i+9 >= limit {
   483  					t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
   484  				}
   485  				if opts[i+1] != 10 {
   486  					t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
   487  				}
   488  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   489  				tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   490  				foundTS = true
   491  				i += 10
   492  			default:
   493  				// We don't recognize this option, just skip over it.
   494  				if i+2 > limit {
   495  					return
   496  				}
   497  				l := int(opts[i+1])
   498  				if i < 2 || i+l > limit {
   499  					return
   500  				}
   501  				i += l
   502  			}
   503  		}
   504  
   505  		if wantTS != foundTS {
   506  			t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
   507  		}
   508  		if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
   509  			t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
   510  		}
   511  		if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
   512  			t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
   513  		}
   514  	}
   515  }
   516  
   517  // TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
   518  // contain any SACK blocks in the TCP options.
   519  func TCPNoSACKBlockChecker() TransportChecker {
   520  	return TCPSACKBlockChecker(nil)
   521  }
   522  
   523  // TCPSACKBlockChecker creates a checker that verifies that the segment does
   524  // contain the specified SACK blocks in the TCP options.
   525  func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
   526  	return func(t *testing.T, h header.Transport) {
   527  		t.Helper()
   528  		tcp, ok := h.(header.TCP)
   529  		if !ok {
   530  			return
   531  		}
   532  		var gotSACKBlocks []header.SACKBlock
   533  
   534  		opts := []byte(tcp.Options())
   535  		limit := len(opts)
   536  		for i := 0; i < limit; {
   537  			switch opts[i] {
   538  			case header.TCPOptionEOL:
   539  				i = limit
   540  			case header.TCPOptionNOP:
   541  				i++
   542  			case header.TCPOptionSACK:
   543  				if i+2 > limit {
   544  					// Malformed SACK block.
   545  					t.Errorf("malformed SACK option in options: %v", opts)
   546  				}
   547  				sackOptionLen := int(opts[i+1])
   548  				if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
   549  					// Malformed SACK block.
   550  					t.Errorf("malformed SACK option length in options: %v", opts)
   551  				}
   552  				numBlocks := sackOptionLen / 8
   553  				for j := 0; j < numBlocks; j++ {
   554  					start := binary.BigEndian.Uint32(opts[i+2+j*8:])
   555  					end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
   556  					gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
   557  						Start: seqnum.Value(start),
   558  						End:   seqnum.Value(end),
   559  					})
   560  				}
   561  				i += sackOptionLen
   562  			default:
   563  				// We don't recognize this option, just skip over it.
   564  				if i+2 > limit {
   565  					break
   566  				}
   567  				l := int(opts[i+1])
   568  				if l < 2 || i+l > limit {
   569  					break
   570  				}
   571  				i += l
   572  			}
   573  		}
   574  
   575  		if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
   576  			t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
   577  		}
   578  	}
   579  }
   580  
   581  // Payload creates a checker that checks the payload.
   582  func Payload(want []byte) TransportChecker {
   583  	return func(t *testing.T, h header.Transport) {
   584  		if got := h.Payload(); !reflect.DeepEqual(got, want) {
   585  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   586  		}
   587  	}
   588  }
   589  
   590  // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
   591  // potentially additional ICMPv4 header fields.
   592  func ICMPv4(checkers ...TransportChecker) NetworkChecker {
   593  	return func(t *testing.T, h []header.Network) {
   594  		t.Helper()
   595  
   596  		last := h[len(h)-1]
   597  
   598  		if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
   599  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
   600  		}
   601  
   602  		icmp := header.ICMPv4(last.Payload())
   603  		for _, f := range checkers {
   604  			f(t, icmp)
   605  		}
   606  		if t.Failed() {
   607  			t.FailNow()
   608  		}
   609  	}
   610  }
   611  
   612  // ICMPv4Type creates a checker that checks the ICMPv4 Type field.
   613  func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
   614  	return func(t *testing.T, h header.Transport) {
   615  		t.Helper()
   616  		icmpv4, ok := h.(header.ICMPv4)
   617  		if !ok {
   618  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
   619  		}
   620  		if got := icmpv4.Type(); got != want {
   621  			t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
   622  		}
   623  	}
   624  }
   625  
   626  // ICMPv4Code creates a checker that checks the ICMPv4 Code field.
   627  func ICMPv4Code(want byte) TransportChecker {
   628  	return func(t *testing.T, h header.Transport) {
   629  		t.Helper()
   630  		icmpv4, ok := h.(header.ICMPv4)
   631  		if !ok {
   632  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
   633  		}
   634  		if got := icmpv4.Code(); got != want {
   635  			t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
   636  		}
   637  	}
   638  }
   639  
   640  // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
   641  // potentially additional ICMPv6 header fields.
   642  func ICMPv6(checkers ...TransportChecker) NetworkChecker {
   643  	return func(t *testing.T, h []header.Network) {
   644  		t.Helper()
   645  
   646  		last := h[len(h)-1]
   647  
   648  		if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
   649  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
   650  		}
   651  
   652  		icmp := header.ICMPv6(last.Payload())
   653  		for _, f := range checkers {
   654  			f(t, icmp)
   655  		}
   656  		if t.Failed() {
   657  			t.FailNow()
   658  		}
   659  	}
   660  }
   661  
   662  // ICMPv6Type creates a checker that checks the ICMPv6 Type field.
   663  func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
   664  	return func(t *testing.T, h header.Transport) {
   665  		t.Helper()
   666  		icmpv6, ok := h.(header.ICMPv6)
   667  		if !ok {
   668  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
   669  		}
   670  		if got := icmpv6.Type(); got != want {
   671  			t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
   672  		}
   673  	}
   674  }
   675  
   676  // ICMPv6Code creates a checker that checks the ICMPv6 Code field.
   677  func ICMPv6Code(want byte) TransportChecker {
   678  	return func(t *testing.T, h header.Transport) {
   679  		t.Helper()
   680  		icmpv6, ok := h.(header.ICMPv6)
   681  		if !ok {
   682  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
   683  		}
   684  		if got := icmpv6.Code(); got != want {
   685  			t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
   686  		}
   687  	}
   688  }