gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/checker/checker.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package checker provides helper functions to check networking packets for
    16  // validity.
    17  package checker
    18  
    19  import (
    20  	"encoding/binary"
    21  	"slices"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"gvisor.dev/gvisor/pkg/buffer"
    27  	"gvisor.dev/gvisor/pkg/tcpip"
    28  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    29  	"gvisor.dev/gvisor/pkg/tcpip/header"
    30  	"gvisor.dev/gvisor/pkg/tcpip/seqnum"
    31  )
    32  
    33  // NetworkChecker is a function to check a property of a network packet.
    34  type NetworkChecker func(*testing.T, []header.Network)
    35  
    36  // TransportChecker is a function to check a property of a transport packet.
    37  type TransportChecker func(*testing.T, header.Transport)
    38  
    39  // ControlMessagesChecker is a function to check a property of ancillary data.
    40  type ControlMessagesChecker func(*testing.T, tcpip.ReceivableControlMessages)
    41  
    42  // IPv4 checks the validity and properties of the given IPv4 packet. It is
    43  // expected to be used in conjunction with other network checkers for specific
    44  // properties. For example, to check the source and destination address, one
    45  // would call:
    46  //
    47  // checker.IPv4(t, v, checker.SrcAddr(x), checker.DstAddr(y))
    48  func IPv4(t *testing.T, v *buffer.View, checkers ...NetworkChecker) {
    49  	t.Helper()
    50  
    51  	ipv4 := header.IPv4(v.AsSlice())
    52  
    53  	if !ipv4.IsValid(len(v.AsSlice())) {
    54  		t.Fatalf("Not a valid IPv4 packet: %x", ipv4)
    55  	}
    56  
    57  	if !ipv4.IsChecksumValid() {
    58  		t.Errorf("Bad checksum, got = %d", ipv4.Checksum())
    59  	}
    60  
    61  	for _, f := range checkers {
    62  		f(t, []header.Network{ipv4})
    63  	}
    64  	if t.Failed() {
    65  		t.FailNow()
    66  	}
    67  }
    68  
    69  // IPv6 checks the validity and properties of the given IPv6 packet. The usage
    70  // is similar to IPv4.
    71  func IPv6(t *testing.T, v *buffer.View, checkers ...NetworkChecker) {
    72  	t.Helper()
    73  
    74  	ipv6 := header.IPv6(v.AsSlice())
    75  	if !ipv6.IsValid(len(v.AsSlice())) {
    76  		t.Fatalf("Not a valid IPv6 packet: %x", ipv6)
    77  	}
    78  
    79  	for _, f := range checkers {
    80  		f(t, []header.Network{ipv6})
    81  	}
    82  	if t.Failed() {
    83  		t.FailNow()
    84  	}
    85  }
    86  
    87  // SrcAddr creates a checker that checks the source address.
    88  func SrcAddr(addr tcpip.Address) NetworkChecker {
    89  	return func(t *testing.T, h []header.Network) {
    90  		t.Helper()
    91  
    92  		if a := h[0].SourceAddress(); a != addr {
    93  			t.Errorf("Bad source address, got %v, want %v", a, addr)
    94  		}
    95  	}
    96  }
    97  
    98  // DstAddr creates a checker that checks the destination address.
    99  func DstAddr(addr tcpip.Address) NetworkChecker {
   100  	return func(t *testing.T, h []header.Network) {
   101  		t.Helper()
   102  
   103  		if a := h[0].DestinationAddress(); a != addr {
   104  			t.Errorf("Bad destination address, got %v, want %v", a, addr)
   105  		}
   106  	}
   107  }
   108  
   109  // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
   110  func TTL(ttl uint8) NetworkChecker {
   111  	return func(t *testing.T, h []header.Network) {
   112  		t.Helper()
   113  
   114  		var v uint8
   115  		switch ip := h[0].(type) {
   116  		case header.IPv4:
   117  			v = ip.TTL()
   118  		case header.IPv6:
   119  			v = ip.HopLimit()
   120  		case *ipv6HeaderWithExtHdr:
   121  			v = ip.HopLimit()
   122  		default:
   123  			t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
   124  		}
   125  		if v != ttl {
   126  			t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
   127  		}
   128  	}
   129  }
   130  
   131  // IPFullLength creates a checker for the full IP packet length. The
   132  // expected size is checked against both the Total Length in the
   133  // header and the number of bytes received.
   134  func IPFullLength(packetLength uint16) NetworkChecker {
   135  	return func(t *testing.T, h []header.Network) {
   136  		t.Helper()
   137  
   138  		var v uint16
   139  		var l uint16
   140  		switch ip := h[0].(type) {
   141  		case header.IPv4:
   142  			v = ip.TotalLength()
   143  			l = uint16(len(ip))
   144  		case header.IPv6:
   145  			v = ip.PayloadLength() + header.IPv6FixedHeaderSize
   146  			l = uint16(len(ip))
   147  		default:
   148  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
   149  		}
   150  		if l != packetLength {
   151  			t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
   152  		}
   153  		if v != packetLength {
   154  			t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
   155  		}
   156  	}
   157  }
   158  
   159  // IPv4HeaderLength creates a checker that checks the IPv4 Header length.
   160  func IPv4HeaderLength(headerLength int) NetworkChecker {
   161  	return func(t *testing.T, h []header.Network) {
   162  		t.Helper()
   163  
   164  		switch ip := h[0].(type) {
   165  		case header.IPv4:
   166  			if hl := ip.HeaderLength(); hl != uint8(headerLength) {
   167  				t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
   168  			}
   169  		default:
   170  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
   171  		}
   172  	}
   173  }
   174  
   175  // PayloadLen creates a checker that checks the payload length.
   176  func PayloadLen(payloadLength int) NetworkChecker {
   177  	return func(t *testing.T, h []header.Network) {
   178  		t.Helper()
   179  
   180  		if l := len(h[0].Payload()); l != payloadLength {
   181  			t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
   182  		}
   183  	}
   184  }
   185  
   186  // IPPayload creates a checker that checks the payload.
   187  func IPPayload(payload []byte) NetworkChecker {
   188  	return func(t *testing.T, h []header.Network) {
   189  		t.Helper()
   190  
   191  		got := h[0].Payload()
   192  
   193  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
   194  		if len(got) == 0 && len(payload) == 0 {
   195  			return
   196  		}
   197  
   198  		if diff := cmp.Diff(payload, got); diff != "" {
   199  			t.Errorf("payload mismatch (-want +got):\n%s", diff)
   200  		}
   201  	}
   202  }
   203  
   204  // IPv4Options returns a checker that checks the options in an IPv4 packet.
   205  func IPv4Options(want header.IPv4Options) NetworkChecker {
   206  	return func(t *testing.T, h []header.Network) {
   207  		t.Helper()
   208  
   209  		ip, ok := h[0].(header.IPv4)
   210  		if !ok {
   211  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
   212  		}
   213  		options := ip.Options()
   214  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
   215  		if len(want) == 0 && len(options) == 0 {
   216  			return
   217  		}
   218  		if diff := cmp.Diff(want, options); diff != "" {
   219  			t.Errorf("options mismatch (-want +got):\n%s", diff)
   220  		}
   221  	}
   222  }
   223  
   224  // IPv4RouterAlert returns a checker that checks that the RouterAlert option is
   225  // set in an IPv4 packet.
   226  func IPv4RouterAlert() NetworkChecker {
   227  	return func(t *testing.T, h []header.Network) {
   228  		t.Helper()
   229  		ip, ok := h[0].(header.IPv4)
   230  		if !ok {
   231  			t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
   232  		}
   233  		iterator := ip.Options().MakeIterator()
   234  		for {
   235  			opt, done, err := iterator.Next()
   236  			if err != nil {
   237  				t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer)
   238  			}
   239  			if done {
   240  				break
   241  			}
   242  			if opt.Type() != header.IPv4OptionRouterAlertType {
   243  				continue
   244  			}
   245  			want := [header.IPv4OptionRouterAlertLength]byte{
   246  				byte(header.IPv4OptionRouterAlertType),
   247  				header.IPv4OptionRouterAlertLength,
   248  				header.IPv4OptionRouterAlertValue,
   249  				header.IPv4OptionRouterAlertValue,
   250  			}
   251  			if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
   252  				t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
   253  			}
   254  			return
   255  		}
   256  		t.Errorf("failed to find router alert option in %v", ip.Options())
   257  	}
   258  }
   259  
   260  // FragmentOffset creates a checker that checks the FragmentOffset field.
   261  func FragmentOffset(offset uint16) NetworkChecker {
   262  	return func(t *testing.T, h []header.Network) {
   263  		t.Helper()
   264  
   265  		// We only do this for IPv4 for now.
   266  		switch ip := h[0].(type) {
   267  		case header.IPv4:
   268  			if v := ip.FragmentOffset(); v != offset {
   269  				t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
   270  			}
   271  		}
   272  	}
   273  }
   274  
   275  // FragmentFlags creates a checker that checks the fragment flags field.
   276  func FragmentFlags(flags uint8) NetworkChecker {
   277  	return func(t *testing.T, h []header.Network) {
   278  		t.Helper()
   279  
   280  		// We only do this for IPv4 for now.
   281  		switch ip := h[0].(type) {
   282  		case header.IPv4:
   283  			if v := ip.Flags(); v != flags {
   284  				t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
   285  			}
   286  		}
   287  	}
   288  }
   289  
   290  // ReceiveTClass creates a checker that checks the TCLASS field in
   291  // ControlMessages.
   292  func ReceiveTClass(want uint32) ControlMessagesChecker {
   293  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   294  		t.Helper()
   295  		if !cm.HasTClass {
   296  			t.Error("got cm.HasTClass = false, want = true")
   297  		} else if got := cm.TClass; got != want {
   298  			t.Errorf("got cm.TClass = %d, want %d", got, want)
   299  		}
   300  	}
   301  }
   302  
   303  // NoTClassReceived creates a checker that checks the absence of the TCLASS
   304  // field in ControlMessages.
   305  func NoTClassReceived() ControlMessagesChecker {
   306  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   307  		t.Helper()
   308  		if cm.HasTClass {
   309  			t.Error("got cm.HasTClass = true, want = false")
   310  		}
   311  	}
   312  }
   313  
   314  // ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
   315  func ReceiveTOS(want uint8) ControlMessagesChecker {
   316  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   317  		t.Helper()
   318  		if !cm.HasTOS {
   319  			t.Error("got cm.HasTOS = false, want = true")
   320  		} else if got := cm.TOS; got != want {
   321  			t.Errorf("got cm.TOS = %d, want %d", got, want)
   322  		}
   323  	}
   324  }
   325  
   326  // NoTOSReceived creates a checker that checks the absence of the TOS field in
   327  // ControlMessages.
   328  func NoTOSReceived() ControlMessagesChecker {
   329  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   330  		t.Helper()
   331  		if cm.HasTOS {
   332  			t.Error("got cm.HasTOS = true, want = false")
   333  		}
   334  	}
   335  }
   336  
   337  // ReceiveTTL creates a checker that checks the TTL field in
   338  // ControlMessages.
   339  func ReceiveTTL(want uint8) ControlMessagesChecker {
   340  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   341  		t.Helper()
   342  		if !cm.HasTTL {
   343  			t.Errorf("got cm.HasTTL = %t, want = true", cm.HasTTL)
   344  		} else if got := cm.TTL; got != want {
   345  			t.Errorf("got cm.TTL = %d, want = %d", got, want)
   346  		}
   347  	}
   348  }
   349  
   350  // NoTTLReceived creates a checker that checks the absence of the TTL field in
   351  // ControlMessages.
   352  func NoTTLReceived() ControlMessagesChecker {
   353  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   354  		t.Helper()
   355  		if cm.HasTTL {
   356  			t.Error("got cm.HasTTL = true, want = false")
   357  		}
   358  	}
   359  }
   360  
   361  // ReceiveHopLimit creates a checker that checks the HopLimit field in
   362  // ControlMessages.
   363  func ReceiveHopLimit(want uint8) ControlMessagesChecker {
   364  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   365  		t.Helper()
   366  		if !cm.HasHopLimit {
   367  			t.Errorf("got cm.HasHopLimit = %t, want = true", cm.HasHopLimit)
   368  		} else if got := cm.HopLimit; got != want {
   369  			t.Errorf("got cm.HopLimit = %d, want = %d", got, want)
   370  		}
   371  	}
   372  }
   373  
   374  // NoHopLimitReceived creates a checker that checks the absence of the HopLimit
   375  // field in ControlMessages.
   376  func NoHopLimitReceived() ControlMessagesChecker {
   377  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   378  		t.Helper()
   379  		if cm.HasHopLimit {
   380  			t.Error("got cm.HasHopLimit = true, want = false")
   381  		}
   382  	}
   383  }
   384  
   385  // ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
   386  // ControlMessages.
   387  func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
   388  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   389  		t.Helper()
   390  		if !cm.HasIPPacketInfo {
   391  			t.Error("got cm.HasIPPacketInfo = false, want = true")
   392  		} else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
   393  			t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
   394  		}
   395  	}
   396  }
   397  
   398  // NoIPPacketInfoReceived creates a checker that checks the PacketInfo field in
   399  // ControlMessages.
   400  func NoIPPacketInfoReceived() ControlMessagesChecker {
   401  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   402  		t.Helper()
   403  		if cm.HasIPPacketInfo {
   404  			t.Error("got cm.HasIPPacketInfo = true, want = false")
   405  		}
   406  	}
   407  }
   408  
   409  // ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
   410  // in ControlMessages.
   411  func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
   412  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   413  		t.Helper()
   414  		if !cm.HasIPv6PacketInfo {
   415  			t.Error("got cm.HasIPv6PacketInfo = false, want = true")
   416  		} else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
   417  			t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
   418  		}
   419  	}
   420  }
   421  
   422  // NoIPv6PacketInfoReceived creates a checker that checks the PacketInfo field
   423  // in ControlMessages.
   424  func NoIPv6PacketInfoReceived() ControlMessagesChecker {
   425  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   426  		t.Helper()
   427  		if cm.HasIPv6PacketInfo {
   428  			t.Error("got cm.HasIPv6PacketInfo = true, want = false")
   429  		}
   430  	}
   431  }
   432  
   433  // ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
   434  // field in ControlMessages.
   435  func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
   436  	return func(t *testing.T, cm tcpip.ReceivableControlMessages) {
   437  		t.Helper()
   438  		if !cm.HasOriginalDstAddress {
   439  			t.Error("got cm.HasOriginalDstAddress = false, want = true")
   440  		} else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
   441  			t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
   442  		}
   443  	}
   444  }
   445  
   446  // TOS creates a checker that checks the TOS field.
   447  func TOS(tos uint8, label uint32) NetworkChecker {
   448  	return func(t *testing.T, h []header.Network) {
   449  		t.Helper()
   450  
   451  		if v, l := h[0].TOS(); v != tos || l != label {
   452  			t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
   453  		}
   454  	}
   455  }
   456  
   457  // Raw creates a checker that checks the bytes of payload.
   458  // The checker always checks the payload of the last network header.
   459  // For instance, in case of IPv6 fragments, the payload that will be checked
   460  // is the one containing the actual data that the packet is carrying, without
   461  // the bytes added by the IPv6 fragmentation.
   462  func Raw(want []byte) NetworkChecker {
   463  	return func(t *testing.T, h []header.Network) {
   464  		t.Helper()
   465  
   466  		if got := h[len(h)-1].Payload(); !slices.Equal(got, want) {
   467  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   468  		}
   469  	}
   470  }
   471  
   472  // IPv6Fragment creates a checker that validates an IPv6 fragment.
   473  func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
   474  	return func(t *testing.T, h []header.Network) {
   475  		t.Helper()
   476  
   477  		if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
   478  			t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
   479  		}
   480  
   481  		ipv6Frag := header.IPv6Fragment(h[0].Payload())
   482  		if !ipv6Frag.IsValid() {
   483  			t.Error("Not a valid IPv6 fragment")
   484  		}
   485  
   486  		for _, f := range checkers {
   487  			f(t, []header.Network{h[0], ipv6Frag})
   488  		}
   489  		if t.Failed() {
   490  			t.FailNow()
   491  		}
   492  	}
   493  }
   494  
   495  // TCP creates a checker that checks that the transport protocol is TCP and
   496  // potentially additional transport header fields.
   497  func TCP(checkers ...TransportChecker) NetworkChecker {
   498  	return func(t *testing.T, h []header.Network) {
   499  		t.Helper()
   500  
   501  		first := h[0]
   502  		last := h[len(h)-1]
   503  
   504  		if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
   505  			t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
   506  		}
   507  
   508  		tcp := header.TCP(last.Payload())
   509  		payload := tcp.Payload()
   510  		payloadChecksum := checksum.Checksum(payload, 0)
   511  		if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) {
   512  			t.Errorf("Bad checksum, got = %d", tcp.Checksum())
   513  		}
   514  
   515  		// Run the transport checkers.
   516  		for _, f := range checkers {
   517  			f(t, tcp)
   518  		}
   519  		if t.Failed() {
   520  			t.FailNow()
   521  		}
   522  	}
   523  }
   524  
   525  // UDP creates a checker that checks that the transport protocol is UDP and
   526  // potentially additional transport header fields.
   527  func UDP(checkers ...TransportChecker) NetworkChecker {
   528  	return func(t *testing.T, h []header.Network) {
   529  		t.Helper()
   530  
   531  		last := h[len(h)-1]
   532  
   533  		if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
   534  			t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
   535  		}
   536  
   537  		udp := header.UDP(last.Payload())
   538  		for _, f := range checkers {
   539  			f(t, udp)
   540  		}
   541  		if t.Failed() {
   542  			t.FailNow()
   543  		}
   544  	}
   545  }
   546  
   547  // SrcPort creates a checker that checks the source port.
   548  func SrcPort(port uint16) TransportChecker {
   549  	return func(t *testing.T, h header.Transport) {
   550  		t.Helper()
   551  
   552  		if p := h.SourcePort(); p != port {
   553  			t.Errorf("Bad source port, got = %d, want = %d", p, port)
   554  		}
   555  	}
   556  }
   557  
   558  // DstPort creates a checker that checks the destination port.
   559  func DstPort(port uint16) TransportChecker {
   560  	return func(t *testing.T, h header.Transport) {
   561  		t.Helper()
   562  
   563  		if p := h.DestinationPort(); p != port {
   564  			t.Errorf("Bad destination port, got = %d, want = %d", p, port)
   565  		}
   566  	}
   567  }
   568  
   569  // TransportChecksum creates a checker that checks the checksum value.
   570  func TransportChecksum(want uint16) TransportChecker {
   571  	return func(t *testing.T, transportHdr header.Transport) {
   572  		t.Helper()
   573  
   574  		if got := transportHdr.Checksum(); got != want {
   575  			t.Errorf("got transportHdr.Checksum() = %d, want = %d", got, want)
   576  		}
   577  	}
   578  }
   579  
   580  // NoChecksum creates a checker that checks if the checksum is zero.
   581  func NoChecksum(noChecksum bool) TransportChecker {
   582  	return func(t *testing.T, h header.Transport) {
   583  		t.Helper()
   584  
   585  		udp, ok := h.(header.UDP)
   586  		if !ok {
   587  			t.Fatalf("UDP header not found in h: %T", h)
   588  		}
   589  
   590  		if b := udp.Checksum() == 0; b != noChecksum {
   591  			t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
   592  		}
   593  	}
   594  }
   595  
   596  // TCPSeqNum creates a checker that checks the sequence number.
   597  func TCPSeqNum(seq uint32) TransportChecker {
   598  	return func(t *testing.T, h header.Transport) {
   599  		t.Helper()
   600  
   601  		tcp, ok := h.(header.TCP)
   602  		if !ok {
   603  			t.Fatalf("TCP header not found in h: %T", h)
   604  		}
   605  
   606  		if s := tcp.SequenceNumber(); s != seq {
   607  			t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
   608  		}
   609  	}
   610  }
   611  
   612  // TCPAckNum creates a checker that checks the ack number.
   613  func TCPAckNum(seq uint32) TransportChecker {
   614  	return func(t *testing.T, h header.Transport) {
   615  		t.Helper()
   616  
   617  		tcp, ok := h.(header.TCP)
   618  		if !ok {
   619  			t.Fatalf("TCP header not found in h: %T", h)
   620  		}
   621  
   622  		if s := tcp.AckNumber(); s != seq {
   623  			t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
   624  		}
   625  	}
   626  }
   627  
   628  // TCPWindow creates a checker that checks the tcp window.
   629  func TCPWindow(window uint16) TransportChecker {
   630  	return func(t *testing.T, h header.Transport) {
   631  		t.Helper()
   632  
   633  		tcp, ok := h.(header.TCP)
   634  		if !ok {
   635  			t.Fatalf("TCP header not found in hdr : %T", h)
   636  		}
   637  
   638  		if w := tcp.WindowSize(); w != window {
   639  			t.Errorf("Bad window, got %d, want %d", w, window)
   640  		}
   641  	}
   642  }
   643  
   644  // TCPWindowGreaterThanEq creates a checker that checks that the TCP window
   645  // is greater than or equal to the provided value.
   646  func TCPWindowGreaterThanEq(window uint16) TransportChecker {
   647  	return func(t *testing.T, h header.Transport) {
   648  		t.Helper()
   649  
   650  		tcp, ok := h.(header.TCP)
   651  		if !ok {
   652  			t.Fatalf("TCP header not found in h: %T", h)
   653  		}
   654  
   655  		if w := tcp.WindowSize(); w < window {
   656  			t.Errorf("Bad window, got %d, want > %d", w, window)
   657  		}
   658  	}
   659  }
   660  
   661  // TCPWindowLessThanEq creates a checker that checks that the tcp window
   662  // is less than or equal to the provided value.
   663  func TCPWindowLessThanEq(window uint16) TransportChecker {
   664  	return func(t *testing.T, h header.Transport) {
   665  		t.Helper()
   666  
   667  		tcp, ok := h.(header.TCP)
   668  		if !ok {
   669  			t.Fatalf("TCP header not found in h: %T", h)
   670  		}
   671  
   672  		if w := tcp.WindowSize(); w > window {
   673  			t.Errorf("Bad window, got %d, want < %d", w, window)
   674  		}
   675  	}
   676  }
   677  
   678  // TCPFlags creates a checker that checks the tcp flags.
   679  func TCPFlags(flags header.TCPFlags) TransportChecker {
   680  	return func(t *testing.T, h header.Transport) {
   681  		t.Helper()
   682  
   683  		tcp, ok := h.(header.TCP)
   684  		if !ok {
   685  			t.Fatalf("TCP header not found in h: %T", h)
   686  		}
   687  
   688  		if got := tcp.Flags(); got != flags {
   689  			t.Errorf("got tcp.Flags() = %s, want %s", got, flags)
   690  		}
   691  	}
   692  }
   693  
   694  // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
   695  // given mask, match the supplied flags.
   696  func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker {
   697  	return func(t *testing.T, h header.Transport) {
   698  		t.Helper()
   699  
   700  		tcp, ok := h.(header.TCP)
   701  		if !ok {
   702  			t.Fatalf("TCP header not found in h: %T", h)
   703  		}
   704  
   705  		if got := tcp.Flags(); (got & mask) != (flags & mask) {
   706  			t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask)
   707  		}
   708  	}
   709  }
   710  
   711  // TCPSynOptions creates a checker that checks the presence of TCP options in
   712  // SYN segments.
   713  //
   714  // If wndscale is negative, the window scale option must not be present.
   715  func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
   716  	return func(t *testing.T, h header.Transport) {
   717  		t.Helper()
   718  
   719  		tcp, ok := h.(header.TCP)
   720  		if !ok {
   721  			return
   722  		}
   723  		opts := tcp.Options()
   724  		limit := len(opts)
   725  		foundMSS := false
   726  		foundWS := false
   727  		foundTS := false
   728  		foundSACKPermitted := false
   729  		tsVal := uint32(0)
   730  		tsEcr := uint32(0)
   731  		for i := 0; i < limit; {
   732  			switch opts[i] {
   733  			case header.TCPOptionEOL:
   734  				i = limit
   735  			case header.TCPOptionNOP:
   736  				i++
   737  			case header.TCPOptionMSS:
   738  				v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
   739  				if wantOpts.MSS != v {
   740  					t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
   741  				}
   742  				foundMSS = true
   743  				i += 4
   744  			case header.TCPOptionWS:
   745  				if wantOpts.WS < 0 {
   746  					t.Error("WS present when it shouldn't be")
   747  				}
   748  				v := int(opts[i+2])
   749  				if v != wantOpts.WS {
   750  					t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
   751  				}
   752  				foundWS = true
   753  				i += 3
   754  			case header.TCPOptionTS:
   755  				if i+9 >= limit {
   756  					t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
   757  				}
   758  				if opts[i+1] != 10 {
   759  					t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
   760  				}
   761  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   762  				tsEcr = uint32(0)
   763  				if tcp.Flags()&header.TCPFlagAck != 0 {
   764  					// If the syn is an SYN-ACK then read
   765  					// the tsEcr value as well.
   766  					tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   767  				}
   768  				foundTS = true
   769  				i += 10
   770  			case header.TCPOptionSACKPermitted:
   771  				if i+1 >= limit {
   772  					t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
   773  				}
   774  				if opts[i+1] != 2 {
   775  					t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
   776  				}
   777  				foundSACKPermitted = true
   778  				i += 2
   779  
   780  			default:
   781  				i += int(opts[i+1])
   782  			}
   783  		}
   784  
   785  		if !foundMSS {
   786  			t.Errorf("MSS option not found. Options: %x", opts)
   787  		}
   788  
   789  		if !foundWS && wantOpts.WS >= 0 {
   790  			t.Errorf("WS option not found. Options: %x", opts)
   791  		}
   792  		if wantOpts.TS && !foundTS {
   793  			t.Errorf("TS option not found. Options: %x", opts)
   794  		}
   795  		if foundTS && tsVal == 0 {
   796  			t.Error("TS option specified but the timestamp value is zero")
   797  		}
   798  		if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
   799  			t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
   800  		}
   801  		if wantOpts.SACKPermitted && !foundSACKPermitted {
   802  			t.Errorf("SACKPermitted option not found. Options: %x", opts)
   803  		}
   804  	}
   805  }
   806  
   807  // TCPTimestampChecker creates a checker that validates that a TCP segment has a
   808  // TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
   809  // wantTSEcr values with those in the TCP segment (if present).
   810  //
   811  // If wantTSVal or wantTSEcr is zero then the corresponding comparison is
   812  // skipped.
   813  func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
   814  	return func(t *testing.T, h header.Transport) {
   815  		t.Helper()
   816  
   817  		tcp, ok := h.(header.TCP)
   818  		if !ok {
   819  			return
   820  		}
   821  		opts := tcp.Options()
   822  		limit := len(opts)
   823  		foundTS := false
   824  		tsVal := uint32(0)
   825  		tsEcr := uint32(0)
   826  		for i := 0; i < limit; {
   827  			switch opts[i] {
   828  			case header.TCPOptionEOL:
   829  				i = limit
   830  			case header.TCPOptionNOP:
   831  				i++
   832  			case header.TCPOptionTS:
   833  				if i+9 >= limit {
   834  					t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
   835  				}
   836  				if opts[i+1] != 10 {
   837  					t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
   838  				}
   839  				tsVal = binary.BigEndian.Uint32(opts[i+2:])
   840  				tsEcr = binary.BigEndian.Uint32(opts[i+6:])
   841  				foundTS = true
   842  				i += 10
   843  			default:
   844  				// We don't recognize this option, just skip over it.
   845  				if i+2 > limit {
   846  					return
   847  				}
   848  				l := int(opts[i+1])
   849  				if l < 2 || i+l > limit {
   850  					return
   851  				}
   852  				i += l
   853  			}
   854  		}
   855  
   856  		if wantTS != foundTS {
   857  			t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
   858  		}
   859  		if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
   860  			t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
   861  		}
   862  		if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
   863  			t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
   864  		}
   865  	}
   866  }
   867  
   868  // TCPSACKBlockChecker creates a checker that verifies that the segment does
   869  // contain the specified SACK blocks in the TCP options.
   870  func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
   871  	return func(t *testing.T, h header.Transport) {
   872  		t.Helper()
   873  		tcp, ok := h.(header.TCP)
   874  		if !ok {
   875  			return
   876  		}
   877  		var gotSACKBlocks []header.SACKBlock
   878  
   879  		opts := tcp.Options()
   880  		limit := len(opts)
   881  		for i := 0; i < limit; {
   882  			switch opts[i] {
   883  			case header.TCPOptionEOL:
   884  				i = limit
   885  			case header.TCPOptionNOP:
   886  				i++
   887  			case header.TCPOptionSACK:
   888  				if i+2 > limit {
   889  					// Malformed SACK block.
   890  					t.Errorf("malformed SACK option in options: %v", opts)
   891  				}
   892  				sackOptionLen := int(opts[i+1])
   893  				if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
   894  					// Malformed SACK block.
   895  					t.Errorf("malformed SACK option length in options: %v", opts)
   896  				}
   897  				numBlocks := sackOptionLen / 8
   898  				for j := 0; j < numBlocks; j++ {
   899  					start := binary.BigEndian.Uint32(opts[i+2+j*8:])
   900  					end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
   901  					gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
   902  						Start: seqnum.Value(start),
   903  						End:   seqnum.Value(end),
   904  					})
   905  				}
   906  				i += sackOptionLen
   907  			default:
   908  				// We don't recognize this option, just skip over it.
   909  				if i+2 > limit {
   910  					break
   911  				}
   912  				l := int(opts[i+1])
   913  				if l < 2 || i+l > limit {
   914  					break
   915  				}
   916  				i += l
   917  			}
   918  		}
   919  
   920  		if !slices.Equal(gotSACKBlocks, sackBlocks) {
   921  			t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
   922  		}
   923  	}
   924  }
   925  
   926  // Payload creates a checker that checks the payload.
   927  func Payload(want []byte) TransportChecker {
   928  	return func(t *testing.T, h header.Transport) {
   929  		t.Helper()
   930  
   931  		if got := h.Payload(); !slices.Equal(got, want) {
   932  			t.Errorf("Wrong payload, got %v, want %v", got, want)
   933  		}
   934  	}
   935  }
   936  
   937  // ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
   938  // and potentially additional ICMPv4 header fields.
   939  func ICMPv4(checkers ...TransportChecker) NetworkChecker {
   940  	return func(t *testing.T, h []header.Network) {
   941  		t.Helper()
   942  
   943  		last := h[len(h)-1]
   944  
   945  		if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
   946  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
   947  		}
   948  
   949  		icmp := header.ICMPv4(last.Payload())
   950  		for _, f := range checkers {
   951  			f(t, icmp)
   952  		}
   953  		if t.Failed() {
   954  			t.FailNow()
   955  		}
   956  	}
   957  }
   958  
   959  // ICMPv4Type creates a checker that checks the ICMPv4 Type field.
   960  func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
   961  	return func(t *testing.T, h header.Transport) {
   962  		t.Helper()
   963  
   964  		icmpv4, ok := h.(header.ICMPv4)
   965  		if !ok {
   966  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   967  		}
   968  		if got := icmpv4.Type(); got != want {
   969  			t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
   970  		}
   971  	}
   972  }
   973  
   974  // ICMPv4Code creates a checker that checks the ICMPv4 Code field.
   975  func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
   976  	return func(t *testing.T, h header.Transport) {
   977  		t.Helper()
   978  
   979  		icmpv4, ok := h.(header.ICMPv4)
   980  		if !ok {
   981  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   982  		}
   983  		if got := icmpv4.Code(); got != want {
   984  			t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
   985  		}
   986  	}
   987  }
   988  
   989  // ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
   990  func ICMPv4Ident(want uint16) TransportChecker {
   991  	return func(t *testing.T, h header.Transport) {
   992  		t.Helper()
   993  
   994  		icmpv4, ok := h.(header.ICMPv4)
   995  		if !ok {
   996  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
   997  		}
   998  		if got := icmpv4.Ident(); got != want {
   999  			t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
  1000  		}
  1001  	}
  1002  }
  1003  
  1004  // ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
  1005  func ICMPv4Seq(want uint16) TransportChecker {
  1006  	return func(t *testing.T, h header.Transport) {
  1007  		t.Helper()
  1008  
  1009  		icmpv4, ok := h.(header.ICMPv4)
  1010  		if !ok {
  1011  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
  1012  		}
  1013  		if got := icmpv4.Sequence(); got != want {
  1014  			t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
  1015  		}
  1016  	}
  1017  }
  1018  
  1019  // ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
  1020  func ICMPv4Pointer(want uint8) TransportChecker {
  1021  	return func(t *testing.T, h header.Transport) {
  1022  		t.Helper()
  1023  
  1024  		icmpv4, ok := h.(header.ICMPv4)
  1025  		if !ok {
  1026  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
  1027  		}
  1028  		if got := icmpv4.Pointer(); got != want {
  1029  			t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
  1030  		}
  1031  	}
  1032  }
  1033  
  1034  // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
  1035  // This assumes that the payload exactly makes up the rest of the slice.
  1036  func ICMPv4Checksum() TransportChecker {
  1037  	return func(t *testing.T, h header.Transport) {
  1038  		t.Helper()
  1039  
  1040  		icmpv4, ok := h.(header.ICMPv4)
  1041  		if !ok {
  1042  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
  1043  		}
  1044  		heldChecksum := icmpv4.Checksum()
  1045  		icmpv4.SetChecksum(0)
  1046  		newChecksum := ^checksum.Checksum(icmpv4, 0)
  1047  		icmpv4.SetChecksum(heldChecksum)
  1048  		if heldChecksum != newChecksum {
  1049  			t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
  1050  		}
  1051  	}
  1052  }
  1053  
  1054  // ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
  1055  func ICMPv4Payload(want []byte) TransportChecker {
  1056  	return func(t *testing.T, h header.Transport) {
  1057  		t.Helper()
  1058  
  1059  		icmpv4, ok := h.(header.ICMPv4)
  1060  		if !ok {
  1061  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
  1062  		}
  1063  		payload := icmpv4.Payload()
  1064  
  1065  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
  1066  		if len(want) == 0 && len(payload) == 0 {
  1067  			return
  1068  		}
  1069  
  1070  		if diff := cmp.Diff(want, payload); diff != "" {
  1071  			t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
  1072  		}
  1073  	}
  1074  }
  1075  
  1076  // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
  1077  // potentially additional ICMPv6 header fields.
  1078  //
  1079  // ICMPv6 will validate the checksum field before calling checkers.
  1080  func ICMPv6(checkers ...TransportChecker) NetworkChecker {
  1081  	return func(t *testing.T, h []header.Network) {
  1082  		t.Helper()
  1083  
  1084  		last := h[len(h)-1]
  1085  
  1086  		if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
  1087  			t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
  1088  		}
  1089  
  1090  		icmp := header.ICMPv6(last.Payload())
  1091  		if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
  1092  			Header: icmp,
  1093  			Src:    last.SourceAddress(),
  1094  			Dst:    last.DestinationAddress(),
  1095  		}); got != want {
  1096  			t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
  1097  		}
  1098  
  1099  		for _, f := range checkers {
  1100  			f(t, icmp)
  1101  		}
  1102  		if t.Failed() {
  1103  			t.FailNow()
  1104  		}
  1105  	}
  1106  }
  1107  
  1108  // ICMPv6Type creates a checker that checks the ICMPv6 Type field.
  1109  func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
  1110  	return func(t *testing.T, h header.Transport) {
  1111  		t.Helper()
  1112  
  1113  		icmpv6, ok := h.(header.ICMPv6)
  1114  		if !ok {
  1115  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1116  		}
  1117  		if got := icmpv6.Type(); got != want {
  1118  			t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
  1119  		}
  1120  	}
  1121  }
  1122  
  1123  // ICMPv6Code creates a checker that checks the ICMPv6 Code field.
  1124  func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
  1125  	return func(t *testing.T, h header.Transport) {
  1126  		t.Helper()
  1127  
  1128  		icmpv6, ok := h.(header.ICMPv6)
  1129  		if !ok {
  1130  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1131  		}
  1132  		if got := icmpv6.Code(); got != want {
  1133  			t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
  1134  		}
  1135  	}
  1136  }
  1137  
  1138  // ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
  1139  // field.
  1140  func ICMPv6TypeSpecific(want uint32) TransportChecker {
  1141  	return func(t *testing.T, h header.Transport) {
  1142  		t.Helper()
  1143  
  1144  		icmpv6, ok := h.(header.ICMPv6)
  1145  		if !ok {
  1146  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1147  		}
  1148  		if got := icmpv6.TypeSpecific(); got != want {
  1149  			t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
  1150  		}
  1151  	}
  1152  }
  1153  
  1154  // ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
  1155  func ICMPv6Payload(want []byte) TransportChecker {
  1156  	return func(t *testing.T, h header.Transport) {
  1157  		t.Helper()
  1158  
  1159  		icmpv6, ok := h.(header.ICMPv6)
  1160  		if !ok {
  1161  			t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
  1162  		}
  1163  		payload := icmpv6.Payload()
  1164  
  1165  		// cmp.Diff does not consider nil slices equal to empty slices, but we do.
  1166  		if len(want) == 0 && len(payload) == 0 {
  1167  			return
  1168  		}
  1169  
  1170  		if diff := cmp.Diff(want, payload); diff != "" {
  1171  			t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
  1172  		}
  1173  	}
  1174  }
  1175  
  1176  // MLD creates a checker that checks that the packet contains a valid MLD
  1177  // message for type of mldType, with potentially additional checks specified by
  1178  // checkers.
  1179  //
  1180  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1181  // MLD message as far as the size of the message (minSize) is concerned. The
  1182  // values within the message are up to checkers to validate.
  1183  func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
  1184  	return func(t *testing.T, h []header.Network) {
  1185  		t.Helper()
  1186  
  1187  		// Check normal ICMPv6 first.
  1188  		ICMPv6(
  1189  			ICMPv6Type(msgType),
  1190  			ICMPv6Code(0))(t, h)
  1191  
  1192  		last := h[len(h)-1]
  1193  
  1194  		icmp := header.ICMPv6(last.Payload())
  1195  		if got := len(icmp.MessageBody()); got < minSize {
  1196  			t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
  1197  		}
  1198  
  1199  		for _, f := range checkers {
  1200  			f(t, icmp)
  1201  		}
  1202  		if t.Failed() {
  1203  			t.FailNow()
  1204  		}
  1205  	}
  1206  }
  1207  
  1208  // MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
  1209  // field of a MLD message.
  1210  //
  1211  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1212  // containing a valid MLD message as far as the size is concerned.
  1213  func MLDMaxRespDelay(want time.Duration) TransportChecker {
  1214  	return func(t *testing.T, h header.Transport) {
  1215  		t.Helper()
  1216  
  1217  		icmp := h.(header.ICMPv6)
  1218  		ns := header.MLD(icmp.MessageBody())
  1219  
  1220  		if got := ns.MaximumResponseDelay(); got != want {
  1221  			t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
  1222  		}
  1223  	}
  1224  }
  1225  
  1226  // MLDMulticastAddressUnordered creates a checker that checks that the multicast
  1227  // address in the MLD message is expected to be seen.
  1228  //
  1229  // The seen address is removed from the expected groups map.
  1230  //
  1231  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1232  // containing a valid MLD message as far as the size is concerned.
  1233  func MLDMulticastAddressUnordered(expectedGroups map[tcpip.Address]struct{}) TransportChecker {
  1234  	return func(t *testing.T, h header.Transport) {
  1235  		t.Helper()
  1236  
  1237  		icmp := h.(header.ICMPv6)
  1238  		ns := header.MLD(icmp.MessageBody())
  1239  
  1240  		addr := ns.MulticastAddress()
  1241  
  1242  		if _, ok := expectedGroups[addr]; !ok {
  1243  			t.Errorf("unexpected multicast group %s", addr)
  1244  		} else {
  1245  			delete(expectedGroups, addr)
  1246  		}
  1247  	}
  1248  }
  1249  
  1250  // MLDMulticastAddress creates a checker that checks the Multicast Address
  1251  // field of a MLD message.
  1252  //
  1253  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1254  // containing a valid MLD message as far as the size is concerned.
  1255  func MLDMulticastAddress(want tcpip.Address) TransportChecker {
  1256  	return MLDMulticastAddressUnordered(map[tcpip.Address]struct{}{
  1257  		want: struct{}{},
  1258  	})
  1259  }
  1260  
  1261  // MLDv2Report creates a checker that checks that the packet contains a valid
  1262  // MLDv2 report with the specified records.
  1263  //
  1264  // Note that observed records are removed from expectedRecords. No error is
  1265  // logged if the report does not have all the records expected.
  1266  func MLDv2Report(expectedRecords map[tcpip.Address]header.MLDv2ReportRecordType) NetworkChecker {
  1267  	return func(t *testing.T, h []header.Network) {
  1268  		t.Helper()
  1269  
  1270  		// Check normal ICMPv6 first.
  1271  		ICMPv6(
  1272  			ICMPv6Type(header.ICMPv6MulticastListenerV2Report),
  1273  			ICMPv6Code(0))(t, h)
  1274  
  1275  		last := h[len(h)-1]
  1276  		icmp := header.ICMPv6(last.Payload())
  1277  		report := header.MLDv2Report(icmp.MessageBody())
  1278  
  1279  		records := report.MulticastAddressRecords()
  1280  		for len(expectedRecords) != 0 {
  1281  			record, res := records.Next()
  1282  			switch res {
  1283  			case header.MLDv2ReportMulticastAddressRecordIteratorNextOk:
  1284  			case header.MLDv2ReportMulticastAddressRecordIteratorNextDone:
  1285  				return
  1286  			default:
  1287  				t.Fatalf("unhandled res = %d", res)
  1288  			}
  1289  
  1290  			addr := record.MulticastAddress()
  1291  			expectedRecordType, ok := expectedRecords[addr]
  1292  			if !ok {
  1293  				t.Errorf("unexpected record for address %s", addr)
  1294  				continue
  1295  			}
  1296  
  1297  			if got, want := record.RecordType(), expectedRecordType; got != want {
  1298  				t.Errorf("got record.RecordType() = %d, want = %d", got, want)
  1299  			}
  1300  
  1301  			if got := record.AuxDataLen(); got != 0 {
  1302  				t.Errorf("got record.AuxDataLen() = %d, want = 0", got)
  1303  			}
  1304  
  1305  			sources, ok := record.Sources()
  1306  			if !ok {
  1307  				t.Error("got record.Sources() = (_, false), want = (_, true)")
  1308  				continue
  1309  			}
  1310  
  1311  			if source, ok := sources.Next(); ok {
  1312  				t.Fatalf("got sources.Next() = (%s, true), want = (_, false)", source)
  1313  			}
  1314  
  1315  			delete(expectedRecords, addr)
  1316  		}
  1317  
  1318  		if record, res := records.Next(); res != header.MLDv2ReportMulticastAddressRecordIteratorNextDone {
  1319  			t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, header.MLDv2ReportMulticastAddressRecordIteratorNextDone)
  1320  		}
  1321  	}
  1322  }
  1323  
  1324  // NDP creates a checker that checks that the packet contains a valid NDP
  1325  // message for type of ty, with potentially additional checks specified by
  1326  // checkers.
  1327  //
  1328  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1329  // NDP message as far as the size of the message (minSize) is concerned. The
  1330  // values within the message are up to checkers to validate.
  1331  func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
  1332  	return func(t *testing.T, h []header.Network) {
  1333  		t.Helper()
  1334  
  1335  		// Check normal ICMPv6 first.
  1336  		ICMPv6(
  1337  			ICMPv6Type(msgType),
  1338  			ICMPv6Code(0))(t, h)
  1339  
  1340  		last := h[len(h)-1]
  1341  
  1342  		icmp := header.ICMPv6(last.Payload())
  1343  		if got := len(icmp.MessageBody()); got < minSize {
  1344  			t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
  1345  		}
  1346  
  1347  		for _, f := range checkers {
  1348  			f(t, icmp)
  1349  		}
  1350  		if t.Failed() {
  1351  			t.FailNow()
  1352  		}
  1353  	}
  1354  }
  1355  
  1356  // NDPNS creates a checker that checks that the packet contains a valid NDP
  1357  // Neighbor Solicitation message (as per the raw wire format), with potentially
  1358  // additional checks specified by checkers.
  1359  //
  1360  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1361  // NDPNS message as far as the size of the message is concerned. The values
  1362  // within the message are up to checkers to validate.
  1363  func NDPNS(checkers ...TransportChecker) NetworkChecker {
  1364  	return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
  1365  }
  1366  
  1367  // NDPNSTargetAddress creates a checker that checks the Target Address field of
  1368  // a header.NDPNeighborSolicit.
  1369  //
  1370  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1371  // containing a valid NDPNS message as far as the size is concerned.
  1372  func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
  1373  	return func(t *testing.T, h header.Transport) {
  1374  		t.Helper()
  1375  
  1376  		icmp := h.(header.ICMPv6)
  1377  		ns := header.NDPNeighborSolicit(icmp.MessageBody())
  1378  
  1379  		if got := ns.TargetAddress(); got != want {
  1380  			t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
  1381  		}
  1382  	}
  1383  }
  1384  
  1385  // NDPNA creates a checker that checks that the packet contains a valid NDP
  1386  // Neighbor Advertisement message (as per the raw wire format), with potentially
  1387  // additional checks specified by checkers.
  1388  //
  1389  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1390  // NDPNA message as far as the size of the message is concerned. The values
  1391  // within the message are up to checkers to validate.
  1392  func NDPNA(checkers ...TransportChecker) NetworkChecker {
  1393  	return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
  1394  }
  1395  
  1396  // NDPNATargetAddress creates a checker that checks the Target Address field of
  1397  // a header.NDPNeighborAdvert.
  1398  //
  1399  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1400  // containing a valid NDPNA message as far as the size is concerned.
  1401  func NDPNATargetAddress(want tcpip.Address) TransportChecker {
  1402  	return func(t *testing.T, h header.Transport) {
  1403  		t.Helper()
  1404  
  1405  		icmp := h.(header.ICMPv6)
  1406  		na := header.NDPNeighborAdvert(icmp.MessageBody())
  1407  
  1408  		if got := na.TargetAddress(); got != want {
  1409  			t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
  1410  		}
  1411  	}
  1412  }
  1413  
  1414  // NDPNASolicitedFlag creates a checker that checks the Solicited field of
  1415  // a header.NDPNeighborAdvert.
  1416  //
  1417  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1418  // containing a valid NDPNA message as far as the size is concerned.
  1419  func NDPNASolicitedFlag(want bool) TransportChecker {
  1420  	return func(t *testing.T, h header.Transport) {
  1421  		t.Helper()
  1422  
  1423  		icmp := h.(header.ICMPv6)
  1424  		na := header.NDPNeighborAdvert(icmp.MessageBody())
  1425  
  1426  		if got := na.SolicitedFlag(); got != want {
  1427  			t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
  1428  		}
  1429  	}
  1430  }
  1431  
  1432  // ndpOptions checks that optsBuf only contains opts.
  1433  func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
  1434  	t.Helper()
  1435  
  1436  	it, err := optsBuf.Iter(true)
  1437  	if err != nil {
  1438  		t.Errorf("optsBuf.Iter(true): %s", err)
  1439  		return
  1440  	}
  1441  
  1442  	i := 0
  1443  	for {
  1444  		opt, done, err := it.Next()
  1445  		if err != nil {
  1446  			// This should never happen as Iter(true) above did not return an error.
  1447  			t.Fatalf("unexpected error when iterating over NDP options: %s", err)
  1448  		}
  1449  		if done {
  1450  			break
  1451  		}
  1452  
  1453  		if i >= len(opts) {
  1454  			t.Errorf("got unexpected option: %s", opt)
  1455  			continue
  1456  		}
  1457  
  1458  		switch wantOpt := opts[i].(type) {
  1459  		case header.NDPSourceLinkLayerAddressOption:
  1460  			gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
  1461  			if !ok {
  1462  				t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
  1463  			} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
  1464  				t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
  1465  			}
  1466  		case header.NDPTargetLinkLayerAddressOption:
  1467  			gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
  1468  			if !ok {
  1469  				t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
  1470  			} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
  1471  				t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
  1472  			}
  1473  		case header.NDPNonceOption:
  1474  			gotOpt, ok := opt.(header.NDPNonceOption)
  1475  			if !ok {
  1476  				t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
  1477  			} else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" {
  1478  				t.Errorf("nonce mismatch (-want +got):\n%s", diff)
  1479  			}
  1480  		default:
  1481  			t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
  1482  		}
  1483  
  1484  		i++
  1485  	}
  1486  
  1487  	if missing := opts[i:]; len(missing) > 0 {
  1488  		t.Errorf("missing options: %s", missing)
  1489  	}
  1490  }
  1491  
  1492  // NDPNAOptions creates a checker that checks that the packet contains the
  1493  // provided NDP options within an NDP Neighbor Solicitation message.
  1494  //
  1495  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1496  // containing a valid NDPNA message as far as the size is concerned.
  1497  func NDPNAOptions(opts []header.NDPOption) TransportChecker {
  1498  	return func(t *testing.T, h header.Transport) {
  1499  		t.Helper()
  1500  
  1501  		icmp := h.(header.ICMPv6)
  1502  		na := header.NDPNeighborAdvert(icmp.MessageBody())
  1503  		ndpOptions(t, na.Options(), opts)
  1504  	}
  1505  }
  1506  
  1507  // NDPNSOptions creates a checker that checks that the packet contains the
  1508  // provided NDP options within an NDP Neighbor Solicitation message.
  1509  //
  1510  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1511  // containing a valid NDPNS message as far as the size is concerned.
  1512  func NDPNSOptions(opts []header.NDPOption) TransportChecker {
  1513  	return func(t *testing.T, h header.Transport) {
  1514  		t.Helper()
  1515  
  1516  		icmp := h.(header.ICMPv6)
  1517  		ns := header.NDPNeighborSolicit(icmp.MessageBody())
  1518  		ndpOptions(t, ns.Options(), opts)
  1519  	}
  1520  }
  1521  
  1522  // NDPRS creates a checker that checks that the packet contains a valid NDP
  1523  // Router Solicitation message (as per the raw wire format).
  1524  //
  1525  // Checkers may assume that a valid ICMPv6 is passed to it containing a valid
  1526  // NDPRS as far as the size of the message is concerned. The values within the
  1527  // message are up to checkers to validate.
  1528  func NDPRS(checkers ...TransportChecker) NetworkChecker {
  1529  	return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
  1530  }
  1531  
  1532  // NDPRSOptions creates a checker that checks that the packet contains the
  1533  // provided NDP options within an NDP Router Solicitation message.
  1534  //
  1535  // The returned TransportChecker assumes that a valid ICMPv6 is passed to it
  1536  // containing a valid NDPRS message as far as the size is concerned.
  1537  func NDPRSOptions(opts []header.NDPOption) TransportChecker {
  1538  	return func(t *testing.T, h header.Transport) {
  1539  		t.Helper()
  1540  
  1541  		icmp := h.(header.ICMPv6)
  1542  		rs := header.NDPRouterSolicit(icmp.MessageBody())
  1543  		ndpOptions(t, rs.Options(), opts)
  1544  	}
  1545  }
  1546  
  1547  // IGMP checks the validity and properties of the given IGMP packet. It is
  1548  // expected to be used in conjunction with other IGMP transport checkers for
  1549  // specific properties.
  1550  func IGMP(checkers ...TransportChecker) NetworkChecker {
  1551  	return func(t *testing.T, h []header.Network) {
  1552  		t.Helper()
  1553  
  1554  		last := h[len(h)-1]
  1555  
  1556  		if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
  1557  			t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
  1558  		}
  1559  
  1560  		igmp := header.IGMP(last.Payload())
  1561  		for _, f := range checkers {
  1562  			f(t, igmp)
  1563  		}
  1564  		if t.Failed() {
  1565  			t.FailNow()
  1566  		}
  1567  	}
  1568  }
  1569  
  1570  // IGMPType creates a checker that checks the IGMP Type field.
  1571  func IGMPType(want header.IGMPType) TransportChecker {
  1572  	return func(t *testing.T, h header.Transport) {
  1573  		t.Helper()
  1574  
  1575  		igmp, ok := h.(header.IGMP)
  1576  		if !ok {
  1577  			t.Fatalf("got transport header = %T, want = header.IGMP", h)
  1578  		}
  1579  		if got := igmp.Type(); got != want {
  1580  			t.Errorf("got igmp.Type() = %d, want = %d", got, want)
  1581  		}
  1582  	}
  1583  }
  1584  
  1585  // IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
  1586  func IGMPMaxRespTime(want time.Duration) TransportChecker {
  1587  	return func(t *testing.T, h header.Transport) {
  1588  		t.Helper()
  1589  
  1590  		igmp, ok := h.(header.IGMP)
  1591  		if !ok {
  1592  			t.Fatalf("got transport header = %T, want = header.IGMP", h)
  1593  		}
  1594  		if got := igmp.MaxRespTime(); got != want {
  1595  			t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
  1596  		}
  1597  	}
  1598  }
  1599  
  1600  // IGMPGroupAddressUnordered creates a checker that checks that the group
  1601  // address in the IGMP message is expected to be seen.
  1602  //
  1603  // The seen address is removed from the expected groups map.
  1604  //
  1605  // The returned TransportChecker assumes that a valid IGMP is passed to it
  1606  // containing a valid IGMP message as far as the size is concerned.
  1607  func IGMPGroupAddressUnordered(expectedGroups map[tcpip.Address]struct{}) TransportChecker {
  1608  	return func(t *testing.T, h header.Transport) {
  1609  		t.Helper()
  1610  
  1611  		igmp, ok := h.(header.IGMP)
  1612  		if !ok {
  1613  			t.Fatalf("got transport header = %T, want = header.IGMP", h)
  1614  		}
  1615  
  1616  		addr := igmp.GroupAddress()
  1617  
  1618  		if _, ok := expectedGroups[addr]; !ok {
  1619  			t.Errorf("unexpected multicast group %s", addr)
  1620  		} else {
  1621  			delete(expectedGroups, addr)
  1622  		}
  1623  	}
  1624  }
  1625  
  1626  // IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
  1627  func IGMPGroupAddress(want tcpip.Address) TransportChecker {
  1628  	return IGMPGroupAddressUnordered(map[tcpip.Address]struct{}{
  1629  		want: struct{}{},
  1630  	})
  1631  }
  1632  
  1633  // IGMPv3Report creates a checker that checks that the packet contains a valid
  1634  // IGMPv3 report with the specified records.
  1635  //
  1636  // Note that observed records are removed from expectedRecords. No error is
  1637  // logged if the report does not have all the records expected.
  1638  func IGMPv3Report(expectedRecords map[tcpip.Address]header.IGMPv3ReportRecordType) NetworkChecker {
  1639  	return func(t *testing.T, h []header.Network) {
  1640  		t.Helper()
  1641  
  1642  		last := h[len(h)-1]
  1643  		if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
  1644  			t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
  1645  		}
  1646  
  1647  		igmp := header.IGMP(last.Payload())
  1648  		if got := igmp.Type(); got != header.IGMPv3MembershipReport {
  1649  			t.Errorf("got igmp.Type() = %d, want = %d", got, header.IGMPv3MembershipReport)
  1650  		}
  1651  
  1652  		report := header.IGMPv3Report(igmp)
  1653  		if got, want := report.Checksum(), header.IGMPCalculateChecksum(igmp); got != want {
  1654  			t.Errorf("got report.Checksum() = %d, want = %d", got, want)
  1655  		}
  1656  
  1657  		records := report.GroupAddressRecords()
  1658  		for len(expectedRecords) != 0 {
  1659  			record, res := records.Next()
  1660  			switch res {
  1661  			case header.IGMPv3ReportGroupAddressRecordIteratorNextOk:
  1662  			case header.IGMPv3ReportGroupAddressRecordIteratorNextDone:
  1663  				return
  1664  			default:
  1665  				t.Fatalf("unhandled res = %d", res)
  1666  			}
  1667  
  1668  			addr := record.GroupAddress()
  1669  			expectedRecordType, ok := expectedRecords[addr]
  1670  			if !ok {
  1671  				t.Errorf("unexpected record for address %s", addr)
  1672  				continue
  1673  			}
  1674  
  1675  			if got, want := record.RecordType(), expectedRecordType; got != want {
  1676  				t.Errorf("got record.RecordType() = %d, want = %d", got, want)
  1677  			}
  1678  
  1679  			if got := record.AuxDataLen(); got != 0 {
  1680  				t.Errorf("got record.AuxDataLen() = %d, want = 0", got)
  1681  			}
  1682  
  1683  			sources, ok := record.Sources()
  1684  			if !ok {
  1685  				t.Error("got record.Sources() = (_, false), want = (_, true)")
  1686  				continue
  1687  			}
  1688  
  1689  			if source, ok := sources.Next(); ok {
  1690  				t.Fatalf("got sources.Next() = (%s, true), want = (_, false)", source)
  1691  			}
  1692  
  1693  			delete(expectedRecords, addr)
  1694  		}
  1695  
  1696  		if record, res := records.Next(); res != header.IGMPv3ReportGroupAddressRecordIteratorNextDone {
  1697  			t.Fatalf("got records.Next() = (%#v, %d), want = (_, %d)", record, res, header.IGMPv3ReportGroupAddressRecordIteratorNextDone)
  1698  		}
  1699  	}
  1700  }
  1701  
  1702  // IPv6ExtHdrChecker is a function to check an extension header.
  1703  type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
  1704  
  1705  // IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
  1706  func IPv6WithExtHdr(t *testing.T, v *buffer.View, checkers ...NetworkChecker) {
  1707  	t.Helper()
  1708  
  1709  	ipv6 := header.IPv6(v.AsSlice())
  1710  	if !ipv6.IsValid(len(v.AsSlice())) {
  1711  		t.Error("not a valid IPv6 packet")
  1712  		return
  1713  	}
  1714  
  1715  	payloadIterator := header.MakeIPv6PayloadIterator(
  1716  		header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
  1717  		buffer.MakeWithData(ipv6.Payload()),
  1718  	)
  1719  	defer payloadIterator.Release()
  1720  
  1721  	var rawPayloadHeader header.IPv6RawPayloadHeader
  1722  	for {
  1723  		h, done, err := payloadIterator.Next()
  1724  		if err != nil {
  1725  			t.Errorf("payloadIterator.Next(): %s", err)
  1726  			return
  1727  		}
  1728  		if done {
  1729  			t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
  1730  			return
  1731  		}
  1732  		defer h.Release()
  1733  		r, ok := h.(header.IPv6RawPayloadHeader)
  1734  		if ok {
  1735  			rawPayloadHeader = r
  1736  			break
  1737  		}
  1738  	}
  1739  
  1740  	networkHeader := ipv6HeaderWithExtHdr{
  1741  		IPv6:      ipv6,
  1742  		transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
  1743  		payload:   rawPayloadHeader.Buf.Flatten(),
  1744  	}
  1745  
  1746  	for _, checker := range checkers {
  1747  		checker(t, []header.Network{&networkHeader})
  1748  	}
  1749  }
  1750  
  1751  // IPv6ExtHdr checks for the presence of extension headers.
  1752  //
  1753  // All the extension headers in headers will be checked exhaustively in the
  1754  // order provided.
  1755  func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
  1756  	return func(t *testing.T, h []header.Network) {
  1757  		t.Helper()
  1758  
  1759  		extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
  1760  		if !ok {
  1761  			t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
  1762  			return
  1763  		}
  1764  
  1765  		payloadIterator := header.MakeIPv6PayloadIterator(
  1766  			header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
  1767  			buffer.MakeWithData(extHdrs.IPv6.Payload()),
  1768  		)
  1769  		defer payloadIterator.Release()
  1770  
  1771  		for _, check := range headers {
  1772  			h, done, err := payloadIterator.Next()
  1773  			if err != nil {
  1774  				t.Errorf("payloadIterator.Next(): %s", err)
  1775  				return
  1776  			}
  1777  			if done {
  1778  				t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
  1779  				return
  1780  			}
  1781  			check(t, h)
  1782  			h.Release()
  1783  		}
  1784  		// Validate we consumed all headers.
  1785  		//
  1786  		// The next one over should be a raw payload and then iterator should
  1787  		// terminate.
  1788  		wantDone := false
  1789  		for {
  1790  			h, done, err := payloadIterator.Next()
  1791  			if err != nil {
  1792  				t.Errorf("payloadIterator.Next(): %s", err)
  1793  				return
  1794  			}
  1795  			if done != wantDone {
  1796  				t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
  1797  				return
  1798  			}
  1799  			if done {
  1800  				break
  1801  			}
  1802  			if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
  1803  				t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
  1804  				continue
  1805  			} else {
  1806  				h.Release()
  1807  			}
  1808  			wantDone = true
  1809  		}
  1810  	}
  1811  }
  1812  
  1813  var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
  1814  
  1815  // ipv6HeaderWithExtHdr provides a header.Network implementation that takes
  1816  // extension headers into consideration, which is not the case with vanilla
  1817  // header.IPv6.
  1818  type ipv6HeaderWithExtHdr struct {
  1819  	header.IPv6
  1820  	transport tcpip.TransportProtocolNumber
  1821  	payload   []byte
  1822  }
  1823  
  1824  // TransportProtocol implements header.Network.
  1825  func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
  1826  	return h.transport
  1827  }
  1828  
  1829  // Payload implements header.Network.
  1830  func (h *ipv6HeaderWithExtHdr) Payload() []byte {
  1831  	return h.payload
  1832  }
  1833  
  1834  // IPv6ExtHdrOptionChecker is a function to check an extension header option.
  1835  type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
  1836  
  1837  // IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
  1838  // extension header and validates the containing options with checkers.
  1839  //
  1840  // checkers must exhaustively contain all the expected options.
  1841  func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
  1842  	return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
  1843  		t.Helper()
  1844  
  1845  		hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
  1846  		if !ok {
  1847  			t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
  1848  			return
  1849  		}
  1850  		optionsIterator := hbh.Iter()
  1851  		for _, f := range checkers {
  1852  			opt, done, err := optionsIterator.Next()
  1853  			if err != nil {
  1854  				t.Errorf("optionsIterator.Next(): %s", err)
  1855  				return
  1856  			}
  1857  			if done {
  1858  				t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
  1859  			}
  1860  			f(t, opt)
  1861  			if uo, ok := opt.(*header.IPv6UnknownExtHdrOption); ok {
  1862  				uo.Data.Release()
  1863  			}
  1864  		}
  1865  		// Validate all options were consumed.
  1866  		for {
  1867  			opt, done, err := optionsIterator.Next()
  1868  			if err != nil {
  1869  				t.Errorf("optionsIterator.Next(): %s", err)
  1870  				return
  1871  			}
  1872  			if !done {
  1873  				t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
  1874  			}
  1875  			if done {
  1876  				break
  1877  			}
  1878  			if uo, ok := opt.(*header.IPv6UnknownExtHdrOption); ok {
  1879  				uo.Data.Release()
  1880  			}
  1881  		}
  1882  	}
  1883  }
  1884  
  1885  // IPv6RouterAlert validates that an extension header option is the RouterAlert
  1886  // option and matches on its value.
  1887  func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
  1888  	return func(t *testing.T, opt header.IPv6ExtHdrOption) {
  1889  		routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
  1890  		if !ok {
  1891  			t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
  1892  			return
  1893  		}
  1894  		if routerAlert.Value != want {
  1895  			t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
  1896  		}
  1897  	}
  1898  }
  1899  
  1900  // IPv6UnknownOption validates that an extension header option is the
  1901  // unknown header option.
  1902  func IPv6UnknownOption() IPv6ExtHdrOptionChecker {
  1903  	return func(t *testing.T, opt header.IPv6ExtHdrOption) {
  1904  		_, ok := opt.(*header.IPv6UnknownExtHdrOption)
  1905  		if !ok {
  1906  			t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt)
  1907  		}
  1908  	}
  1909  }
  1910  
  1911  // IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
  1912  func IgnoreCmpPath(paths ...string) cmp.Option {
  1913  	ignores := map[string]struct{}{}
  1914  	for _, path := range paths {
  1915  		ignores[path] = struct{}{}
  1916  	}
  1917  	return cmp.FilterPath(func(path cmp.Path) bool {
  1918  		_, ok := ignores[path.String()]
  1919  		return ok
  1920  	}, cmp.Ignore())
  1921  }