github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/checker/checker.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package checker provides helper functions to check networking packets for
    16  // validity.
    17  package checker
    18  
    19  import (
    20  	"encoding/binary"
    21  	"reflect"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/seqnum"
    30  )
    31  
    32  // NetworkChecker is a function to check a property of a network packet.
    33  type NetworkChecker func(*testing.T, []header.Network)
    34  
    35  // TransportChecker is a function to check a property of a transport packet.
    36  type TransportChecker func(*testing.T, header.Transport)
    37  
    38  // ControlMessagesChecker is a function to check a property of ancillary data.
    39  type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
    40  
    41  // IPv4 checks the validity and properties of the given IPv4 packet. It is
    42  // expected to be used in conjunction with other network checkers for specific
    43  // properties. For example, to check the source and destination address, one
    44  // would call:
    45  //
    46  // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
    47  func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
    48  	t.Helper()
    49  
    50  	ipv4 := header.IPv4(b)
    51  
    52  	if !ipv4.IsValid(len(b)) {
    53  		t.Fatalf("Not a valid IPv4 packet: %x", ipv4)
    54  	}
    55  
    56  	if !ipv4.IsChecksumValid() {
    57  		t.Errorf("Bad checksum, got = %d", ipv4.Checksum())
    58  	}
    59  
    60  	for _, f := range checkers {
    61  		f(t, []header.Network{ipv4})
    62  	}
    63  	if t.Failed() {
    64  		t.FailNow()
    65  	}
    66  }
    67  
    68  // IPv6 checks the validity and properties of the given IPv6 packet. The usage
    69  // is similar to IPv4.
    70  func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
    71  	t.Helper()
    72  
    73  	ipv6 := header.IPv6(b)
    74  	if !ipv6.IsValid(len(b)) {
    75  		t.Fatalf("Not a valid IPv6 packet: %x", ipv6)
    76  	}
    77  
    78  	for _, f := range checkers {
    79  		f(t, []header.Network{ipv6})
    80  	}
    81  	if t.Failed() {
    82  		t.FailNow()
    83  	}
    84  }
    85  
    86  // SrcAddr creates a checker that checks the source address.
    87  func SrcAddr(addr tcpip.Address) NetworkChecker {
    88  	return func(t *testing.T, h []header.Network) {
    89  		t.Helper()
    90  
    91  		if a := h[0].SourceAddress(); a != addr {
    92  			t.Errorf("Bad source address, got %v, want %v", a, addr)
    93  		}
    94  	}
    95  }
    96  
    97  // DstAddr creates a checker that checks the destination address.
    98  func DstAddr(addr tcpip.Address) NetworkChecker {
    99  	return func(t *testing.T, h []header.Network) {
   100  		t.Helper()
   101  
   102  		if a := h[0].DestinationAddress(); a != addr {
   103  			t.Errorf("Bad destination address, got %v, want %v", a, addr)
   104  		}
   105  	}
   106  }
   107  
   108  // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
   109  func TTL(ttl uint8) NetworkChecker {
   110  	return func(t *testing.T, h []header.Network) {
   111  		t.Helper()
   112  
   113  		var v uint8
   114  		switch ip := h[0].(type) {
   115  		case header.IPv4:
   116  			v = ip.TTL()
   117  		case header.IPv6:
   118  			v = ip.HopLimit()
   119  		case *ipv6HeaderWithExtHdr:
   120  			v = ip.HopLimit()
   121  		default:
   122  			t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
   123  		}
   124  		if v != ttl {
   125  			t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
   126  		}
   127  	}
   128  }
   129  
   130  // IPFullLength creates a checker for the full IP packet length. The
   131  // expected size is checked against both the Total Length in the
   132  // header and the number of bytes received.
   133  func IPFullLength(packetLength uint16) NetworkChecker {
   134  	return func(t *testing.T, h []header.Network) {
   135  		t.Helper()
   136  
   137  		var v uint16
   138  		var l uint16
   139  		switch ip := h[0].(type) {
   140  		case header.IPv4:
   141  			v = ip.TotalLength()
   142  			l = uint16(len(ip))
   143  		case header.IPv6:
   144  			v = ip.PayloadLength() + header.IPv6FixedHeaderSize
   145  			l = uint16(len(ip))
   146  		default:
   147  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
   148  		}
   149  		if l != packetLength {
   150  			t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
   151  		}
   152  		if v != packetLength {
   153  			t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
   154  		}
   155  	}
   156  }
   157  
   158  // IPv4HeaderLength creates a checker that checks the IPv4 Header length.
   159  func IPv4HeaderLength(headerLength int) NetworkChecker {
   160  	return func(t *testing.T, h []header.Network) {
   161  		t.Helper()
   162  
   163  		switch ip := h[0].(type) {
   164  		case header.IPv4:
   165  			if hl := ip.HeaderLength(); hl != uint8(headerLength) {
   166  				t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
   167  			}
   168  		default:
   169  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
   170  		}
   171  	}
   172  }
   173  
   174  // PayloadLen creates a checker that checks the payload length.
   175  func PayloadLen(payloadLength int) NetworkChecker {
   176  	return func(t *testing.T, h []header.Network) {
   177  		t.Helper()
   178  
   179  		if l := len(h[0].Payload()); l != payloadLength {
   180  			t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
   181  		}
   182  	}
   183  }
   184  
   185  // IPPayload creates a checker that checks the payload.
   186  func IPPayload(payload []byte) NetworkChecker {
   187  	return func(t *testing.T, h []header.Network) {
   188  		t.Helper()
   189  
   190  		got := h[0].Payload()
   191  
   192  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
   193  		if len(got) == 0 && len(payload) == 0 {
   194  			return
   195  		}
   196  
   197  		if diff := cmp.Diff(payload, got); diff != "" {
   198  			t.Errorf("payload mismatch (-want +got):\n%s", diff)
   199  		}
   200  	}
   201  }
   202  
   203  // IPv4Options returns a checker that checks the options in an IPv4 packet.
   204  func IPv4Options(want header.IPv4Options) NetworkChecker {
   205  	return func(t *testing.T, h []header.Network) {
   206  		t.Helper()
   207  
   208  		ip, ok := h[0].(header.IPv4)
   209  		if !ok {
   210  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
   211  		}
   212  		options := ip.Options()
   213  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
   214  		if len(want) == 0 && len(options) == 0 {
   215  			return
   216  		}
   217  		if diff := cmp.Diff(want, options); diff != "" {
   218  			t.Errorf("options mismatch (-want +got):\n%s", diff)
   219  		}
   220  	}
   221  }
   222  
   223  // IPv4RouterAlert returns a checker that checks that the RouterAlert option is
   224  // set in an IPv4 packet.
   225  func IPv4RouterAlert() NetworkChecker {
   226  	return func(t *testing.T, h []header.Network) {
   227  		t.Helper()
   228  		ip, ok := h[0].(header.IPv4)
   229  		if !ok {
   230  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
   231  		}
   232  		iterator := ip.Options().MakeIterator()
   233  		for {
   234  			opt, done, err := iterator.Next()
   235  			if err != nil {
   236  				t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer)
   237  			}
   238  			if done {
   239  				break
   240  			}
   241  			if opt.Type() != header.IPv4OptionRouterAlertType {
   242  				continue
   243  			}
   244  			want := [header.IPv4OptionRouterAlertLength]byte{
   245  				byte(header.IPv4OptionRouterAlertType),
   246  				header.IPv4OptionRouterAlertLength,
   247  				header.IPv4OptionRouterAlertValue,
   248  				header.IPv4OptionRouterAlertValue,
   249  			}
   250  			if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
   251  				t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
   252  			}
   253  			return
   254  		}
   255  		t.Errorf("failed to find router alert option in %v", ip.Options())
   256  	}
   257  }
   258  
   259  // FragmentOffset creates a checker that checks the FragmentOffset field.
   260  func FragmentOffset(offset uint16) NetworkChecker {
   261  	return func(t *testing.T, h []header.Network) {
   262  		t.Helper()
   263  
   264  		// We only do this for IPv4 for now.
   265  		switch ip := h[0].(type) {
   266  		case header.IPv4:
   267  			if v := ip.FragmentOffset(); v != offset {
   268  				t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
   269  			}
   270  		}
   271  	}
   272  }
   273  
   274  // FragmentFlags creates a checker that checks the fragment flags field.
   275  func FragmentFlags(flags uint8) NetworkChecker {
   276  	return func(t *testing.T, h []header.Network) {
   277  		t.Helper()
   278  
   279  		// We only do this for IPv4 for now.
   280  		switch ip := h[0].(type) {
   281  		case header.IPv4:
   282  			if v := ip.Flags(); v != flags {
   283  				t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
   284  			}
   285  		}
   286  	}
   287  }
   288  
   289  // ReceiveTClass creates a checker that checks the TCLASS field in
   290  // ControlMessages.
   291  func ReceiveTClass(want uint32) ControlMessagesChecker {
   292  	return func(t *testing.T, cm tcpip.ControlMessages) {
   293  		t.Helper()
   294  		if !cm.HasTClass {
   295  			t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
   296  		} else if got := cm.TClass; got != want {
   297  			t.Errorf("got cm.TClass = %d, want %d", got, want)
   298  		}
   299  	}
   300  }
   301  
   302  // ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
   303  func ReceiveTOS(want uint8) ControlMessagesChecker {
   304  	return func(t *testing.T, cm tcpip.ControlMessages) {
   305  		t.Helper()
   306  		if !cm.HasTOS {
   307  			t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
   308  		} else if got := cm.TOS; got != want {
   309  			t.Errorf("got cm.TOS = %d, want %d", got, want)
   310  		}
   311  	}
   312  }
   313  
   314  // ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
   315  // ControlMessages.
   316  func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
   317  	return func(t *testing.T, cm tcpip.ControlMessages) {
   318  		t.Helper()
   319  		if !cm.HasIPPacketInfo {
   320  			t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
   321  		} else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
   322  			t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
   323  		}
   324  	}
   325  }
   326  
   327  // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
   328  // field in ControlMessages.
   329  func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
   330  	return func(t *testing.T, cm tcpip.ControlMessages) {
   331  		t.Helper()
   332  		if !cm.HasOriginalDstAddress {
   333  			t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
   334  		} else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
   335  			t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
   336  		}
   337  	}
   338  }
   339  
   340  // TOS creates a checker that checks the TOS field.
   341  func TOS(tos uint8, label uint32) NetworkChecker {
   342  	return func(t *testing.T, h []header.Network) {
   343  		t.Helper()
   344  
   345  		if v, l := h[0].TOS(); v != tos || l != label {
   346  			t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
   347  		}
   348  	}
   349  }
   350  
   351  // Raw creates a checker that checks the bytes of payload.
   352  // The checker always checks the payload of the last network header.
   353  // For instance, in case of IPv6 fragments, the payload that will be checked
   354  // is the one containing the actual data that the packet is carrying, without
   355  // the bytes added by the IPv6 fragmentation.
   356  func Raw(want []byte) NetworkChecker {
   357  	return func(t *testing.T, h []header.Network) {
   358  		t.Helper()
   359  
   360  		if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
   361  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   362  		}
   363  	}
   364  }
   365  
   366  // IPv6Fragment creates a checker that validates an IPv6 fragment.
   367  func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
   368  	return func(t *testing.T, h []header.Network) {
   369  		t.Helper()
   370  
   371  		if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
   372  			t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
   373  		}
   374  
   375  		ipv6Frag := header.IPv6Fragment(h[0].Payload())
   376  		if !ipv6Frag.IsValid() {
   377  			t.Error("Not a valid IPv6 fragment")
   378  		}
   379  
   380  		for _, f := range checkers {
   381  			f(t, []header.Network{h[0], ipv6Frag})
   382  		}
   383  		if t.Failed() {
   384  			t.FailNow()
   385  		}
   386  	}
   387  }
   388  
   389  // TCP creates a checker that checks that the transport protocol is TCP and
   390  // potentially additional transport header fields.
   391  func TCP(checkers ...TransportChecker) NetworkChecker {
   392  	return func(t *testing.T, h []header.Network) {
   393  		t.Helper()
   394  
   395  		first := h[0]
   396  		last := h[len(h)-1]
   397  
   398  		if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
   399  			t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
   400  		}
   401  
   402  		tcp := header.TCP(last.Payload())
   403  		payload := tcp.Payload()
   404  		payloadChecksum := header.Checksum(payload, 0)
   405  		if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) {
   406  			t.Errorf("Bad checksum, got = %d", tcp.Checksum())
   407  		}
   408  
   409  		// Run the transport checkers.
   410  		for _, f := range checkers {
   411  			f(t, tcp)
   412  		}
   413  		if t.Failed() {
   414  			t.FailNow()
   415  		}
   416  	}
   417  }
   418  
   419  // UDP creates a checker that checks that the transport protocol is UDP and
   420  // potentially additional transport header fields.
   421  func UDP(checkers ...TransportChecker) NetworkChecker {
   422  	return func(t *testing.T, h []header.Network) {
   423  		t.Helper()
   424  
   425  		last := h[len(h)-1]
   426  
   427  		if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
   428  			t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
   429  		}
   430  
   431  		udp := header.UDP(last.Payload())
   432  		for _, f := range checkers {
   433  			f(t, udp)
   434  		}
   435  		if t.Failed() {
   436  			t.FailNow()
   437  		}
   438  	}
   439  }
   440  
   441  // SrcPort creates a checker that checks the source port.
   442  func SrcPort(port uint16) TransportChecker {
   443  	return func(t *testing.T, h header.Transport) {
   444  		t.Helper()
   445  
   446  		if p := h.SourcePort(); p != port {
   447  			t.Errorf("Bad source port, got = %d, want = %d", p, port)
   448  		}
   449  	}
   450  }
   451  
   452  // DstPort creates a checker that checks the destination port.
   453  func DstPort(port uint16) TransportChecker {
   454  	return func(t *testing.T, h header.Transport) {
   455  		t.Helper()
   456  
   457  		if p := h.DestinationPort(); p != port {
   458  			t.Errorf("Bad destination port, got = %d, want = %d", p, port)
   459  		}
   460  	}
   461  }
   462  
   463  // NoChecksum creates a checker that checks if the checksum is zero.
   464  func NoChecksum(noChecksum bool) TransportChecker {
   465  	return func(t *testing.T, h header.Transport) {
   466  		t.Helper()
   467  
   468  		udp, ok := h.(header.UDP)
   469  		if !ok {
   470  			t.Fatalf("UDP header not found in h: %T", h)
   471  		}
   472  
   473  		if b := udp.Checksum() == 0; b != noChecksum {
   474  			t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
   475  		}
   476  	}
   477  }
   478  
   479  // TCPSeqNum creates a checker that checks the sequence number.
   480  func TCPSeqNum(seq uint32) TransportChecker {
   481  	return func(t *testing.T, h header.Transport) {
   482  		t.Helper()
   483  
   484  		tcp, ok := h.(header.TCP)
   485  		if !ok {
   486  			t.Fatalf("TCP header not found in h: %T", h)
   487  		}
   488  
   489  		if s := tcp.SequenceNumber(); s != seq {
   490  			t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
   491  		}
   492  	}
   493  }
   494  
   495  // TCPAckNum creates a checker that checks the ack number.
   496  func TCPAckNum(seq uint32) TransportChecker {
   497  	return func(t *testing.T, h header.Transport) {
   498  		t.Helper()
   499  
   500  		tcp, ok := h.(header.TCP)
   501  		if !ok {
   502  			t.Fatalf("TCP header not found in h: %T", h)
   503  		}
   504  
   505  		if s := tcp.AckNumber(); s != seq {
   506  			t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
   507  		}
   508  	}
   509  }
   510  
   511  // TCPWindow creates a checker that checks the tcp window.
   512  func TCPWindow(window uint16) TransportChecker {
   513  	return func(t *testing.T, h header.Transport) {
   514  		t.Helper()
   515  
   516  		tcp, ok := h.(header.TCP)
   517  		if !ok {
   518  			t.Fatalf("TCP header not found in hdr : %T", h)
   519  		}
   520  
   521  		if w := tcp.WindowSize(); w != window {
   522  			t.Errorf("Bad window, got %d, want %d", w, window)
   523  		}
   524  	}
   525  }
   526  
   527  // TCPWindowGreaterThanEq creates a checker that checks that the TCP window
   528  // is greater than or equal to the provided value.
   529  func TCPWindowGreaterThanEq(window uint16) TransportChecker {
   530  	return func(t *testing.T, h header.Transport) {
   531  		t.Helper()
   532  
   533  		tcp, ok := h.(header.TCP)
   534  		if !ok {
   535  			t.Fatalf("TCP header not found in h: %T", h)
   536  		}
   537  
   538  		if w := tcp.WindowSize(); w < window {
   539  			t.Errorf("Bad window, got %d, want > %d", w, window)
   540  		}
   541  	}
   542  }
   543  
   544  // TCPWindowLessThanEq creates a checker that checks that the tcp window
   545  // is less than or equal to the provided value.
   546  func TCPWindowLessThanEq(window uint16) TransportChecker {
   547  	return func(t *testing.T, h header.Transport) {
   548  		t.Helper()
   549  
   550  		tcp, ok := h.(header.TCP)
   551  		if !ok {
   552  			t.Fatalf("TCP header not found in h: %T", h)
   553  		}
   554  
   555  		if w := tcp.WindowSize(); w > window {
   556  			t.Errorf("Bad window, got %d, want < %d", w, window)
   557  		}
   558  	}
   559  }
   560  
   561  // TCPFlags creates a checker that checks the tcp flags.
   562  func TCPFlags(flags header.TCPFlags) TransportChecker {
   563  	return func(t *testing.T, h header.Transport) {
   564  		t.Helper()
   565  
   566  		tcp, ok := h.(header.TCP)
   567  		if !ok {
   568  			t.Fatalf("TCP header not found in h: %T", h)
   569  		}
   570  
   571  		if got := tcp.Flags(); got != flags {
   572  			t.Errorf("got tcp.Flags() = %s, want %s", got, flags)
   573  		}
   574  	}
   575  }
   576  
   577  // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
   578  // given mask, match the supplied flags.
   579  func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker {
   580  	return func(t *testing.T, h header.Transport) {
   581  		t.Helper()
   582  
   583  		tcp, ok := h.(header.TCP)
   584  		if !ok {
   585  			t.Fatalf("TCP header not found in h: %T", h)
   586  		}
   587  
   588  		if got := tcp.Flags(); (got & mask) != (flags & mask) {
   589  			t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask)
   590  		}
   591  	}
   592  }
   593  
   594  // TCPSynOptions creates a checker that checks the presence of TCP options in
   595  // SYN segments.
   596  //
   597  // If wndscale is negative, the window scale option must not be present.
   598  func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
   599  	return func(t *testing.T, h header.Transport) {
   600  		t.Helper()
   601  
   602  		tcp, ok := h.(header.TCP)
   603  		if !ok {
   604  			return
   605  		}
   606  		opts := tcp.Options()
   607  		limit := len(opts)
   608  		foundMSS := false
   609  		foundWS := false
   610  		foundTS := false
   611  		foundSACKPermitted := false
   612  		tsVal := uint32(0)
   613  		tsEcr := uint32(0)
   614  		for i := 0; i < limit; {
   615  			switch opts[i] {
   616  			case header.TCPOptionEOL:
   617  				i = limit
   618  			case header.TCPOptionNOP:
   619  				i++
   620  			case header.TCPOptionMSS:
   621  				v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
   622  				if wantOpts.MSS != v {
   623  					t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
   624  				}
   625  				foundMSS = true
   626  				i += 4
   627  			case header.TCPOptionWS:
   628  				if wantOpts.WS < 0 {
   629  					t.Error("WS present when it shouldn't be")
   630  				}
   631  				v := int(opts[i+2])
   632  				if v != wantOpts.WS {
   633  					t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
   634  				}
   635  				foundWS = true
   636  				i += 3
   637  			case header.TCPOptionTS:
   638  				if i+9 >= limit {
   639  					t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
   640  				}
   641  				if opts[i+1] != 10 {
   642  					t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
   643  				}
   644  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   645  				tsEcr = uint32(0)
   646  				if tcp.Flags()&header.TCPFlagAck != 0 {
   647  					// If the syn is an SYN-ACK then read
   648  					// the tsEcr value as well.
   649  					tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   650  				}
   651  				foundTS = true
   652  				i += 10
   653  			case header.TCPOptionSACKPermitted:
   654  				if i+1 >= limit {
   655  					t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
   656  				}
   657  				if opts[i+1] != 2 {
   658  					t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
   659  				}
   660  				foundSACKPermitted = true
   661  				i += 2
   662  
   663  			default:
   664  				i += int(opts[i+1])
   665  			}
   666  		}
   667  
   668  		if !foundMSS {
   669  			t.Errorf("MSS option not found. Options: %x", opts)
   670  		}
   671  
   672  		if !foundWS && wantOpts.WS >= 0 {
   673  			t.Errorf("WS option not found. Options: %x", opts)
   674  		}
   675  		if wantOpts.TS && !foundTS {
   676  			t.Errorf("TS option not found. Options: %x", opts)
   677  		}
   678  		if foundTS && tsVal == 0 {
   679  			t.Error("TS option specified but the timestamp value is zero")
   680  		}
   681  		if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
   682  			t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
   683  		}
   684  		if wantOpts.SACKPermitted && !foundSACKPermitted {
   685  			t.Errorf("SACKPermitted option not found. Options: %x", opts)
   686  		}
   687  	}
   688  }
   689  
   690  // TCPTimestampChecker creates a checker that validates that a TCP segment has a
   691  // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
   692  // wantTSEcr values with those in the TCP segment (if present).
   693  //
   694  // If wantTSVal or wantTSEcr is zero then the corresponding comparison is
   695  // skipped.
   696  func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
   697  	return func(t *testing.T, h header.Transport) {
   698  		t.Helper()
   699  
   700  		tcp, ok := h.(header.TCP)
   701  		if !ok {
   702  			return
   703  		}
   704  		opts := tcp.Options()
   705  		limit := len(opts)
   706  		foundTS := false
   707  		tsVal := uint32(0)
   708  		tsEcr := uint32(0)
   709  		for i := 0; i < limit; {
   710  			switch opts[i] {
   711  			case header.TCPOptionEOL:
   712  				i = limit
   713  			case header.TCPOptionNOP:
   714  				i++
   715  			case header.TCPOptionTS:
   716  				if i+9 >= limit {
   717  					t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
   718  				}
   719  				if opts[i+1] != 10 {
   720  					t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
   721  				}
   722  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   723  				tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   724  				foundTS = true
   725  				i += 10
   726  			default:
   727  				// We don't recognize this option, just skip over it.
   728  				if i+2 > limit {
   729  					return
   730  				}
   731  				l := int(opts[i+1])
   732  				if i < 2 || i+l > limit {
   733  					return
   734  				}
   735  				i += l
   736  			}
   737  		}
   738  
   739  		if wantTS != foundTS {
   740  			t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
   741  		}
   742  		if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
   743  			t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
   744  		}
   745  		if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
   746  			t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
   747  		}
   748  	}
   749  }
   750  
   751  // TCPSACKBlockChecker creates a checker that verifies that the segment does
   752  // contain the specified SACK blocks in the TCP options.
   753  func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
   754  	return func(t *testing.T, h header.Transport) {
   755  		t.Helper()
   756  		tcp, ok := h.(header.TCP)
   757  		if !ok {
   758  			return
   759  		}
   760  		var gotSACKBlocks []header.SACKBlock
   761  
   762  		opts := tcp.Options()
   763  		limit := len(opts)
   764  		for i := 0; i < limit; {
   765  			switch opts[i] {
   766  			case header.TCPOptionEOL:
   767  				i = limit
   768  			case header.TCPOptionNOP:
   769  				i++
   770  			case header.TCPOptionSACK:
   771  				if i+2 > limit {
   772  					// Malformed SACK block.
   773  					t.Errorf("malformed SACK option in options: %v", opts)
   774  				}
   775  				sackOptionLen := int(opts[i+1])
   776  				if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
   777  					// Malformed SACK block.
   778  					t.Errorf("malformed SACK option length in options: %v", opts)
   779  				}
   780  				numBlocks := sackOptionLen / 8
   781  				for j := 0; j < numBlocks; j++ {
   782  					start := binary.BigEndian.Uint32(opts[i+2+j*8:])
   783  					end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
   784  					gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
   785  						Start: seqnum.Value(start),
   786  						End:   seqnum.Value(end),
   787  					})
   788  				}
   789  				i += sackOptionLen
   790  			default:
   791  				// We don't recognize this option, just skip over it.
   792  				if i+2 > limit {
   793  					break
   794  				}
   795  				l := int(opts[i+1])
   796  				if l < 2 || i+l > limit {
   797  					break
   798  				}
   799  				i += l
   800  			}
   801  		}
   802  
   803  		if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
   804  			t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
   805  		}
   806  	}
   807  }
   808  
   809  // Payload creates a checker that checks the payload.
   810  func Payload(want []byte) TransportChecker {
   811  	return func(t *testing.T, h header.Transport) {
   812  		t.Helper()
   813  
   814  		if got := h.Payload(); !reflect.DeepEqual(got, want) {
   815  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   816  		}
   817  	}
   818  }
   819  
   820  // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
   821  // and potentially additional ICMPv4 header fields.
   822  func ICMPv4(checkers ...TransportChecker) NetworkChecker {
   823  	return func(t *testing.T, h []header.Network) {
   824  		t.Helper()
   825  
   826  		last := h[len(h)-1]
   827  
   828  		if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
   829  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
   830  		}
   831  
   832  		icmp := header.ICMPv4(last.Payload())
   833  		for _, f := range checkers {
   834  			f(t, icmp)
   835  		}
   836  		if t.Failed() {
   837  			t.FailNow()
   838  		}
   839  	}
   840  }
   841  
   842  // ICMPv4Type creates a checker that checks the ICMPv4 Type field.
   843  func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
   844  	return func(t *testing.T, h header.Transport) {
   845  		t.Helper()
   846  
   847  		icmpv4, ok := h.(header.ICMPv4)
   848  		if !ok {
   849  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   850  		}
   851  		if got := icmpv4.Type(); got != want {
   852  			t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
   853  		}
   854  	}
   855  }
   856  
   857  // ICMPv4Code creates a checker that checks the ICMPv4 Code field.
   858  func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
   859  	return func(t *testing.T, h header.Transport) {
   860  		t.Helper()
   861  
   862  		icmpv4, ok := h.(header.ICMPv4)
   863  		if !ok {
   864  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   865  		}
   866  		if got := icmpv4.Code(); got != want {
   867  			t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
   868  		}
   869  	}
   870  }
   871  
   872  // ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
   873  func ICMPv4Ident(want uint16) TransportChecker {
   874  	return func(t *testing.T, h header.Transport) {
   875  		t.Helper()
   876  
   877  		icmpv4, ok := h.(header.ICMPv4)
   878  		if !ok {
   879  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   880  		}
   881  		if got := icmpv4.Ident(); got != want {
   882  			t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
   883  		}
   884  	}
   885  }
   886  
   887  // ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
   888  func ICMPv4Seq(want uint16) TransportChecker {
   889  	return func(t *testing.T, h header.Transport) {
   890  		t.Helper()
   891  
   892  		icmpv4, ok := h.(header.ICMPv4)
   893  		if !ok {
   894  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   895  		}
   896  		if got := icmpv4.Sequence(); got != want {
   897  			t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
   898  		}
   899  	}
   900  }
   901  
   902  // ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
   903  func ICMPv4Pointer(want uint8) TransportChecker {
   904  	return func(t *testing.T, h header.Transport) {
   905  		t.Helper()
   906  
   907  		icmpv4, ok := h.(header.ICMPv4)
   908  		if !ok {
   909  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   910  		}
   911  		if got := icmpv4.Pointer(); got != want {
   912  			t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
   913  		}
   914  	}
   915  }
   916  
   917  // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
   918  // This assumes that the payload exactly makes up the rest of the slice.
   919  func ICMPv4Checksum() TransportChecker {
   920  	return func(t *testing.T, h header.Transport) {
   921  		t.Helper()
   922  
   923  		icmpv4, ok := h.(header.ICMPv4)
   924  		if !ok {
   925  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   926  		}
   927  		heldChecksum := icmpv4.Checksum()
   928  		icmpv4.SetChecksum(0)
   929  		newChecksum := ^header.Checksum(icmpv4, 0)
   930  		icmpv4.SetChecksum(heldChecksum)
   931  		if heldChecksum != newChecksum {
   932  			t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
   933  		}
   934  	}
   935  }
   936  
   937  // ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
   938  func ICMPv4Payload(want []byte) TransportChecker {
   939  	return func(t *testing.T, h header.Transport) {
   940  		t.Helper()
   941  
   942  		icmpv4, ok := h.(header.ICMPv4)
   943  		if !ok {
   944  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   945  		}
   946  		payload := icmpv4.Payload()
   947  
   948  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
   949  		if len(want) == 0 && len(payload) == 0 {
   950  			return
   951  		}
   952  
   953  		if diff := cmp.Diff(want, payload); diff != "" {
   954  			t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
   955  		}
   956  	}
   957  }
   958  
   959  // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
   960  // potentially additional ICMPv6 header fields.
   961  //
   962  // ICMPv6 will validate the checksum field before calling checkers.
   963  func ICMPv6(checkers ...TransportChecker) NetworkChecker {
   964  	return func(t *testing.T, h []header.Network) {
   965  		t.Helper()
   966  
   967  		last := h[len(h)-1]
   968  
   969  		if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
   970  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
   971  		}
   972  
   973  		icmp := header.ICMPv6(last.Payload())
   974  		if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   975  			Header: icmp,
   976  			Src:    last.SourceAddress(),
   977  			Dst:    last.DestinationAddress(),
   978  		}); got != want {
   979  			t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
   980  		}
   981  
   982  		for _, f := range checkers {
   983  			f(t, icmp)
   984  		}
   985  		if t.Failed() {
   986  			t.FailNow()
   987  		}
   988  	}
   989  }
   990  
   991  // ICMPv6Type creates a checker that checks the ICMPv6 Type field.
   992  func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
   993  	return func(t *testing.T, h header.Transport) {
   994  		t.Helper()
   995  
   996  		icmpv6, ok := h.(header.ICMPv6)
   997  		if !ok {
   998  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
   999  		}
  1000  		if got := icmpv6.Type(); got != want {
  1001  			t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
  1002  		}
  1003  	}
  1004  }
  1005  
  1006  // ICMPv6Code creates a checker that checks the ICMPv6 Code field.
  1007  func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
  1008  	return func(t *testing.T, h header.Transport) {
  1009  		t.Helper()
  1010  
  1011  		icmpv6, ok := h.(header.ICMPv6)
  1012  		if !ok {
  1013  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1014  		}
  1015  		if got := icmpv6.Code(); got != want {
  1016  			t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
  1017  		}
  1018  	}
  1019  }
  1020  
  1021  // ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
  1022  // field.
  1023  func ICMPv6TypeSpecific(want uint32) TransportChecker {
  1024  	return func(t *testing.T, h header.Transport) {
  1025  		t.Helper()
  1026  
  1027  		icmpv6, ok := h.(header.ICMPv6)
  1028  		if !ok {
  1029  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1030  		}
  1031  		if got := icmpv6.TypeSpecific(); got != want {
  1032  			t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
  1033  		}
  1034  	}
  1035  }
  1036  
  1037  // ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
  1038  func ICMPv6Payload(want []byte) TransportChecker {
  1039  	return func(t *testing.T, h header.Transport) {
  1040  		t.Helper()
  1041  
  1042  		icmpv6, ok := h.(header.ICMPv6)
  1043  		if !ok {
  1044  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1045  		}
  1046  		payload := icmpv6.Payload()
  1047  
  1048  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
  1049  		if len(want) == 0 && len(payload) == 0 {
  1050  			return
  1051  		}
  1052  
  1053  		if diff := cmp.Diff(want, payload); diff != "" {
  1054  			t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
  1055  		}
  1056  	}
  1057  }
  1058  
  1059  // MLD creates a checker that checks that the packet contains a valid MLD
  1060  // message for type of mldType, with potentially additional checks specified by
  1061  // checkers.
  1062  //
  1063  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1064  // MLD message as far as the size of the message (minSize) is concerned. The
  1065  // values within the message are up to checkers to validate.
  1066  func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
  1067  	return func(t *testing.T, h []header.Network) {
  1068  		t.Helper()
  1069  
  1070  		// Check normal ICMPv6 first.
  1071  		ICMPv6(
  1072  			ICMPv6Type(msgType),
  1073  			ICMPv6Code(0))(t, h)
  1074  
  1075  		last := h[len(h)-1]
  1076  
  1077  		icmp := header.ICMPv6(last.Payload())
  1078  		if got := len(icmp.MessageBody()); got < minSize {
  1079  			t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
  1080  		}
  1081  
  1082  		for _, f := range checkers {
  1083  			f(t, icmp)
  1084  		}
  1085  		if t.Failed() {
  1086  			t.FailNow()
  1087  		}
  1088  	}
  1089  }
  1090  
  1091  // MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
  1092  // field of a MLD message.
  1093  //
  1094  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1095  // containing a valid MLD message as far as the size is concerned.
  1096  func MLDMaxRespDelay(want time.Duration) TransportChecker {
  1097  	return func(t *testing.T, h header.Transport) {
  1098  		t.Helper()
  1099  
  1100  		icmp := h.(header.ICMPv6)
  1101  		ns := header.MLD(icmp.MessageBody())
  1102  
  1103  		if got := ns.MaximumResponseDelay(); got != want {
  1104  			t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
  1105  		}
  1106  	}
  1107  }
  1108  
  1109  // MLDMulticastAddress creates a checker that checks the Multicast Address
  1110  // field of a MLD message.
  1111  //
  1112  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1113  // containing a valid MLD message as far as the size is concerned.
  1114  func MLDMulticastAddress(want tcpip.Address) TransportChecker {
  1115  	return func(t *testing.T, h header.Transport) {
  1116  		t.Helper()
  1117  
  1118  		icmp := h.(header.ICMPv6)
  1119  		ns := header.MLD(icmp.MessageBody())
  1120  
  1121  		if got := ns.MulticastAddress(); got != want {
  1122  			t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
  1123  		}
  1124  	}
  1125  }
  1126  
  1127  // NDP creates a checker that checks that the packet contains a valid NDP
  1128  // message for type of ty, with potentially additional checks specified by
  1129  // checkers.
  1130  //
  1131  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1132  // NDP message as far as the size of the message (minSize) is concerned. The
  1133  // values within the message are up to checkers to validate.
  1134  func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
  1135  	return func(t *testing.T, h []header.Network) {
  1136  		t.Helper()
  1137  
  1138  		// Check normal ICMPv6 first.
  1139  		ICMPv6(
  1140  			ICMPv6Type(msgType),
  1141  			ICMPv6Code(0))(t, h)
  1142  
  1143  		last := h[len(h)-1]
  1144  
  1145  		icmp := header.ICMPv6(last.Payload())
  1146  		if got := len(icmp.MessageBody()); got < minSize {
  1147  			t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
  1148  		}
  1149  
  1150  		for _, f := range checkers {
  1151  			f(t, icmp)
  1152  		}
  1153  		if t.Failed() {
  1154  			t.FailNow()
  1155  		}
  1156  	}
  1157  }
  1158  
  1159  // NDPNS creates a checker that checks that the packet contains a valid NDP
  1160  // Neighbor Solicitation message (as per the raw wire format), with potentially
  1161  // additional checks specified by checkers.
  1162  //
  1163  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1164  // NDPNS message as far as the size of the message is concerned. The values
  1165  // within the message are up to checkers to validate.
  1166  func NDPNS(checkers ...TransportChecker) NetworkChecker {
  1167  	return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
  1168  }
  1169  
  1170  // NDPNSTargetAddress creates a checker that checks the Target Address field of
  1171  // a header.NDPNeighborSolicit.
  1172  //
  1173  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1174  // containing a valid NDPNS message as far as the size is concerned.
  1175  func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
  1176  	return func(t *testing.T, h header.Transport) {
  1177  		t.Helper()
  1178  
  1179  		icmp := h.(header.ICMPv6)
  1180  		ns := header.NDPNeighborSolicit(icmp.MessageBody())
  1181  
  1182  		if got := ns.TargetAddress(); got != want {
  1183  			t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
  1184  		}
  1185  	}
  1186  }
  1187  
  1188  // NDPNA creates a checker that checks that the packet contains a valid NDP
  1189  // Neighbor Advertisement message (as per the raw wire format), with potentially
  1190  // additional checks specified by checkers.
  1191  //
  1192  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1193  // NDPNA message as far as the size of the message is concerned. The values
  1194  // within the message are up to checkers to validate.
  1195  func NDPNA(checkers ...TransportChecker) NetworkChecker {
  1196  	return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
  1197  }
  1198  
  1199  // NDPNATargetAddress creates a checker that checks the Target Address field of
  1200  // a header.NDPNeighborAdvert.
  1201  //
  1202  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1203  // containing a valid NDPNA message as far as the size is concerned.
  1204  func NDPNATargetAddress(want tcpip.Address) TransportChecker {
  1205  	return func(t *testing.T, h header.Transport) {
  1206  		t.Helper()
  1207  
  1208  		icmp := h.(header.ICMPv6)
  1209  		na := header.NDPNeighborAdvert(icmp.MessageBody())
  1210  
  1211  		if got := na.TargetAddress(); got != want {
  1212  			t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
  1213  		}
  1214  	}
  1215  }
  1216  
  1217  // NDPNASolicitedFlag creates a checker that checks the Solicited field of
  1218  // a header.NDPNeighborAdvert.
  1219  //
  1220  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1221  // containing a valid NDPNA message as far as the size is concerned.
  1222  func NDPNASolicitedFlag(want bool) TransportChecker {
  1223  	return func(t *testing.T, h header.Transport) {
  1224  		t.Helper()
  1225  
  1226  		icmp := h.(header.ICMPv6)
  1227  		na := header.NDPNeighborAdvert(icmp.MessageBody())
  1228  
  1229  		if got := na.SolicitedFlag(); got != want {
  1230  			t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
  1231  		}
  1232  	}
  1233  }
  1234  
  1235  // ndpOptions checks that optsBuf only contains opts.
  1236  func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
  1237  	t.Helper()
  1238  
  1239  	it, err := optsBuf.Iter(true)
  1240  	if err != nil {
  1241  		t.Errorf("optsBuf.Iter(true): %s", err)
  1242  		return
  1243  	}
  1244  
  1245  	i := 0
  1246  	for {
  1247  		opt, done, err := it.Next()
  1248  		if err != nil {
  1249  			// This should never happen as Iter(true) above did not return an error.
  1250  			t.Fatalf("unexpected error when iterating over NDP options: %s", err)
  1251  		}
  1252  		if done {
  1253  			break
  1254  		}
  1255  
  1256  		if i >= len(opts) {
  1257  			t.Errorf("got unexpected option: %s", opt)
  1258  			continue
  1259  		}
  1260  
  1261  		switch wantOpt := opts[i].(type) {
  1262  		case header.NDPSourceLinkLayerAddressOption:
  1263  			gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
  1264  			if !ok {
  1265  				t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
  1266  			} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
  1267  				t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
  1268  			}
  1269  		case header.NDPTargetLinkLayerAddressOption:
  1270  			gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
  1271  			if !ok {
  1272  				t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
  1273  			} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
  1274  				t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
  1275  			}
  1276  		case header.NDPNonceOption:
  1277  			gotOpt, ok := opt.(header.NDPNonceOption)
  1278  			if !ok {
  1279  				t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
  1280  			} else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" {
  1281  				t.Errorf("nonce mismatch (-want +got):\n%s", diff)
  1282  			}
  1283  		default:
  1284  			t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
  1285  		}
  1286  
  1287  		i++
  1288  	}
  1289  
  1290  	if missing := opts[i:]; len(missing) > 0 {
  1291  		t.Errorf("missing options: %s", missing)
  1292  	}
  1293  }
  1294  
  1295  // NDPNAOptions creates a checker that checks that the packet contains the
  1296  // provided NDP options within an NDP Neighbor Solicitation message.
  1297  //
  1298  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1299  // containing a valid NDPNA message as far as the size is concerned.
  1300  func NDPNAOptions(opts []header.NDPOption) TransportChecker {
  1301  	return func(t *testing.T, h header.Transport) {
  1302  		t.Helper()
  1303  
  1304  		icmp := h.(header.ICMPv6)
  1305  		na := header.NDPNeighborAdvert(icmp.MessageBody())
  1306  		ndpOptions(t, na.Options(), opts)
  1307  	}
  1308  }
  1309  
  1310  // NDPNSOptions creates a checker that checks that the packet contains the
  1311  // provided NDP options within an NDP Neighbor Solicitation message.
  1312  //
  1313  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1314  // containing a valid NDPNS message as far as the size is concerned.
  1315  func NDPNSOptions(opts []header.NDPOption) TransportChecker {
  1316  	return func(t *testing.T, h header.Transport) {
  1317  		t.Helper()
  1318  
  1319  		icmp := h.(header.ICMPv6)
  1320  		ns := header.NDPNeighborSolicit(icmp.MessageBody())
  1321  		ndpOptions(t, ns.Options(), opts)
  1322  	}
  1323  }
  1324  
  1325  // NDPRS creates a checker that checks that the packet contains a valid NDP
  1326  // Router Solicitation message (as per the raw wire format).
  1327  //
  1328  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1329  // NDPRS as far as the size of the message is concerned. The values within the
  1330  // message are up to checkers to validate.
  1331  func NDPRS(checkers ...TransportChecker) NetworkChecker {
  1332  	return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
  1333  }
  1334  
  1335  // NDPRSOptions creates a checker that checks that the packet contains the
  1336  // provided NDP options within an NDP Router Solicitation message.
  1337  //
  1338  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1339  // containing a valid NDPRS message as far as the size is concerned.
  1340  func NDPRSOptions(opts []header.NDPOption) TransportChecker {
  1341  	return func(t *testing.T, h header.Transport) {
  1342  		t.Helper()
  1343  
  1344  		icmp := h.(header.ICMPv6)
  1345  		rs := header.NDPRouterSolicit(icmp.MessageBody())
  1346  		ndpOptions(t, rs.Options(), opts)
  1347  	}
  1348  }
  1349  
  1350  // IGMP checks the validity and properties of the given IGMP packet. It is
  1351  // expected to be used in conjunction with other IGMP transport checkers for
  1352  // specific properties.
  1353  func IGMP(checkers ...TransportChecker) NetworkChecker {
  1354  	return func(t *testing.T, h []header.Network) {
  1355  		t.Helper()
  1356  
  1357  		last := h[len(h)-1]
  1358  
  1359  		if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
  1360  			t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
  1361  		}
  1362  
  1363  		igmp := header.IGMP(last.Payload())
  1364  		for _, f := range checkers {
  1365  			f(t, igmp)
  1366  		}
  1367  		if t.Failed() {
  1368  			t.FailNow()
  1369  		}
  1370  	}
  1371  }
  1372  
  1373  // IGMPType creates a checker that checks the IGMP Type field.
  1374  func IGMPType(want header.IGMPType) TransportChecker {
  1375  	return func(t *testing.T, h header.Transport) {
  1376  		t.Helper()
  1377  
  1378  		igmp, ok := h.(header.IGMP)
  1379  		if !ok {
  1380  			t.Fatalf("got transport header = %T, want = header.IGMP", h)
  1381  		}
  1382  		if got := igmp.Type(); got != want {
  1383  			t.Errorf("got igmp.Type() = %d, want = %d", got, want)
  1384  		}
  1385  	}
  1386  }
  1387  
  1388  // IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
  1389  func IGMPMaxRespTime(want time.Duration) TransportChecker {
  1390  	return func(t *testing.T, h header.Transport) {
  1391  		t.Helper()
  1392  
  1393  		igmp, ok := h.(header.IGMP)
  1394  		if !ok {
  1395  			t.Fatalf("got transport header = %T, want = header.IGMP", h)
  1396  		}
  1397  		if got := igmp.MaxRespTime(); got != want {
  1398  			t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
  1399  		}
  1400  	}
  1401  }
  1402  
  1403  // IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
  1404  func IGMPGroupAddress(want tcpip.Address) TransportChecker {
  1405  	return func(t *testing.T, h header.Transport) {
  1406  		t.Helper()
  1407  
  1408  		igmp, ok := h.(header.IGMP)
  1409  		if !ok {
  1410  			t.Fatalf("got transport header = %T, want = header.IGMP", h)
  1411  		}
  1412  		if got := igmp.GroupAddress(); got != want {
  1413  			t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
  1414  		}
  1415  	}
  1416  }
  1417  
  1418  // IPv6ExtHdrChecker is a function to check an extension header.
  1419  type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
  1420  
  1421  // IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
  1422  func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
  1423  	t.Helper()
  1424  
  1425  	ipv6 := header.IPv6(b)
  1426  	if !ipv6.IsValid(len(b)) {
  1427  		t.Error("not a valid IPv6 packet")
  1428  		return
  1429  	}
  1430  
  1431  	payloadIterator := header.MakeIPv6PayloadIterator(
  1432  		header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
  1433  		buffer.View(ipv6.Payload()).ToVectorisedView(),
  1434  	)
  1435  
  1436  	var rawPayloadHeader header.IPv6RawPayloadHeader
  1437  	for {
  1438  		h, done, err := payloadIterator.Next()
  1439  		if err != nil {
  1440  			t.Errorf("payloadIterator.Next(): %s", err)
  1441  			return
  1442  		}
  1443  		if done {
  1444  			t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
  1445  			return
  1446  		}
  1447  		r, ok := h.(header.IPv6RawPayloadHeader)
  1448  		if ok {
  1449  			rawPayloadHeader = r
  1450  			break
  1451  		}
  1452  	}
  1453  
  1454  	networkHeader := ipv6HeaderWithExtHdr{
  1455  		IPv6:      ipv6,
  1456  		transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
  1457  		payload:   rawPayloadHeader.Buf.ToView(),
  1458  	}
  1459  
  1460  	for _, checker := range checkers {
  1461  		checker(t, []header.Network{&networkHeader})
  1462  	}
  1463  }
  1464  
  1465  // IPv6ExtHdr checks for the presence of extension headers.
  1466  //
  1467  // All the extension headers in headers will be checked exhaustively in the
  1468  // order provided.
  1469  func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
  1470  	return func(t *testing.T, h []header.Network) {
  1471  		t.Helper()
  1472  
  1473  		extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
  1474  		if !ok {
  1475  			t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
  1476  			return
  1477  		}
  1478  
  1479  		payloadIterator := header.MakeIPv6PayloadIterator(
  1480  			header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
  1481  			buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
  1482  		)
  1483  
  1484  		for _, check := range headers {
  1485  			h, done, err := payloadIterator.Next()
  1486  			if err != nil {
  1487  				t.Errorf("payloadIterator.Next(): %s", err)
  1488  				return
  1489  			}
  1490  			if done {
  1491  				t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
  1492  				return
  1493  			}
  1494  			check(t, h)
  1495  		}
  1496  		// Validate we consumed all headers.
  1497  		//
  1498  		// The next one over should be a raw payload and then iterator should
  1499  		// terminate.
  1500  		wantDone := false
  1501  		for {
  1502  			h, done, err := payloadIterator.Next()
  1503  			if err != nil {
  1504  				t.Errorf("payloadIterator.Next(): %s", err)
  1505  				return
  1506  			}
  1507  			if done != wantDone {
  1508  				t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
  1509  				return
  1510  			}
  1511  			if done {
  1512  				break
  1513  			}
  1514  			if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
  1515  				t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
  1516  				continue
  1517  			}
  1518  			wantDone = true
  1519  		}
  1520  	}
  1521  }
  1522  
  1523  var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
  1524  
  1525  // ipv6HeaderWithExtHdr provides a header.Network implementation that takes
  1526  // extension headers into consideration, which is not the case with vanilla
  1527  // header.IPv6.
  1528  type ipv6HeaderWithExtHdr struct {
  1529  	header.IPv6
  1530  	transport tcpip.TransportProtocolNumber
  1531  	payload   []byte
  1532  }
  1533  
  1534  // TransportProtocol implements header.Network.
  1535  func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
  1536  	return h.transport
  1537  }
  1538  
  1539  // Payload implements header.Network.
  1540  func (h *ipv6HeaderWithExtHdr) Payload() []byte {
  1541  	return h.payload
  1542  }
  1543  
  1544  // IPv6ExtHdrOptionChecker is a function to check an extension header option.
  1545  type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
  1546  
  1547  // IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
  1548  // extension header and validates the containing options with checkers.
  1549  //
  1550  // checkers must exhaustively contain all the expected options.
  1551  func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
  1552  	return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
  1553  		t.Helper()
  1554  
  1555  		hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
  1556  		if !ok {
  1557  			t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
  1558  			return
  1559  		}
  1560  		optionsIterator := hbh.Iter()
  1561  		for _, f := range checkers {
  1562  			opt, done, err := optionsIterator.Next()
  1563  			if err != nil {
  1564  				t.Errorf("optionsIterator.Next(): %s", err)
  1565  				return
  1566  			}
  1567  			if done {
  1568  				t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
  1569  			}
  1570  			f(t, opt)
  1571  		}
  1572  		// Validate all options were consumed.
  1573  		for {
  1574  			opt, done, err := optionsIterator.Next()
  1575  			if err != nil {
  1576  				t.Errorf("optionsIterator.Next(): %s", err)
  1577  				return
  1578  			}
  1579  			if !done {
  1580  				t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
  1581  			}
  1582  			if done {
  1583  				break
  1584  			}
  1585  		}
  1586  	}
  1587  }
  1588  
  1589  // IPv6RouterAlert validates that an extension header option is the RouterAlert
  1590  // option and matches on its value.
  1591  func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
  1592  	return func(t *testing.T, opt header.IPv6ExtHdrOption) {
  1593  		routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
  1594  		if !ok {
  1595  			t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
  1596  			return
  1597  		}
  1598  		if routerAlert.Value != want {
  1599  			t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
  1600  		}
  1601  	}
  1602  }
  1603  
  1604  // IPv6UnknownOption validates that an extension header option is the
  1605  // unknown header option.
  1606  func IPv6UnknownOption() IPv6ExtHdrOptionChecker {
  1607  	return func(t *testing.T, opt header.IPv6ExtHdrOption) {
  1608  		_, ok := opt.(*header.IPv6UnknownExtHdrOption)
  1609  		if !ok {
  1610  			t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt)
  1611  		}
  1612  	}
  1613  }
  1614  
  1615  // IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
  1616  func IgnoreCmpPath(paths ...string) cmp.Option {
  1617  	ignores := map[string]struct{}{}
  1618  	for _, path := range paths {
  1619  		ignores[path] = struct{}{}
  1620  	}
  1621  	return cmp.FilterPath(func(path cmp.Path) bool {
  1622  		_, ok := ignores[path.String()]
  1623  		return ok
  1624  	}, cmp.Ignore())
  1625  }