github.com/lightlus/netstack@v1.2.0/tcpip/checker/checker.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package checker provides helper functions to check networking packets for
    16  // validity.
    17  package checker
    18  
    19  import (
    20  	"encoding/binary"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/lightlus/netstack/tcpip"
    25  	"github.com/lightlus/netstack/tcpip/buffer"
    26  	"github.com/lightlus/netstack/tcpip/header"
    27  	"github.com/lightlus/netstack/tcpip/seqnum"
    28  )
    29  
    30  // NetworkChecker is a function to check a property of a network packet.
    31  type NetworkChecker func(*testing.T, []header.Network)
    32  
    33  // TransportChecker is a function to check a property of a transport packet.
    34  type TransportChecker func(*testing.T, header.Transport)
    35  
    36  // IPv4 checks the validity and properties of the given IPv4 packet. It is
    37  // expected to be used in conjunction with other network checkers for specific
    38  // properties. For example, to check the source and destination address, one
    39  // would call:
    40  //
    41  // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
    42  func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
    43  	t.Helper()
    44  
    45  	ipv4 := header.IPv4(b)
    46  
    47  	if !ipv4.IsValid(len(b)) {
    48  		t.Error("Not a valid IPv4 packet")
    49  	}
    50  
    51  	xsum := ipv4.CalculateChecksum()
    52  	if xsum != 0 && xsum != 0xffff {
    53  		t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
    54  	}
    55  
    56  	for _, f := range checkers {
    57  		f(t, []header.Network{ipv4})
    58  	}
    59  	if t.Failed() {
    60  		t.FailNow()
    61  	}
    62  }
    63  
    64  // IPv6 checks the validity and properties of the given IPv6 packet. The usage
    65  // is similar to IPv4.
    66  func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
    67  	t.Helper()
    68  
    69  	ipv6 := header.IPv6(b)
    70  	if !ipv6.IsValid(len(b)) {
    71  		t.Error("Not a valid IPv6 packet")
    72  	}
    73  
    74  	for _, f := range checkers {
    75  		f(t, []header.Network{ipv6})
    76  	}
    77  	if t.Failed() {
    78  		t.FailNow()
    79  	}
    80  }
    81  
    82  // SrcAddr creates a checker that checks the source address.
    83  func SrcAddr(addr tcpip.Address) NetworkChecker {
    84  	return func(t *testing.T, h []header.Network) {
    85  		t.Helper()
    86  
    87  		if a := h[0].SourceAddress(); a != addr {
    88  			t.Errorf("Bad source address, got %v, want %v", a, addr)
    89  		}
    90  	}
    91  }
    92  
    93  // DstAddr creates a checker that checks the destination address.
    94  func DstAddr(addr tcpip.Address) NetworkChecker {
    95  	return func(t *testing.T, h []header.Network) {
    96  		t.Helper()
    97  
    98  		if a := h[0].DestinationAddress(); a != addr {
    99  			t.Errorf("Bad destination address, got %v, want %v", a, addr)
   100  		}
   101  	}
   102  }
   103  
   104  // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
   105  func TTL(ttl uint8) NetworkChecker {
   106  	return func(t *testing.T, h []header.Network) {
   107  		var v uint8
   108  		switch ip := h[0].(type) {
   109  		case header.IPv4:
   110  			v = ip.TTL()
   111  		case header.IPv6:
   112  			v = ip.HopLimit()
   113  		}
   114  		if v != ttl {
   115  			t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
   116  		}
   117  	}
   118  }
   119  
   120  // PayloadLen creates a checker that checks the payload length.
   121  func PayloadLen(plen int) NetworkChecker {
   122  	return func(t *testing.T, h []header.Network) {
   123  		t.Helper()
   124  
   125  		if l := len(h[0].Payload()); l != plen {
   126  			t.Errorf("Bad payload length, got %v, want %v", l, plen)
   127  		}
   128  	}
   129  }
   130  
   131  // FragmentOffset creates a checker that checks the FragmentOffset field.
   132  func FragmentOffset(offset uint16) NetworkChecker {
   133  	return func(t *testing.T, h []header.Network) {
   134  		t.Helper()
   135  
   136  		// We only do this of IPv4 for now.
   137  		switch ip := h[0].(type) {
   138  		case header.IPv4:
   139  			if v := ip.FragmentOffset(); v != offset {
   140  				t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
   141  			}
   142  		}
   143  	}
   144  }
   145  
   146  // FragmentFlags creates a checker that checks the fragment flags field.
   147  func FragmentFlags(flags uint8) NetworkChecker {
   148  	return func(t *testing.T, h []header.Network) {
   149  		t.Helper()
   150  
   151  		// We only do this of IPv4 for now.
   152  		switch ip := h[0].(type) {
   153  		case header.IPv4:
   154  			if v := ip.Flags(); v != flags {
   155  				t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
   156  			}
   157  		}
   158  	}
   159  }
   160  
   161  // TOS creates a checker that checks the TOS field.
   162  func TOS(tos uint8, label uint32) NetworkChecker {
   163  	return func(t *testing.T, h []header.Network) {
   164  		t.Helper()
   165  
   166  		if v, l := h[0].TOS(); v != tos || l != label {
   167  			t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
   168  		}
   169  	}
   170  }
   171  
   172  // Raw creates a checker that checks the bytes of payload.
   173  // The checker always checks the payload of the last network header.
   174  // For instance, in case of IPv6 fragments, the payload that will be checked
   175  // is the one containing the actual data that the packet is carrying, without
   176  // the bytes added by the IPv6 fragmentation.
   177  func Raw(want []byte) NetworkChecker {
   178  	return func(t *testing.T, h []header.Network) {
   179  		t.Helper()
   180  
   181  		if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
   182  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   183  		}
   184  	}
   185  }
   186  
   187  // IPv6Fragment creates a checker that validates an IPv6 fragment.
   188  func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
   189  	return func(t *testing.T, h []header.Network) {
   190  		t.Helper()
   191  
   192  		if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
   193  			t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
   194  		}
   195  
   196  		ipv6Frag := header.IPv6Fragment(h[0].Payload())
   197  		if !ipv6Frag.IsValid() {
   198  			t.Error("Not a valid IPv6 fragment")
   199  		}
   200  
   201  		for _, f := range checkers {
   202  			f(t, []header.Network{h[0], ipv6Frag})
   203  		}
   204  		if t.Failed() {
   205  			t.FailNow()
   206  		}
   207  	}
   208  }
   209  
   210  // TCP creates a checker that checks that the transport protocol is TCP and
   211  // potentially additional transport header fields.
   212  func TCP(checkers ...TransportChecker) NetworkChecker {
   213  	return func(t *testing.T, h []header.Network) {
   214  		t.Helper()
   215  
   216  		first := h[0]
   217  		last := h[len(h)-1]
   218  
   219  		if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
   220  			t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
   221  		}
   222  
   223  		// Verify the checksum.
   224  		tcp := header.TCP(last.Payload())
   225  		l := uint16(len(tcp))
   226  
   227  		xsum := header.Checksum([]byte(first.SourceAddress()), 0)
   228  		xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
   229  		xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
   230  		xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
   231  		xsum = header.Checksum(tcp, xsum)
   232  
   233  		if xsum != 0 && xsum != 0xffff {
   234  			t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
   235  		}
   236  
   237  		// Run the transport checkers.
   238  		for _, f := range checkers {
   239  			f(t, tcp)
   240  		}
   241  		if t.Failed() {
   242  			t.FailNow()
   243  		}
   244  	}
   245  }
   246  
   247  // UDP creates a checker that checks that the transport protocol is UDP and
   248  // potentially additional transport header fields.
   249  func UDP(checkers ...TransportChecker) NetworkChecker {
   250  	return func(t *testing.T, h []header.Network) {
   251  		t.Helper()
   252  
   253  		last := h[len(h)-1]
   254  
   255  		if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
   256  			t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
   257  		}
   258  
   259  		udp := header.UDP(last.Payload())
   260  		for _, f := range checkers {
   261  			f(t, udp)
   262  		}
   263  		if t.Failed() {
   264  			t.FailNow()
   265  		}
   266  	}
   267  }
   268  
   269  // SrcPort creates a checker that checks the source port.
   270  func SrcPort(port uint16) TransportChecker {
   271  	return func(t *testing.T, h header.Transport) {
   272  		t.Helper()
   273  
   274  		if p := h.SourcePort(); p != port {
   275  			t.Errorf("Bad source port, got %v, want %v", p, port)
   276  		}
   277  	}
   278  }
   279  
   280  // DstPort creates a checker that checks the destination port.
   281  func DstPort(port uint16) TransportChecker {
   282  	return func(t *testing.T, h header.Transport) {
   283  		if p := h.DestinationPort(); p != port {
   284  			t.Errorf("Bad destination port, got %v, want %v", p, port)
   285  		}
   286  	}
   287  }
   288  
   289  // SeqNum creates a checker that checks the sequence number.
   290  func SeqNum(seq uint32) TransportChecker {
   291  	return func(t *testing.T, h header.Transport) {
   292  		t.Helper()
   293  
   294  		tcp, ok := h.(header.TCP)
   295  		if !ok {
   296  			return
   297  		}
   298  
   299  		if s := tcp.SequenceNumber(); s != seq {
   300  			t.Errorf("Bad sequence number, got %v, want %v", s, seq)
   301  		}
   302  	}
   303  }
   304  
   305  // AckNum creates a checker that checks the ack number.
   306  func AckNum(seq uint32) TransportChecker {
   307  	return func(t *testing.T, h header.Transport) {
   308  		t.Helper()
   309  		tcp, ok := h.(header.TCP)
   310  		if !ok {
   311  			return
   312  		}
   313  
   314  		if s := tcp.AckNumber(); s != seq {
   315  			t.Errorf("Bad ack number, got %v, want %v", s, seq)
   316  		}
   317  	}
   318  }
   319  
   320  // Window creates a checker that checks the tcp window.
   321  func Window(window uint16) TransportChecker {
   322  	return func(t *testing.T, h header.Transport) {
   323  		tcp, ok := h.(header.TCP)
   324  		if !ok {
   325  			return
   326  		}
   327  
   328  		if w := tcp.WindowSize(); w != window {
   329  			t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
   330  		}
   331  	}
   332  }
   333  
   334  // TCPFlags creates a checker that checks the tcp flags.
   335  func TCPFlags(flags uint8) TransportChecker {
   336  	return func(t *testing.T, h header.Transport) {
   337  		t.Helper()
   338  
   339  		tcp, ok := h.(header.TCP)
   340  		if !ok {
   341  			return
   342  		}
   343  
   344  		if f := tcp.Flags(); f != flags {
   345  			t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
   346  		}
   347  	}
   348  }
   349  
   350  // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
   351  // given mask, match the supplied flags.
   352  func TCPFlagsMatch(flags, mask uint8) TransportChecker {
   353  	return func(t *testing.T, h header.Transport) {
   354  		tcp, ok := h.(header.TCP)
   355  		if !ok {
   356  			return
   357  		}
   358  
   359  		if f := tcp.Flags(); (f & mask) != (flags & mask) {
   360  			t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
   361  		}
   362  	}
   363  }
   364  
   365  // TCPSynOptions creates a checker that checks the presence of TCP options in
   366  // SYN segments.
   367  //
   368  // If wndscale is negative, the window scale option must not be present.
   369  func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
   370  	return func(t *testing.T, h header.Transport) {
   371  		tcp, ok := h.(header.TCP)
   372  		if !ok {
   373  			return
   374  		}
   375  		opts := tcp.Options()
   376  		limit := len(opts)
   377  		foundMSS := false
   378  		foundWS := false
   379  		foundTS := false
   380  		foundSACKPermitted := false
   381  		tsVal := uint32(0)
   382  		tsEcr := uint32(0)
   383  		for i := 0; i < limit; {
   384  			switch opts[i] {
   385  			case header.TCPOptionEOL:
   386  				i = limit
   387  			case header.TCPOptionNOP:
   388  				i++
   389  			case header.TCPOptionMSS:
   390  				v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
   391  				if wantOpts.MSS != v {
   392  					t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
   393  				}
   394  				foundMSS = true
   395  				i += 4
   396  			case header.TCPOptionWS:
   397  				if wantOpts.WS < 0 {
   398  					t.Error("WS present when it shouldn't be")
   399  				}
   400  				v := int(opts[i+2])
   401  				if v != wantOpts.WS {
   402  					t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
   403  				}
   404  				foundWS = true
   405  				i += 3
   406  			case header.TCPOptionTS:
   407  				if i+9 >= limit {
   408  					t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
   409  				}
   410  				if opts[i+1] != 10 {
   411  					t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
   412  				}
   413  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   414  				tsEcr = uint32(0)
   415  				if tcp.Flags()&header.TCPFlagAck != 0 {
   416  					// If the syn is an SYN-ACK then read
   417  					// the tsEcr value as well.
   418  					tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   419  				}
   420  				foundTS = true
   421  				i += 10
   422  			case header.TCPOptionSACKPermitted:
   423  				if i+1 >= limit {
   424  					t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
   425  				}
   426  				if opts[i+1] != 2 {
   427  					t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
   428  				}
   429  				foundSACKPermitted = true
   430  				i += 2
   431  
   432  			default:
   433  				i += int(opts[i+1])
   434  			}
   435  		}
   436  
   437  		if !foundMSS {
   438  			t.Errorf("MSS option not found. Options: %x", opts)
   439  		}
   440  
   441  		if !foundWS && wantOpts.WS >= 0 {
   442  			t.Errorf("WS option not found. Options: %x", opts)
   443  		}
   444  		if wantOpts.TS && !foundTS {
   445  			t.Errorf("TS option not found. Options: %x", opts)
   446  		}
   447  		if foundTS && tsVal == 0 {
   448  			t.Error("TS option specified but the timestamp value is zero")
   449  		}
   450  		if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
   451  			t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
   452  		}
   453  		if wantOpts.SACKPermitted && !foundSACKPermitted {
   454  			t.Errorf("SACKPermitted option not found. Options: %x", opts)
   455  		}
   456  	}
   457  }
   458  
   459  // TCPTimestampChecker creates a checker that validates that a TCP segment has a
   460  // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
   461  // wantTSEcr values with those in the TCP segment (if present).
   462  //
   463  // If wantTSVal or wantTSEcr is zero then the corresponding comparison is
   464  // skipped.
   465  func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
   466  	return func(t *testing.T, h header.Transport) {
   467  		tcp, ok := h.(header.TCP)
   468  		if !ok {
   469  			return
   470  		}
   471  		opts := []byte(tcp.Options())
   472  		limit := len(opts)
   473  		foundTS := false
   474  		tsVal := uint32(0)
   475  		tsEcr := uint32(0)
   476  		for i := 0; i < limit; {
   477  			switch opts[i] {
   478  			case header.TCPOptionEOL:
   479  				i = limit
   480  			case header.TCPOptionNOP:
   481  				i++
   482  			case header.TCPOptionTS:
   483  				if i+9 >= limit {
   484  					t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
   485  				}
   486  				if opts[i+1] != 10 {
   487  					t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
   488  				}
   489  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   490  				tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   491  				foundTS = true
   492  				i += 10
   493  			default:
   494  				// We don't recognize this option, just skip over it.
   495  				if i+2 > limit {
   496  					return
   497  				}
   498  				l := int(opts[i+1])
   499  				if i < 2 || i+l > limit {
   500  					return
   501  				}
   502  				i += l
   503  			}
   504  		}
   505  
   506  		if wantTS != foundTS {
   507  			t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
   508  		}
   509  		if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
   510  			t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
   511  		}
   512  		if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
   513  			t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
   514  		}
   515  	}
   516  }
   517  
   518  // TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
   519  // contain any SACK blocks in the TCP options.
   520  func TCPNoSACKBlockChecker() TransportChecker {
   521  	return TCPSACKBlockChecker(nil)
   522  }
   523  
   524  // TCPSACKBlockChecker creates a checker that verifies that the segment does
   525  // contain the specified SACK blocks in the TCP options.
   526  func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
   527  	return func(t *testing.T, h header.Transport) {
   528  		t.Helper()
   529  		tcp, ok := h.(header.TCP)
   530  		if !ok {
   531  			return
   532  		}
   533  		var gotSACKBlocks []header.SACKBlock
   534  
   535  		opts := []byte(tcp.Options())
   536  		limit := len(opts)
   537  		for i := 0; i < limit; {
   538  			switch opts[i] {
   539  			case header.TCPOptionEOL:
   540  				i = limit
   541  			case header.TCPOptionNOP:
   542  				i++
   543  			case header.TCPOptionSACK:
   544  				if i+2 > limit {
   545  					// Malformed SACK block.
   546  					t.Errorf("malformed SACK option in options: %v", opts)
   547  				}
   548  				sackOptionLen := int(opts[i+1])
   549  				if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
   550  					// Malformed SACK block.
   551  					t.Errorf("malformed SACK option length in options: %v", opts)
   552  				}
   553  				numBlocks := sackOptionLen / 8
   554  				for j := 0; j < numBlocks; j++ {
   555  					start := binary.BigEndian.Uint32(opts[i+2+j*8:])
   556  					end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
   557  					gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
   558  						Start: seqnum.Value(start),
   559  						End:   seqnum.Value(end),
   560  					})
   561  				}
   562  				i += sackOptionLen
   563  			default:
   564  				// We don't recognize this option, just skip over it.
   565  				if i+2 > limit {
   566  					break
   567  				}
   568  				l := int(opts[i+1])
   569  				if l < 2 || i+l > limit {
   570  					break
   571  				}
   572  				i += l
   573  			}
   574  		}
   575  
   576  		if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
   577  			t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
   578  		}
   579  	}
   580  }
   581  
   582  // Payload creates a checker that checks the payload.
   583  func Payload(want []byte) TransportChecker {
   584  	return func(t *testing.T, h header.Transport) {
   585  		if got := h.Payload(); !reflect.DeepEqual(got, want) {
   586  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   587  		}
   588  	}
   589  }
   590  
   591  // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
   592  // potentially additional ICMPv4 header fields.
   593  func ICMPv4(checkers ...TransportChecker) NetworkChecker {
   594  	return func(t *testing.T, h []header.Network) {
   595  		t.Helper()
   596  
   597  		last := h[len(h)-1]
   598  
   599  		if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
   600  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
   601  		}
   602  
   603  		icmp := header.ICMPv4(last.Payload())
   604  		for _, f := range checkers {
   605  			f(t, icmp)
   606  		}
   607  		if t.Failed() {
   608  			t.FailNow()
   609  		}
   610  	}
   611  }
   612  
   613  // ICMPv4Type creates a checker that checks the ICMPv4 Type field.
   614  func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
   615  	return func(t *testing.T, h header.Transport) {
   616  		t.Helper()
   617  		icmpv4, ok := h.(header.ICMPv4)
   618  		if !ok {
   619  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
   620  		}
   621  		if got := icmpv4.Type(); got != want {
   622  			t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
   623  		}
   624  	}
   625  }
   626  
   627  // ICMPv4Code creates a checker that checks the ICMPv4 Code field.
   628  func ICMPv4Code(want byte) TransportChecker {
   629  	return func(t *testing.T, h header.Transport) {
   630  		t.Helper()
   631  		icmpv4, ok := h.(header.ICMPv4)
   632  		if !ok {
   633  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
   634  		}
   635  		if got := icmpv4.Code(); got != want {
   636  			t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
   637  		}
   638  	}
   639  }
   640  
   641  // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
   642  // potentially additional ICMPv6 header fields.
   643  //
   644  // ICMPv6 will validate the checksum field before calling checkers.
   645  func ICMPv6(checkers ...TransportChecker) NetworkChecker {
   646  	return func(t *testing.T, h []header.Network) {
   647  		t.Helper()
   648  
   649  		last := h[len(h)-1]
   650  
   651  		if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
   652  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
   653  		}
   654  
   655  		icmp := header.ICMPv6(last.Payload())
   656  		if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
   657  			t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
   658  		}
   659  
   660  		for _, f := range checkers {
   661  			f(t, icmp)
   662  		}
   663  		if t.Failed() {
   664  			t.FailNow()
   665  		}
   666  	}
   667  }
   668  
   669  // ICMPv6Type creates a checker that checks the ICMPv6 Type field.
   670  func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
   671  	return func(t *testing.T, h header.Transport) {
   672  		t.Helper()
   673  		icmpv6, ok := h.(header.ICMPv6)
   674  		if !ok {
   675  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
   676  		}
   677  		if got := icmpv6.Type(); got != want {
   678  			t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
   679  		}
   680  	}
   681  }
   682  
   683  // ICMPv6Code creates a checker that checks the ICMPv6 Code field.
   684  func ICMPv6Code(want byte) TransportChecker {
   685  	return func(t *testing.T, h header.Transport) {
   686  		t.Helper()
   687  		icmpv6, ok := h.(header.ICMPv6)
   688  		if !ok {
   689  			t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
   690  		}
   691  		if got := icmpv6.Code(); got != want {
   692  			t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
   693  		}
   694  	}
   695  }
   696  
   697  // NDP creates a checker that checks that the packet contains a valid NDP
   698  // message for type of ty, with potentially additional checks specified by
   699  // checkers.
   700  //
   701  // checkers may assume that a valid ICMPv6 is passed to it containing a valid
   702  // NDP message as far as the size of the message (minSize) is concerned. The
   703  // values within the message are up to checkers to validate.
   704  func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
   705  	return func(t *testing.T, h []header.Network) {
   706  		t.Helper()
   707  
   708  		// Check normal ICMPv6 first.
   709  		ICMPv6(
   710  			ICMPv6Type(msgType),
   711  			ICMPv6Code(0))(t, h)
   712  
   713  		last := h[len(h)-1]
   714  
   715  		icmp := header.ICMPv6(last.Payload())
   716  		if got := len(icmp.NDPPayload()); got < minSize {
   717  			t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
   718  		}
   719  
   720  		for _, f := range checkers {
   721  			f(t, icmp)
   722  		}
   723  		if t.Failed() {
   724  			t.FailNow()
   725  		}
   726  	}
   727  }
   728  
   729  // NDPNS creates a checker that checks that the packet contains a valid NDP
   730  // Neighbor Solicitation message (as per the raw wire format), with potentially
   731  // additional checks specified by checkers.
   732  //
   733  // checkers may assume that a valid ICMPv6 is passed to it containing a valid
   734  // NDPNS message as far as the size of the messages concerned. The values within
   735  // the message are up to checkers to validate.
   736  func NDPNS(checkers ...TransportChecker) NetworkChecker {
   737  	return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
   738  }
   739  
   740  // NDPNSTargetAddress creates a checker that checks the Target Address field of
   741  // a header.NDPNeighborSolicit.
   742  //
   743  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
   744  // containing a valid NDPNS message as far as the size is concerned.
   745  func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
   746  	return func(t *testing.T, h header.Transport) {
   747  		t.Helper()
   748  
   749  		icmp := h.(header.ICMPv6)
   750  		ns := header.NDPNeighborSolicit(icmp.NDPPayload())
   751  
   752  		if got := ns.TargetAddress(); got != want {
   753  			t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
   754  		}
   755  	}
   756  }