github.com/vpnishe/netstack@v1.10.6/tcpip/network/ipv6/icmp_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ipv6
    16  
    17  import (
    18  	"reflect"
    19  	"strings"
    20  	"testing"
    21  
    22  	"github.com/vpnishe/netstack/tcpip"
    23  	"github.com/vpnishe/netstack/tcpip/buffer"
    24  	"github.com/vpnishe/netstack/tcpip/header"
    25  	"github.com/vpnishe/netstack/tcpip/link/channel"
    26  	"github.com/vpnishe/netstack/tcpip/link/sniffer"
    27  	"github.com/vpnishe/netstack/tcpip/stack"
    28  	"github.com/vpnishe/netstack/tcpip/transport/icmp"
    29  	"github.com/vpnishe/netstack/waiter"
    30  )
    31  
    32  const (
    33  	linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
    34  	linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
    35  )
    36  
    37  var (
    38  	lladdr0 = header.LinkLocalAddr(linkAddr0)
    39  	lladdr1 = header.LinkLocalAddr(linkAddr1)
    40  )
    41  
    42  type stubLinkEndpoint struct {
    43  	stack.LinkEndpoint
    44  }
    45  
    46  func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
    47  	return 0
    48  }
    49  
    50  func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
    51  	return 0
    52  }
    53  
    54  func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
    55  	return ""
    56  }
    57  
    58  func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error {
    59  	return nil
    60  }
    61  
    62  func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {}
    63  
    64  type stubDispatcher struct {
    65  	stack.TransportDispatcher
    66  }
    67  
    68  func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) {
    69  }
    70  
    71  type stubLinkAddressCache struct {
    72  	stack.LinkAddressCache
    73  }
    74  
    75  func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID {
    76  	return 0
    77  }
    78  
    79  func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
    80  }
    81  
    82  func TestICMPCounts(t *testing.T) {
    83  	s := stack.New(stack.Options{
    84  		NetworkProtocols:   []stack.NetworkProtocol{NewProtocol()},
    85  		TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
    86  	})
    87  	{
    88  		if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
    89  			t.Fatalf("CreateNIC(_) = %s", err)
    90  		}
    91  		if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
    92  			t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
    93  		}
    94  	}
    95  	{
    96  		subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
    97  		if err != nil {
    98  			t.Fatal(err)
    99  		}
   100  		s.SetRouteTable(
   101  			[]tcpip.Route{{
   102  				Destination: subnet,
   103  				NIC:         1,
   104  			}},
   105  		)
   106  	}
   107  
   108  	netProto := s.NetworkProtocolInstance(ProtocolNumber)
   109  	if netProto == nil {
   110  		t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
   111  	}
   112  	ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
   113  	if err != nil {
   114  		t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
   115  	}
   116  
   117  	r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
   118  	if err != nil {
   119  		t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
   120  	}
   121  	defer r.Release()
   122  
   123  	types := []struct {
   124  		typ  header.ICMPv6Type
   125  		size int
   126  	}{
   127  		{header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize},
   128  		{header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize},
   129  		{header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize},
   130  		{header.ICMPv6ParamProblem, header.ICMPv6MinimumSize},
   131  		{header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
   132  		{header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
   133  		{header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
   134  		{header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize},
   135  		{header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
   136  		{header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
   137  		{header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
   138  	}
   139  
   140  	handleIPv6Payload := func(hdr buffer.Prependable) {
   141  		payloadLength := hdr.UsedLength()
   142  		ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   143  		ip.Encode(&header.IPv6Fields{
   144  			PayloadLength: uint16(payloadLength),
   145  			NextHeader:    uint8(header.ICMPv6ProtocolNumber),
   146  			HopLimit:      header.NDPHopLimit,
   147  			SrcAddr:       r.LocalAddress,
   148  			DstAddr:       r.RemoteAddress,
   149  		})
   150  		ep.HandlePacket(&r, tcpip.PacketBuffer{
   151  			Data: hdr.View().ToVectorisedView(),
   152  		})
   153  	}
   154  
   155  	for _, typ := range types {
   156  		hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
   157  		pkt := header.ICMPv6(hdr.Prepend(typ.size))
   158  		pkt.SetType(typ.typ)
   159  		pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
   160  
   161  		handleIPv6Payload(hdr)
   162  	}
   163  
   164  	// Construct an empty ICMP packet so that
   165  	// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
   166  	handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
   167  
   168  	icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
   169  	visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
   170  		if got, want := s.Value(), uint64(1); got != want {
   171  			t.Errorf("got %s = %d, want = %d", name, got, want)
   172  		}
   173  	})
   174  	if t.Failed() {
   175  		t.Logf("stats:\n%+v", s.Stats())
   176  	}
   177  }
   178  
   179  func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
   180  	t := v.Type()
   181  	for i := 0; i < v.NumField(); i++ {
   182  		v := v.Field(i)
   183  		if s, ok := v.Interface().(*tcpip.StatCounter); ok {
   184  			f(t.Field(i).Name, s)
   185  		} else {
   186  			visitStats(v, f)
   187  		}
   188  	}
   189  }
   190  
   191  type testContext struct {
   192  	s0 *stack.Stack
   193  	s1 *stack.Stack
   194  
   195  	linkEP0 *channel.Endpoint
   196  	linkEP1 *channel.Endpoint
   197  }
   198  
   199  type endpointWithResolutionCapability struct {
   200  	stack.LinkEndpoint
   201  }
   202  
   203  func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
   204  	return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
   205  }
   206  
   207  func newTestContext(t *testing.T) *testContext {
   208  	c := &testContext{
   209  		s0: stack.New(stack.Options{
   210  			NetworkProtocols:   []stack.NetworkProtocol{NewProtocol()},
   211  			TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
   212  		}),
   213  		s1: stack.New(stack.Options{
   214  			NetworkProtocols:   []stack.NetworkProtocol{NewProtocol()},
   215  			TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
   216  		}),
   217  	}
   218  
   219  	const defaultMTU = 65536
   220  	c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
   221  
   222  	wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
   223  	if testing.Verbose() {
   224  		wrappedEP0 = sniffer.New(wrappedEP0)
   225  	}
   226  	if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
   227  		t.Fatalf("CreateNIC s0: %v", err)
   228  	}
   229  	if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
   230  		t.Fatalf("AddAddress lladdr0: %v", err)
   231  	}
   232  
   233  	c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
   234  	wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
   235  	if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
   236  		t.Fatalf("CreateNIC failed: %v", err)
   237  	}
   238  	if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
   239  		t.Fatalf("AddAddress lladdr1: %v", err)
   240  	}
   241  
   242  	subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  	c.s0.SetRouteTable(
   247  		[]tcpip.Route{{
   248  			Destination: subnet0,
   249  			NIC:         1,
   250  		}},
   251  	)
   252  	subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
   253  	if err != nil {
   254  		t.Fatal(err)
   255  	}
   256  	c.s1.SetRouteTable(
   257  		[]tcpip.Route{{
   258  			Destination: subnet1,
   259  			NIC:         1,
   260  		}},
   261  	)
   262  
   263  	return c
   264  }
   265  
   266  func (c *testContext) cleanup() {
   267  	close(c.linkEP0.C)
   268  	close(c.linkEP1.C)
   269  }
   270  
   271  type routeArgs struct {
   272  	src, dst *channel.Endpoint
   273  	typ      header.ICMPv6Type
   274  }
   275  
   276  func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
   277  	t.Helper()
   278  
   279  	pi := <-args.src.C
   280  
   281  	{
   282  		views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
   283  		size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
   284  		vv := buffer.NewVectorisedView(size, views)
   285  		args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{
   286  			Data: vv,
   287  		})
   288  	}
   289  
   290  	if pi.Proto != ProtocolNumber {
   291  		t.Errorf("unexpected protocol number %d", pi.Proto)
   292  		return
   293  	}
   294  	ipv6 := header.IPv6(pi.Pkt.Header.View())
   295  	transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
   296  	if transProto != header.ICMPv6ProtocolNumber {
   297  		t.Errorf("unexpected transport protocol number %d", transProto)
   298  		return
   299  	}
   300  	icmpv6 := header.ICMPv6(ipv6.Payload())
   301  	if got, want := icmpv6.Type(), args.typ; got != want {
   302  		t.Errorf("got ICMPv6 type = %d, want = %d", got, want)
   303  		return
   304  	}
   305  	if fn != nil {
   306  		fn(t, icmpv6)
   307  	}
   308  }
   309  
   310  func TestLinkResolution(t *testing.T) {
   311  	c := newTestContext(t)
   312  	defer c.cleanup()
   313  
   314  	r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
   315  	if err != nil {
   316  		t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
   317  	}
   318  	defer r.Release()
   319  
   320  	hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
   321  	pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
   322  	pkt.SetType(header.ICMPv6EchoRequest)
   323  	pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
   324  	payload := tcpip.SlicePayload(hdr.View())
   325  
   326  	// We can't send our payload directly over the route because that
   327  	// doesn't provoke NDP discovery.
   328  	var wq waiter.Queue
   329  	ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
   330  	if err != nil {
   331  		t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
   332  	}
   333  
   334  	for {
   335  		_, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
   336  		if resCh != nil {
   337  			if err != tcpip.ErrNoLinkAddress {
   338  				t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
   339  			}
   340  			for _, args := range []routeArgs{
   341  				{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit},
   342  				{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
   343  			} {
   344  				routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
   345  					if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
   346  						t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
   347  					}
   348  				})
   349  			}
   350  			<-resCh
   351  			continue
   352  		}
   353  		if err != nil {
   354  			t.Fatalf("ep.Write(_) = _, _, %s", err)
   355  		}
   356  		break
   357  	}
   358  
   359  	for _, args := range []routeArgs{
   360  		{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
   361  		{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
   362  	} {
   363  		routeICMPv6Packet(t, args, nil)
   364  	}
   365  }
   366  
   367  func TestICMPChecksumValidationSimple(t *testing.T) {
   368  	types := []struct {
   369  		name        string
   370  		typ         header.ICMPv6Type
   371  		size        int
   372  		statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
   373  	}{
   374  		{
   375  			"DstUnreachable",
   376  			header.ICMPv6DstUnreachable,
   377  			header.ICMPv6DstUnreachableMinimumSize,
   378  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   379  				return stats.DstUnreachable
   380  			},
   381  		},
   382  		{
   383  			"PacketTooBig",
   384  			header.ICMPv6PacketTooBig,
   385  			header.ICMPv6PacketTooBigMinimumSize,
   386  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   387  				return stats.PacketTooBig
   388  			},
   389  		},
   390  		{
   391  			"TimeExceeded",
   392  			header.ICMPv6TimeExceeded,
   393  			header.ICMPv6MinimumSize,
   394  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   395  				return stats.TimeExceeded
   396  			},
   397  		},
   398  		{
   399  			"ParamProblem",
   400  			header.ICMPv6ParamProblem,
   401  			header.ICMPv6MinimumSize,
   402  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   403  				return stats.ParamProblem
   404  			},
   405  		},
   406  		{
   407  			"EchoRequest",
   408  			header.ICMPv6EchoRequest,
   409  			header.ICMPv6EchoMinimumSize,
   410  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   411  				return stats.EchoRequest
   412  			},
   413  		},
   414  		{
   415  			"EchoReply",
   416  			header.ICMPv6EchoReply,
   417  			header.ICMPv6EchoMinimumSize,
   418  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   419  				return stats.EchoReply
   420  			},
   421  		},
   422  		{
   423  			"RouterSolicit",
   424  			header.ICMPv6RouterSolicit,
   425  			header.ICMPv6MinimumSize,
   426  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   427  				return stats.RouterSolicit
   428  			},
   429  		},
   430  		{
   431  			"RouterAdvert",
   432  			header.ICMPv6RouterAdvert,
   433  			header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
   434  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   435  				return stats.RouterAdvert
   436  			},
   437  		},
   438  		{
   439  			"NeighborSolicit",
   440  			header.ICMPv6NeighborSolicit,
   441  			header.ICMPv6NeighborSolicitMinimumSize,
   442  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   443  				return stats.NeighborSolicit
   444  			},
   445  		},
   446  		{
   447  			"NeighborAdvert",
   448  			header.ICMPv6NeighborAdvert,
   449  			header.ICMPv6NeighborAdvertSize,
   450  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   451  				return stats.NeighborAdvert
   452  			},
   453  		},
   454  		{
   455  			"RedirectMsg",
   456  			header.ICMPv6RedirectMsg,
   457  			header.ICMPv6MinimumSize,
   458  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   459  				return stats.RedirectMsg
   460  			},
   461  		},
   462  	}
   463  
   464  	for _, typ := range types {
   465  		t.Run(typ.name, func(t *testing.T) {
   466  			e := channel.New(10, 1280, linkAddr0)
   467  			s := stack.New(stack.Options{
   468  				NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
   469  			})
   470  			if err := s.CreateNIC(1, e); err != nil {
   471  				t.Fatalf("CreateNIC(_) = %s", err)
   472  			}
   473  
   474  			if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
   475  				t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
   476  			}
   477  			{
   478  				subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
   479  				if err != nil {
   480  					t.Fatal(err)
   481  				}
   482  				s.SetRouteTable(
   483  					[]tcpip.Route{{
   484  						Destination: subnet,
   485  						NIC:         1,
   486  					}},
   487  				)
   488  			}
   489  
   490  			handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
   491  				hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
   492  				pkt := header.ICMPv6(hdr.Prepend(size))
   493  				pkt.SetType(typ)
   494  				if checksum {
   495  					pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
   496  				}
   497  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   498  				ip.Encode(&header.IPv6Fields{
   499  					PayloadLength: uint16(size),
   500  					NextHeader:    uint8(header.ICMPv6ProtocolNumber),
   501  					HopLimit:      header.NDPHopLimit,
   502  					SrcAddr:       lladdr1,
   503  					DstAddr:       lladdr0,
   504  				})
   505  				e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
   506  					Data: hdr.View().ToVectorisedView(),
   507  				})
   508  			}
   509  
   510  			stats := s.Stats().ICMP.V6PacketsReceived
   511  			invalid := stats.Invalid
   512  			typStat := typ.statCounter(stats)
   513  
   514  			// Initial stat counts should be 0.
   515  			if got := invalid.Value(); got != 0 {
   516  				t.Fatalf("got invalid = %d, want = 0", got)
   517  			}
   518  			if got := typStat.Value(); got != 0 {
   519  				t.Fatalf("got %s = %d, want = 0", typ.name, got)
   520  			}
   521  
   522  			// Without setting checksum, the incoming packet should
   523  			// be invalid.
   524  			handleIPv6Payload(typ.typ, typ.size, false)
   525  			if got := invalid.Value(); got != 1 {
   526  				t.Fatalf("got invalid = %d, want = 1", got)
   527  			}
   528  			// Rx count of type typ.typ should not have increased.
   529  			if got := typStat.Value(); got != 0 {
   530  				t.Fatalf("got %s = %d, want = 0", typ.name, got)
   531  			}
   532  
   533  			// When checksum is set, it should be received.
   534  			handleIPv6Payload(typ.typ, typ.size, true)
   535  			if got := typStat.Value(); got != 1 {
   536  				t.Fatalf("got %s = %d, want = 1", typ.name, got)
   537  			}
   538  			// Invalid count should not have increased again.
   539  			if got := invalid.Value(); got != 1 {
   540  				t.Fatalf("got invalid = %d, want = 1", got)
   541  			}
   542  		})
   543  	}
   544  }
   545  
   546  func TestICMPChecksumValidationWithPayload(t *testing.T) {
   547  	const simpleBodySize = 64
   548  	simpleBody := func(view buffer.View) {
   549  		for i := 0; i < simpleBodySize; i++ {
   550  			view[i] = uint8(i)
   551  		}
   552  	}
   553  
   554  	const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
   555  	errorICMPBody := func(view buffer.View) {
   556  		ip := header.IPv6(view)
   557  		ip.Encode(&header.IPv6Fields{
   558  			PayloadLength: simpleBodySize,
   559  			NextHeader:    10,
   560  			HopLimit:      20,
   561  			SrcAddr:       lladdr0,
   562  			DstAddr:       lladdr1,
   563  		})
   564  		simpleBody(view[header.IPv6MinimumSize:])
   565  	}
   566  
   567  	types := []struct {
   568  		name        string
   569  		typ         header.ICMPv6Type
   570  		size        int
   571  		statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
   572  		payloadSize int
   573  		payload     func(buffer.View)
   574  	}{
   575  		{
   576  			"DstUnreachable",
   577  			header.ICMPv6DstUnreachable,
   578  			header.ICMPv6DstUnreachableMinimumSize,
   579  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   580  				return stats.DstUnreachable
   581  			},
   582  			errorICMPBodySize,
   583  			errorICMPBody,
   584  		},
   585  		{
   586  			"PacketTooBig",
   587  			header.ICMPv6PacketTooBig,
   588  			header.ICMPv6PacketTooBigMinimumSize,
   589  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   590  				return stats.PacketTooBig
   591  			},
   592  			errorICMPBodySize,
   593  			errorICMPBody,
   594  		},
   595  		{
   596  			"TimeExceeded",
   597  			header.ICMPv6TimeExceeded,
   598  			header.ICMPv6MinimumSize,
   599  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   600  				return stats.TimeExceeded
   601  			},
   602  			errorICMPBodySize,
   603  			errorICMPBody,
   604  		},
   605  		{
   606  			"ParamProblem",
   607  			header.ICMPv6ParamProblem,
   608  			header.ICMPv6MinimumSize,
   609  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   610  				return stats.ParamProblem
   611  			},
   612  			errorICMPBodySize,
   613  			errorICMPBody,
   614  		},
   615  		{
   616  			"EchoRequest",
   617  			header.ICMPv6EchoRequest,
   618  			header.ICMPv6EchoMinimumSize,
   619  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   620  				return stats.EchoRequest
   621  			},
   622  			simpleBodySize,
   623  			simpleBody,
   624  		},
   625  		{
   626  			"EchoReply",
   627  			header.ICMPv6EchoReply,
   628  			header.ICMPv6EchoMinimumSize,
   629  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   630  				return stats.EchoReply
   631  			},
   632  			simpleBodySize,
   633  			simpleBody,
   634  		},
   635  	}
   636  
   637  	for _, typ := range types {
   638  		t.Run(typ.name, func(t *testing.T) {
   639  			e := channel.New(10, 1280, linkAddr0)
   640  			s := stack.New(stack.Options{
   641  				NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
   642  			})
   643  			if err := s.CreateNIC(1, e); err != nil {
   644  				t.Fatalf("CreateNIC(_) = %s", err)
   645  			}
   646  
   647  			if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
   648  				t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
   649  			}
   650  			{
   651  				subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
   652  				if err != nil {
   653  					t.Fatal(err)
   654  				}
   655  				s.SetRouteTable(
   656  					[]tcpip.Route{{
   657  						Destination: subnet,
   658  						NIC:         1,
   659  					}},
   660  				)
   661  			}
   662  
   663  			handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
   664  				icmpSize := size + payloadSize
   665  				hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
   666  				pkt := header.ICMPv6(hdr.Prepend(icmpSize))
   667  				pkt.SetType(typ)
   668  				payloadFn(pkt.Payload())
   669  
   670  				if checksum {
   671  					pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
   672  				}
   673  
   674  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   675  				ip.Encode(&header.IPv6Fields{
   676  					PayloadLength: uint16(icmpSize),
   677  					NextHeader:    uint8(header.ICMPv6ProtocolNumber),
   678  					HopLimit:      header.NDPHopLimit,
   679  					SrcAddr:       lladdr1,
   680  					DstAddr:       lladdr0,
   681  				})
   682  				e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
   683  					Data: hdr.View().ToVectorisedView(),
   684  				})
   685  			}
   686  
   687  			stats := s.Stats().ICMP.V6PacketsReceived
   688  			invalid := stats.Invalid
   689  			typStat := typ.statCounter(stats)
   690  
   691  			// Initial stat counts should be 0.
   692  			if got := invalid.Value(); got != 0 {
   693  				t.Fatalf("got invalid = %d, want = 0", got)
   694  			}
   695  			if got := typStat.Value(); got != 0 {
   696  				t.Fatalf("got %s = %d, want = 0", typ.name, got)
   697  			}
   698  
   699  			// Without setting checksum, the incoming packet should
   700  			// be invalid.
   701  			handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
   702  			if got := invalid.Value(); got != 1 {
   703  				t.Fatalf("got invalid = %d, want = 1", got)
   704  			}
   705  			// Rx count of type typ.typ should not have increased.
   706  			if got := typStat.Value(); got != 0 {
   707  				t.Fatalf("got %s = %d, want = 0", typ.name, got)
   708  			}
   709  
   710  			// When checksum is set, it should be received.
   711  			handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
   712  			if got := typStat.Value(); got != 1 {
   713  				t.Fatalf("got %s = %d, want = 1", typ.name, got)
   714  			}
   715  			// Invalid count should not have increased again.
   716  			if got := invalid.Value(); got != 1 {
   717  				t.Fatalf("got invalid = %d, want = 1", got)
   718  			}
   719  		})
   720  	}
   721  }
   722  
   723  func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
   724  	const simpleBodySize = 64
   725  	simpleBody := func(view buffer.View) {
   726  		for i := 0; i < simpleBodySize; i++ {
   727  			view[i] = uint8(i)
   728  		}
   729  	}
   730  
   731  	const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
   732  	errorICMPBody := func(view buffer.View) {
   733  		ip := header.IPv6(view)
   734  		ip.Encode(&header.IPv6Fields{
   735  			PayloadLength: simpleBodySize,
   736  			NextHeader:    10,
   737  			HopLimit:      20,
   738  			SrcAddr:       lladdr0,
   739  			DstAddr:       lladdr1,
   740  		})
   741  		simpleBody(view[header.IPv6MinimumSize:])
   742  	}
   743  
   744  	types := []struct {
   745  		name        string
   746  		typ         header.ICMPv6Type
   747  		size        int
   748  		statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
   749  		payloadSize int
   750  		payload     func(buffer.View)
   751  	}{
   752  		{
   753  			"DstUnreachable",
   754  			header.ICMPv6DstUnreachable,
   755  			header.ICMPv6DstUnreachableMinimumSize,
   756  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   757  				return stats.DstUnreachable
   758  			},
   759  			errorICMPBodySize,
   760  			errorICMPBody,
   761  		},
   762  		{
   763  			"PacketTooBig",
   764  			header.ICMPv6PacketTooBig,
   765  			header.ICMPv6PacketTooBigMinimumSize,
   766  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   767  				return stats.PacketTooBig
   768  			},
   769  			errorICMPBodySize,
   770  			errorICMPBody,
   771  		},
   772  		{
   773  			"TimeExceeded",
   774  			header.ICMPv6TimeExceeded,
   775  			header.ICMPv6MinimumSize,
   776  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   777  				return stats.TimeExceeded
   778  			},
   779  			errorICMPBodySize,
   780  			errorICMPBody,
   781  		},
   782  		{
   783  			"ParamProblem",
   784  			header.ICMPv6ParamProblem,
   785  			header.ICMPv6MinimumSize,
   786  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   787  				return stats.ParamProblem
   788  			},
   789  			errorICMPBodySize,
   790  			errorICMPBody,
   791  		},
   792  		{
   793  			"EchoRequest",
   794  			header.ICMPv6EchoRequest,
   795  			header.ICMPv6EchoMinimumSize,
   796  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   797  				return stats.EchoRequest
   798  			},
   799  			simpleBodySize,
   800  			simpleBody,
   801  		},
   802  		{
   803  			"EchoReply",
   804  			header.ICMPv6EchoReply,
   805  			header.ICMPv6EchoMinimumSize,
   806  			func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
   807  				return stats.EchoReply
   808  			},
   809  			simpleBodySize,
   810  			simpleBody,
   811  		},
   812  	}
   813  
   814  	for _, typ := range types {
   815  		t.Run(typ.name, func(t *testing.T) {
   816  			e := channel.New(10, 1280, linkAddr0)
   817  			s := stack.New(stack.Options{
   818  				NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
   819  			})
   820  			if err := s.CreateNIC(1, e); err != nil {
   821  				t.Fatalf("CreateNIC(_) = %s", err)
   822  			}
   823  
   824  			if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
   825  				t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
   826  			}
   827  			{
   828  				subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
   829  				if err != nil {
   830  					t.Fatal(err)
   831  				}
   832  				s.SetRouteTable(
   833  					[]tcpip.Route{{
   834  						Destination: subnet,
   835  						NIC:         1,
   836  					}},
   837  				)
   838  			}
   839  
   840  			handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
   841  				hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
   842  				pkt := header.ICMPv6(hdr.Prepend(size))
   843  				pkt.SetType(typ)
   844  
   845  				payload := buffer.NewView(payloadSize)
   846  				payloadFn(payload)
   847  
   848  				if checksum {
   849  					pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
   850  				}
   851  
   852  				ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
   853  				ip.Encode(&header.IPv6Fields{
   854  					PayloadLength: uint16(size + payloadSize),
   855  					NextHeader:    uint8(header.ICMPv6ProtocolNumber),
   856  					HopLimit:      header.NDPHopLimit,
   857  					SrcAddr:       lladdr1,
   858  					DstAddr:       lladdr0,
   859  				})
   860  				e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
   861  					Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
   862  				})
   863  			}
   864  
   865  			stats := s.Stats().ICMP.V6PacketsReceived
   866  			invalid := stats.Invalid
   867  			typStat := typ.statCounter(stats)
   868  
   869  			// Initial stat counts should be 0.
   870  			if got := invalid.Value(); got != 0 {
   871  				t.Fatalf("got invalid = %d, want = 0", got)
   872  			}
   873  			if got := typStat.Value(); got != 0 {
   874  				t.Fatalf("got %s = %d, want = 0", typ.name, got)
   875  			}
   876  
   877  			// Without setting checksum, the incoming packet should
   878  			// be invalid.
   879  			handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
   880  			if got := invalid.Value(); got != 1 {
   881  				t.Fatalf("got invalid = %d, want = 1", got)
   882  			}
   883  			// Rx count of type typ.typ should not have increased.
   884  			if got := typStat.Value(); got != 0 {
   885  				t.Fatalf("got %s = %d, want = 0", typ.name, got)
   886  			}
   887  
   888  			// When checksum is set, it should be received.
   889  			handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
   890  			if got := typStat.Value(); got != 1 {
   891  				t.Fatalf("got %s = %d, want = 1", typ.name, got)
   892  			}
   893  			// Invalid count should not have increased again.
   894  			if got := invalid.Value(); got != 1 {
   895  				t.Fatalf("got invalid = %d, want = 1", got)
   896  			}
   897  		})
   898  	}
   899  }