gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/tests/integration/iptables_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package iptables_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"math"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"gvisor.dev/gvisor/pkg/buffer"
    25  	"gvisor.dev/gvisor/pkg/tcpip"
    26  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    27  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    28  	"gvisor.dev/gvisor/pkg/tcpip/header"
    29  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    30  	"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
    31  	"gvisor.dev/gvisor/pkg/tcpip/network/arp"
    32  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    33  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    34  	"gvisor.dev/gvisor/pkg/tcpip/prependable"
    35  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    36  	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
    37  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    38  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    39  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    40  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    41  	"gvisor.dev/gvisor/pkg/waiter"
    42  )
    43  
    44  type inputIfNameMatcher struct {
    45  	name string
    46  }
    47  
    48  var _ stack.Matcher = (*inputIfNameMatcher)(nil)
    49  
    50  func (*inputIfNameMatcher) Name() string {
    51  	return "inputIfNameMatcher"
    52  }
    53  
    54  func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
    55  	return (hook == stack.Input && im.name != "" && im.name == inNicName), false
    56  }
    57  
    58  const (
    59  	nicID          = 1
    60  	nicName        = "nic1"
    61  	anotherNicName = "nic2"
    62  	linkAddr       = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
    63  	payloadSize    = 20
    64  )
    65  
    66  var (
    67  	srcAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x01"))
    68  	dstAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x02"))
    69  	srcAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"))
    70  	dstAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"))
    71  )
    72  
    73  func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
    74  	t.Helper()
    75  	s := stack.New(stack.Options{
    76  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv6.NewProtocol},
    77  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
    78  	})
    79  	e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
    80  	nicOpts := stack.NICOptions{Name: nicName}
    81  	if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
    82  		t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
    83  	}
    84  	protocolAddr := tcpip.ProtocolAddress{
    85  		Protocol:          header.IPv6ProtocolNumber,
    86  		AddressWithPrefix: dstAddrV6.WithPrefix(),
    87  	}
    88  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
    89  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
    90  	}
    91  	return s, e
    92  }
    93  
    94  func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
    95  	t.Helper()
    96  	s := stack.New(stack.Options{
    97  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
    98  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
    99  	})
   100  	e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
   101  	nicOpts := stack.NICOptions{Name: nicName}
   102  	if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
   103  		t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
   104  	}
   105  	protocolAddr := tcpip.ProtocolAddress{
   106  		Protocol:          header.IPv4ProtocolNumber,
   107  		AddressWithPrefix: dstAddrV4.WithPrefix(),
   108  	}
   109  	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   110  		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   111  	}
   112  	return s, e
   113  }
   114  
   115  func genPacketV6() *stack.PacketBuffer {
   116  	pktSize := header.IPv6MinimumSize + payloadSize
   117  	hdr := prependable.New(pktSize)
   118  	ip := header.IPv6(hdr.Prepend(pktSize))
   119  	ip.Encode(&header.IPv6Fields{
   120  		PayloadLength:     payloadSize,
   121  		TransportProtocol: 99,
   122  		HopLimit:          255,
   123  		SrcAddr:           srcAddrV6,
   124  		DstAddr:           dstAddrV6,
   125  	})
   126  	buf := buffer.MakeWithData(hdr.View())
   127  	return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
   128  }
   129  
   130  func genPacketV4() *stack.PacketBuffer {
   131  	pktSize := header.IPv4MinimumSize + payloadSize
   132  	hdr := prependable.New(pktSize)
   133  	ip := header.IPv4(hdr.Prepend(pktSize))
   134  	ip.Encode(&header.IPv4Fields{
   135  		TOS:            0,
   136  		TotalLength:    uint16(pktSize),
   137  		ID:             1,
   138  		Flags:          0,
   139  		FragmentOffset: 16,
   140  		TTL:            48,
   141  		Protocol:       99,
   142  		SrcAddr:        srcAddrV4,
   143  		DstAddr:        dstAddrV4,
   144  	})
   145  	ip.SetChecksum(0)
   146  	ip.SetChecksum(^ip.CalculateChecksum())
   147  	buf := buffer.MakeWithData(hdr.View())
   148  	return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
   149  }
   150  
   151  func TestIPTablesStatsForInput(t *testing.T) {
   152  	tests := []struct {
   153  		name               string
   154  		setupStack         func(*testing.T) (*stack.Stack, *channel.Endpoint)
   155  		setupFilter        func(*testing.T, *stack.Stack)
   156  		genPacket          func() *stack.PacketBuffer
   157  		proto              tcpip.NetworkProtocolNumber
   158  		expectReceived     int
   159  		expectInputDropped int
   160  	}{
   161  		{
   162  			name:               "IPv6 Accept",
   163  			setupStack:         genStackV6,
   164  			setupFilter:        func(*testing.T, *stack.Stack) { /* no filter */ },
   165  			genPacket:          genPacketV6,
   166  			proto:              header.IPv6ProtocolNumber,
   167  			expectReceived:     1,
   168  			expectInputDropped: 0,
   169  		},
   170  		{
   171  			name:               "IPv4 Accept",
   172  			setupStack:         genStackV4,
   173  			setupFilter:        func(*testing.T, *stack.Stack) { /* no filter */ },
   174  			genPacket:          genPacketV4,
   175  			proto:              header.IPv4ProtocolNumber,
   176  			expectReceived:     1,
   177  			expectInputDropped: 0,
   178  		},
   179  		{
   180  			name:       "IPv6 Drop (input interface matches)",
   181  			setupStack: genStackV6,
   182  			setupFilter: func(t *testing.T, s *stack.Stack) {
   183  				t.Helper()
   184  				ipt := s.IPTables()
   185  				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
   186  				ruleIdx := filter.BuiltinChains[stack.Input]
   187  				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
   188  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   189  				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
   190  				// Make sure the packet is not dropped by the next rule.
   191  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   192  				ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */)
   193  			},
   194  			genPacket:          genPacketV6,
   195  			proto:              header.IPv6ProtocolNumber,
   196  			expectReceived:     1,
   197  			expectInputDropped: 1,
   198  		},
   199  		{
   200  			name:       "IPv4 Drop (input interface matches)",
   201  			setupStack: genStackV4,
   202  			setupFilter: func(t *testing.T, s *stack.Stack) {
   203  				t.Helper()
   204  				ipt := s.IPTables()
   205  				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
   206  				ruleIdx := filter.BuiltinChains[stack.Input]
   207  				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
   208  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   209  				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
   210  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   211  				ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */)
   212  			},
   213  			genPacket:          genPacketV4,
   214  			proto:              header.IPv4ProtocolNumber,
   215  			expectReceived:     1,
   216  			expectInputDropped: 1,
   217  		},
   218  		{
   219  			name:       "IPv6 Accept (input interface does not match)",
   220  			setupStack: genStackV6,
   221  			setupFilter: func(t *testing.T, s *stack.Stack) {
   222  				t.Helper()
   223  				ipt := s.IPTables()
   224  				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
   225  				ruleIdx := filter.BuiltinChains[stack.Input]
   226  				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
   227  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   228  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   229  				ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */)
   230  			},
   231  			genPacket:          genPacketV6,
   232  			proto:              header.IPv6ProtocolNumber,
   233  			expectReceived:     1,
   234  			expectInputDropped: 0,
   235  		},
   236  		{
   237  			name:       "IPv4 Accept (input interface does not match)",
   238  			setupStack: genStackV4,
   239  			setupFilter: func(t *testing.T, s *stack.Stack) {
   240  				t.Helper()
   241  				ipt := s.IPTables()
   242  				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
   243  				ruleIdx := filter.BuiltinChains[stack.Input]
   244  				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
   245  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   246  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   247  				ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */)
   248  			},
   249  			genPacket:          genPacketV4,
   250  			proto:              header.IPv4ProtocolNumber,
   251  			expectReceived:     1,
   252  			expectInputDropped: 0,
   253  		},
   254  		{
   255  			name:       "IPv6 Drop (input interface does not match but invert is true)",
   256  			setupStack: genStackV6,
   257  			setupFilter: func(t *testing.T, s *stack.Stack) {
   258  				t.Helper()
   259  				ipt := s.IPTables()
   260  				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
   261  				ruleIdx := filter.BuiltinChains[stack.Input]
   262  				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
   263  					InputInterface:       anotherNicName,
   264  					InputInterfaceInvert: true,
   265  				}
   266  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   267  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   268  				ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */)
   269  			},
   270  			genPacket:          genPacketV6,
   271  			proto:              header.IPv6ProtocolNumber,
   272  			expectReceived:     1,
   273  			expectInputDropped: 1,
   274  		},
   275  		{
   276  			name:       "IPv4 Drop (input interface does not match but invert is true)",
   277  			setupStack: genStackV4,
   278  			setupFilter: func(t *testing.T, s *stack.Stack) {
   279  				t.Helper()
   280  				ipt := s.IPTables()
   281  				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
   282  				ruleIdx := filter.BuiltinChains[stack.Input]
   283  				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
   284  					InputInterface:       anotherNicName,
   285  					InputInterfaceInvert: true,
   286  				}
   287  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   288  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   289  				ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */)
   290  			},
   291  			genPacket:          genPacketV4,
   292  			proto:              header.IPv4ProtocolNumber,
   293  			expectReceived:     1,
   294  			expectInputDropped: 1,
   295  		},
   296  		{
   297  			name:       "IPv6 Accept (input interface does not match using a matcher)",
   298  			setupStack: genStackV6,
   299  			setupFilter: func(t *testing.T, s *stack.Stack) {
   300  				t.Helper()
   301  				ipt := s.IPTables()
   302  				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
   303  				ruleIdx := filter.BuiltinChains[stack.Input]
   304  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   305  				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
   306  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   307  				ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */)
   308  			},
   309  			genPacket:          genPacketV6,
   310  			proto:              header.IPv6ProtocolNumber,
   311  			expectReceived:     1,
   312  			expectInputDropped: 0,
   313  		},
   314  		{
   315  			name:       "IPv4 Accept (input interface does not match using a matcher)",
   316  			setupStack: genStackV4,
   317  			setupFilter: func(t *testing.T, s *stack.Stack) {
   318  				t.Helper()
   319  				ipt := s.IPTables()
   320  				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
   321  				ruleIdx := filter.BuiltinChains[stack.Input]
   322  				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
   323  				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
   324  				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
   325  				ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */)
   326  			},
   327  			genPacket:          genPacketV4,
   328  			proto:              header.IPv4ProtocolNumber,
   329  			expectReceived:     1,
   330  			expectInputDropped: 0,
   331  		},
   332  	}
   333  
   334  	for _, test := range tests {
   335  		t.Run(test.name, func(t *testing.T) {
   336  			s, e := test.setupStack(t)
   337  			defer s.Destroy()
   338  			test.setupFilter(t, s)
   339  			e.InjectInbound(test.proto, test.genPacket())
   340  
   341  			if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
   342  				t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
   343  			}
   344  			if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
   345  				t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
   346  			}
   347  		})
   348  	}
   349  }
   350  
   351  var _ stack.LinkEndpoint = (*channelEndpoint)(nil)
   352  
   353  type channelEndpoint struct {
   354  	*channel.Endpoint
   355  
   356  	t *testing.T
   357  }
   358  
   359  var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
   360  
   361  type udpSourcePortMatcher struct {
   362  	port uint16
   363  }
   364  
   365  func (*udpSourcePortMatcher) Name() string {
   366  	return "udpSourcePortMatcher"
   367  }
   368  
   369  func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
   370  	udp := header.UDP(pkt.TransportHeader().Slice())
   371  	if len(udp) < header.UDPMinimumSize {
   372  		// Drop immediately as the packet is invalid.
   373  		return false, true
   374  	}
   375  
   376  	return udp.SourcePort() == m.port, false
   377  }
   378  
   379  func TestIPTableWritePackets(t *testing.T) {
   380  	const (
   381  		nicID = 1
   382  
   383  		dropLocalPort = utils.LocalPort - 1
   384  		acceptPackets = 2
   385  		dropPackets   = 3
   386  	)
   387  
   388  	udpHdr := func(hdr []byte, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
   389  		u := header.UDP(hdr)
   390  		u.Encode(&header.UDPFields{
   391  			SrcPort: srcPort,
   392  			DstPort: dstPort,
   393  			Length:  header.UDPMinimumSize,
   394  		})
   395  		sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
   396  		sum = checksum.Checksum(hdr, sum)
   397  		u.SetChecksum(^u.CalculateChecksum(sum))
   398  	}
   399  
   400  	tests := []struct {
   401  		name                string
   402  		setupFilter         func(*testing.T, *stack.Stack)
   403  		genPacket           func(*stack.Route) stack.PacketBufferList
   404  		proto               tcpip.NetworkProtocolNumber
   405  		remoteAddr          tcpip.Address
   406  		expectSent          uint64
   407  		expectOutputDropped uint64
   408  	}{
   409  		{
   410  			name:        "IPv4 Accept",
   411  			setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
   412  			genPacket: func(r *stack.Route) stack.PacketBufferList {
   413  				var pkts stack.PacketBufferList
   414  
   415  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   416  					ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
   417  				})
   418  				hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
   419  				udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
   420  				pkts.PushBack(pkt)
   421  
   422  				return pkts
   423  			},
   424  			proto:               header.IPv4ProtocolNumber,
   425  			remoteAddr:          dstAddrV4,
   426  			expectSent:          1,
   427  			expectOutputDropped: 0,
   428  		},
   429  		{
   430  			name: "IPv4 Drop Other Port",
   431  			setupFilter: func(t *testing.T, s *stack.Stack) {
   432  				t.Helper()
   433  
   434  				table := stack.Table{
   435  					Rules: []stack.Rule{
   436  						{
   437  							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
   438  						},
   439  						{
   440  							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
   441  						},
   442  						{
   443  							Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
   444  							Target:   &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
   445  						},
   446  						{
   447  							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
   448  						},
   449  						{
   450  							Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
   451  						},
   452  					},
   453  					BuiltinChains: [stack.NumHooks]int{
   454  						stack.Prerouting:  stack.HookUnset,
   455  						stack.Input:       0,
   456  						stack.Forward:     1,
   457  						stack.Output:      2,
   458  						stack.Postrouting: stack.HookUnset,
   459  					},
   460  					Underflows: [stack.NumHooks]int{
   461  						stack.Prerouting:  stack.HookUnset,
   462  						stack.Input:       0,
   463  						stack.Forward:     1,
   464  						stack.Output:      2,
   465  						stack.Postrouting: stack.HookUnset,
   466  					},
   467  				}
   468  
   469  				s.IPTables().ForceReplaceTable(stack.FilterID, table, false /* ipv4 */)
   470  			},
   471  			genPacket: func(r *stack.Route) stack.PacketBufferList {
   472  				var pkts stack.PacketBufferList
   473  
   474  				for i := 0; i < acceptPackets; i++ {
   475  					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   476  						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
   477  					})
   478  					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
   479  					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
   480  					pkts.PushBack(pkt)
   481  				}
   482  				for i := 0; i < dropPackets; i++ {
   483  					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   484  						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
   485  					})
   486  					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
   487  					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
   488  					pkts.PushBack(pkt)
   489  				}
   490  
   491  				return pkts
   492  			},
   493  			proto:               header.IPv4ProtocolNumber,
   494  			remoteAddr:          dstAddrV4,
   495  			expectSent:          acceptPackets,
   496  			expectOutputDropped: dropPackets,
   497  		},
   498  		{
   499  			name:        "IPv6 Accept",
   500  			setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
   501  			genPacket: func(r *stack.Route) stack.PacketBufferList {
   502  				var pkts stack.PacketBufferList
   503  
   504  				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   505  					ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
   506  				})
   507  				hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
   508  				udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
   509  				pkts.PushBack(pkt)
   510  
   511  				return pkts
   512  			},
   513  			proto:               header.IPv6ProtocolNumber,
   514  			remoteAddr:          dstAddrV6,
   515  			expectSent:          1,
   516  			expectOutputDropped: 0,
   517  		},
   518  		{
   519  			name: "IPv6 Drop Other Port",
   520  			setupFilter: func(t *testing.T, s *stack.Stack) {
   521  				t.Helper()
   522  
   523  				table := stack.Table{
   524  					Rules: []stack.Rule{
   525  						{
   526  							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
   527  						},
   528  						{
   529  							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
   530  						},
   531  						{
   532  							Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
   533  							Target:   &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
   534  						},
   535  						{
   536  							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
   537  						},
   538  						{
   539  							Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
   540  						},
   541  					},
   542  					BuiltinChains: [stack.NumHooks]int{
   543  						stack.Prerouting:  stack.HookUnset,
   544  						stack.Input:       0,
   545  						stack.Forward:     1,
   546  						stack.Output:      2,
   547  						stack.Postrouting: stack.HookUnset,
   548  					},
   549  					Underflows: [stack.NumHooks]int{
   550  						stack.Prerouting:  stack.HookUnset,
   551  						stack.Input:       0,
   552  						stack.Forward:     1,
   553  						stack.Output:      2,
   554  						stack.Postrouting: stack.HookUnset,
   555  					},
   556  				}
   557  
   558  				s.IPTables().ForceReplaceTable(stack.FilterID, table, true /* ipv6 */)
   559  			},
   560  			genPacket: func(r *stack.Route) stack.PacketBufferList {
   561  				var pkts stack.PacketBufferList
   562  
   563  				for i := 0; i < acceptPackets; i++ {
   564  					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   565  						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
   566  					})
   567  					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
   568  					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
   569  					pkts.PushBack(pkt)
   570  				}
   571  				for i := 0; i < dropPackets; i++ {
   572  					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   573  						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
   574  					})
   575  					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
   576  					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
   577  					pkts.PushBack(pkt)
   578  				}
   579  
   580  				return pkts
   581  			},
   582  			proto:               header.IPv6ProtocolNumber,
   583  			remoteAddr:          dstAddrV6,
   584  			expectSent:          acceptPackets,
   585  			expectOutputDropped: dropPackets,
   586  		},
   587  	}
   588  
   589  	for _, test := range tests {
   590  		t.Run(test.name, func(t *testing.T) {
   591  			s := stack.New(stack.Options{
   592  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   593  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   594  			})
   595  			defer s.Destroy()
   596  			e := channelEndpoint{
   597  				Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
   598  				t:        t,
   599  			}
   600  			if err := s.CreateNIC(nicID, &e); err != nil {
   601  				t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
   602  			}
   603  			protocolAddrV6 := tcpip.ProtocolAddress{
   604  				Protocol:          header.IPv6ProtocolNumber,
   605  				AddressWithPrefix: srcAddrV6.WithPrefix(),
   606  			}
   607  			if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
   608  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
   609  			}
   610  			protocolAddrV4 := tcpip.ProtocolAddress{
   611  				Protocol:          header.IPv4ProtocolNumber,
   612  				AddressWithPrefix: srcAddrV4.WithPrefix(),
   613  			}
   614  			if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
   615  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
   616  			}
   617  
   618  			s.SetRouteTable([]tcpip.Route{
   619  				{
   620  					Destination: header.IPv4EmptySubnet,
   621  					NIC:         nicID,
   622  				},
   623  				{
   624  					Destination: header.IPv6EmptySubnet,
   625  					NIC:         nicID,
   626  				},
   627  			})
   628  
   629  			test.setupFilter(t, s)
   630  
   631  			r, err := s.FindRoute(nicID, tcpip.Address{}, test.remoteAddr, test.proto, false)
   632  			if err != nil {
   633  				t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
   634  			}
   635  			defer r.Release()
   636  
   637  			pkts := test.genPacket(r)
   638  			for _, pkt := range pkts.AsSlice() {
   639  				if err := r.WritePacket(stack.NetworkHeaderParams{
   640  					Protocol: header.UDPProtocolNumber,
   641  					TTL:      64,
   642  				}, pkt); err != nil {
   643  					t.Fatalf("WritePacket(...): %s", err)
   644  				}
   645  				pkt.DecRef()
   646  			}
   647  
   648  			if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
   649  				t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
   650  			}
   651  			if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
   652  				t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
   653  			}
   654  		})
   655  	}
   656  }
   657  
   658  const ttl = 64
   659  
   660  var (
   661  	ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
   662  	ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
   663  )
   664  
   665  func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
   666  	utils.RxICMPv4EchoReply(e, src, dst, ttl)
   667  }
   668  
   669  func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
   670  	utils.RxICMPv6EchoReply(e, src, dst, ttl)
   671  }
   672  
   673  func forwardedICMPv4EchoReplyChecker(t *testing.T, v *buffer.View, src, dst tcpip.Address) {
   674  	checker.IPv4(t, v,
   675  		checker.SrcAddr(src),
   676  		checker.DstAddr(dst),
   677  		checker.TTL(ttl-1),
   678  		checker.ICMPv4(
   679  			checker.ICMPv4Type(header.ICMPv4EchoReply)))
   680  }
   681  
   682  func forwardedICMPv6EchoReplyChecker(t *testing.T, v *buffer.View, src, dst tcpip.Address) {
   683  	checker.IPv6(t, v,
   684  		checker.SrcAddr(src),
   685  		checker.DstAddr(dst),
   686  		checker.TTL(ttl-1),
   687  		checker.ICMPv6(
   688  			checker.ICMPv6Type(header.ICMPv6EchoReply)))
   689  }
   690  
   691  func boolToInt(v bool) uint64 {
   692  	if v {
   693  		return 1
   694  	}
   695  	return 0
   696  }
   697  
   698  func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
   699  	return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
   700  		t.Helper()
   701  
   702  		ipv6 := netProto == ipv6.ProtocolNumber
   703  
   704  		ipt := s.IPTables()
   705  		filter := ipt.GetTable(stack.FilterID, ipv6)
   706  		ruleIdx := filter.BuiltinChains[hook]
   707  		filter.Rules[ruleIdx].Filter = f
   708  		filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
   709  		// Make sure the packet is not dropped by the next rule.
   710  		filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
   711  		ipt.ForceReplaceTable(stack.FilterID, filter, ipv6)
   712  	}
   713  }
   714  
   715  func TestForwardingHook(t *testing.T) {
   716  	const (
   717  		nicID1 = 1
   718  		nicID2 = 2
   719  
   720  		nic1Name = "nic1"
   721  		nic2Name = "nic2"
   722  
   723  		otherNICName = "otherNIC"
   724  	)
   725  
   726  	tests := []struct {
   727  		name             string
   728  		netProto         tcpip.NetworkProtocolNumber
   729  		local            bool
   730  		srcAddr, dstAddr tcpip.Address
   731  		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
   732  		checker          func(*testing.T, *buffer.View)
   733  	}{
   734  		{
   735  			name:     "IPv4 remote",
   736  			netProto: ipv4.ProtocolNumber,
   737  			local:    false,
   738  			srcAddr:  utils.RemoteIPv4Addr,
   739  			dstAddr:  utils.Ipv4Addr2.AddressWithPrefix.Address,
   740  			rx:       rxICMPv4EchoReply,
   741  			checker: func(t *testing.T, v *buffer.View) {
   742  				forwardedICMPv4EchoReplyChecker(t, v, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
   743  			},
   744  		},
   745  		{
   746  			name:     "IPv4 local",
   747  			netProto: ipv4.ProtocolNumber,
   748  			local:    true,
   749  			srcAddr:  utils.RemoteIPv4Addr,
   750  			dstAddr:  utils.Ipv4Addr.Address,
   751  			rx:       rxICMPv4EchoReply,
   752  		},
   753  		{
   754  			name:     "IPv6 remote",
   755  			netProto: ipv6.ProtocolNumber,
   756  			local:    false,
   757  			srcAddr:  utils.RemoteIPv6Addr,
   758  			dstAddr:  utils.Ipv6Addr2.AddressWithPrefix.Address,
   759  			rx:       rxICMPv6EchoReply,
   760  			checker: func(t *testing.T, v *buffer.View) {
   761  				forwardedICMPv6EchoReplyChecker(t, v, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
   762  			},
   763  		},
   764  		{
   765  			name:     "IPv6 local",
   766  			netProto: ipv6.ProtocolNumber,
   767  			local:    true,
   768  			srcAddr:  utils.RemoteIPv6Addr,
   769  			dstAddr:  utils.Ipv6Addr.Address,
   770  			rx:       rxICMPv6EchoReply,
   771  		},
   772  	}
   773  
   774  	subTests := []struct {
   775  		name          string
   776  		setupFilter   func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
   777  		expectForward bool
   778  	}{
   779  		{
   780  			name:          "Accept",
   781  			setupFilter:   func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
   782  			expectForward: true,
   783  		},
   784  
   785  		{
   786  			name:          "Drop",
   787  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
   788  			expectForward: false,
   789  		},
   790  		{
   791  			name:          "Drop with input NIC filtering",
   792  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
   793  			expectForward: false,
   794  		},
   795  		{
   796  			name:          "Drop with output NIC filtering",
   797  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
   798  			expectForward: false,
   799  		},
   800  		{
   801  			name:          "Drop with input and output NIC filtering",
   802  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
   803  			expectForward: false,
   804  		},
   805  
   806  		{
   807  			name:          "Drop with other input NIC filtering",
   808  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
   809  			expectForward: true,
   810  		},
   811  		{
   812  			name:          "Drop with other output NIC filtering",
   813  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
   814  			expectForward: true,
   815  		},
   816  		{
   817  			name:          "Drop with other input and output NIC filtering",
   818  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
   819  			expectForward: true,
   820  		},
   821  		{
   822  			name:          "Drop with input and other output NIC filtering",
   823  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
   824  			expectForward: true,
   825  		},
   826  		{
   827  			name:          "Drop with other input and other output NIC filtering",
   828  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
   829  			expectForward: true,
   830  		},
   831  
   832  		{
   833  			name:          "Drop with inverted input NIC filtering",
   834  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
   835  			expectForward: true,
   836  		},
   837  		{
   838  			name:          "Drop with inverted output NIC filtering",
   839  			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
   840  			expectForward: true,
   841  		},
   842  	}
   843  
   844  	for _, test := range tests {
   845  		t.Run(test.name, func(t *testing.T) {
   846  			for _, subTest := range subTests {
   847  				t.Run(subTest.name, func(t *testing.T) {
   848  					s := stack.New(stack.Options{
   849  						NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   850  					})
   851  					defer s.Destroy()
   852  
   853  					subTest.setupFilter(t, s, test.netProto)
   854  
   855  					e1 := channel.New(1, header.IPv6MinimumMTU, "")
   856  					if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
   857  						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
   858  					}
   859  
   860  					e2 := channel.New(1, header.IPv6MinimumMTU, "")
   861  					if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
   862  						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
   863  					}
   864  
   865  					protocolAddrV4 := tcpip.ProtocolAddress{
   866  						Protocol:          ipv4.ProtocolNumber,
   867  						AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
   868  					}
   869  					if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
   870  						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
   871  					}
   872  					protocolAddrV6 := tcpip.ProtocolAddress{
   873  						Protocol:          ipv6.ProtocolNumber,
   874  						AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
   875  					}
   876  					if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
   877  						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
   878  					}
   879  
   880  					if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
   881  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
   882  					}
   883  					if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
   884  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
   885  					}
   886  
   887  					s.SetRouteTable([]tcpip.Route{
   888  						{
   889  							Destination: header.IPv4EmptySubnet,
   890  							NIC:         nicID2,
   891  						},
   892  						{
   893  							Destination: header.IPv6EmptySubnet,
   894  							NIC:         nicID2,
   895  						},
   896  					})
   897  
   898  					test.rx(e1, test.srcAddr, test.dstAddr)
   899  
   900  					expectTransmitPacket := subTest.expectForward && !test.local
   901  
   902  					ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
   903  					if err != nil {
   904  						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
   905  					}
   906  					ep1Stats := ep1.Stats()
   907  					ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
   908  					if !ok {
   909  						t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
   910  					}
   911  					ip1Stats := ipEP1Stats.IPStats()
   912  
   913  					if got := ip1Stats.PacketsReceived.Value(); got != 1 {
   914  						t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
   915  					}
   916  					if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
   917  						t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
   918  					}
   919  					if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
   920  						t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
   921  					}
   922  					if got := ip1Stats.PacketsSent.Value(); got != 0 {
   923  						t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
   924  					}
   925  
   926  					ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
   927  					if err != nil {
   928  						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
   929  					}
   930  					ep2Stats := ep2.Stats()
   931  					ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
   932  					if !ok {
   933  						t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
   934  					}
   935  					ip2Stats := ipEP2Stats.IPStats()
   936  					if got := ip2Stats.PacketsReceived.Value(); got != 0 {
   937  						t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
   938  					}
   939  					if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
   940  						t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
   941  					}
   942  					if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
   943  						t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
   944  					}
   945  
   946  					p := e2.Read()
   947  					if (p != nil) != expectTransmitPacket {
   948  						t.Fatalf("got e2.Read() = %#v, want = (_ == nil) = %t", p, expectTransmitPacket)
   949  					}
   950  					if expectTransmitPacket {
   951  						payload := stack.PayloadSince(p.NetworkHeader())
   952  						defer payload.Release()
   953  						test.checker(t, payload)
   954  						p.DecRef()
   955  					}
   956  				})
   957  			}
   958  		})
   959  	}
   960  }
   961  
   962  func TestFilteringEchoPacketsWithLocalForwarding(t *testing.T) {
   963  	const (
   964  		nicID1 = 1
   965  		nicID2 = 2
   966  
   967  		nic1Name = "nic1"
   968  		nic2Name = "nic2"
   969  
   970  		otherNICName = "otherNIC"
   971  	)
   972  
   973  	tests := []struct {
   974  		name     string
   975  		netProto tcpip.NetworkProtocolNumber
   976  		rx       func(*channel.Endpoint)
   977  		checker  func(*testing.T, *buffer.View)
   978  	}{
   979  		{
   980  			name:     "IPv4",
   981  			netProto: ipv4.ProtocolNumber,
   982  			rx: func(e *channel.Endpoint) {
   983  				utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
   984  			},
   985  			checker: func(t *testing.T, v *buffer.View) {
   986  				checker.IPv4(t, v,
   987  					checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
   988  					checker.DstAddr(utils.RemoteIPv4Addr),
   989  					checker.ICMPv4(
   990  						checker.ICMPv4Type(header.ICMPv4EchoReply)))
   991  			},
   992  		},
   993  		{
   994  			name:     "IPv6",
   995  			netProto: ipv6.ProtocolNumber,
   996  			rx: func(e *channel.Endpoint) {
   997  				utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
   998  			},
   999  			checker: func(t *testing.T, v *buffer.View) {
  1000  				checker.IPv6(t, v,
  1001  					checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
  1002  					checker.DstAddr(utils.RemoteIPv6Addr),
  1003  					checker.ICMPv6(
  1004  						checker.ICMPv6Type(header.ICMPv6EchoReply)))
  1005  			},
  1006  		},
  1007  	}
  1008  
  1009  	type droppedEcho int
  1010  	const (
  1011  		_ droppedEcho = iota
  1012  		noneDropped
  1013  		echoRequestDroppedAtInput
  1014  		echoRequestDroppedAtForward
  1015  		echoReplyDropped
  1016  	)
  1017  
  1018  	subTests := []struct {
  1019  		name         string
  1020  		setupFilter  func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
  1021  		expectResult droppedEcho
  1022  	}{
  1023  		{
  1024  			name:         "Accept",
  1025  			setupFilter:  func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
  1026  			expectResult: noneDropped,
  1027  		},
  1028  
  1029  		{
  1030  			name:         "Input Drop",
  1031  			setupFilter:  setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
  1032  			expectResult: echoRequestDroppedAtInput,
  1033  		},
  1034  		{
  1035  			name:         "Input Drop with input NIC filtering on arrival NIC",
  1036  			setupFilter:  setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
  1037  			expectResult: echoRequestDroppedAtInput,
  1038  		},
  1039  		{
  1040  			name:         "Input Drop with input NIC filtering on delivered NIC",
  1041  			setupFilter:  setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
  1042  			expectResult: noneDropped,
  1043  		},
  1044  
  1045  		{
  1046  			name:         "Input Drop with input NIC filtering on other NIC",
  1047  			setupFilter:  setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
  1048  			expectResult: noneDropped,
  1049  		},
  1050  
  1051  		{
  1052  			name:         "Forward Drop",
  1053  			setupFilter:  setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
  1054  			expectResult: echoRequestDroppedAtForward,
  1055  		},
  1056  
  1057  		{
  1058  			name:         "Output Drop",
  1059  			setupFilter:  setupDropFilter(stack.Output, stack.IPHeaderFilter{}),
  1060  			expectResult: echoReplyDropped,
  1061  		},
  1062  	}
  1063  
  1064  	for _, test := range tests {
  1065  		t.Run(test.name, func(t *testing.T) {
  1066  			for _, subTest := range subTests {
  1067  				t.Run(subTest.name, func(t *testing.T) {
  1068  					s := stack.New(stack.Options{
  1069  						NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  1070  					})
  1071  					defer s.Destroy()
  1072  
  1073  					subTest.setupFilter(t, s, test.netProto)
  1074  
  1075  					e1 := channel.New(1, header.IPv6MinimumMTU, "")
  1076  					if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
  1077  						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
  1078  					}
  1079  					if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
  1080  						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
  1081  					}
  1082  					if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
  1083  						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
  1084  					}
  1085  
  1086  					e2 := channel.New(1, header.IPv6MinimumMTU, "")
  1087  					if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
  1088  						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
  1089  					}
  1090  					if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
  1091  						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
  1092  					}
  1093  					if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
  1094  						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
  1095  					}
  1096  
  1097  					if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
  1098  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
  1099  					}
  1100  					if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
  1101  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
  1102  					}
  1103  
  1104  					s.SetRouteTable([]tcpip.Route{
  1105  						{
  1106  							Destination: header.IPv4EmptySubnet,
  1107  							NIC:         nicID1,
  1108  						},
  1109  						{
  1110  							Destination: header.IPv6EmptySubnet,
  1111  							NIC:         nicID1,
  1112  						},
  1113  					})
  1114  
  1115  					test.rx(e1)
  1116  
  1117  					ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
  1118  					if err != nil {
  1119  						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
  1120  					}
  1121  					ep1Stats := ep1.Stats()
  1122  					ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
  1123  					if !ok {
  1124  						t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
  1125  					}
  1126  					ip1Stats := ipEP1Stats.IPStats()
  1127  
  1128  					if got := ip1Stats.PacketsReceived.Value(); got != 1 {
  1129  						t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
  1130  					}
  1131  					if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
  1132  						t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
  1133  					}
  1134  
  1135  					expectedIP1StatIPTablesForawrdDropped := uint64(0)
  1136  					expectedIP1StatIPTablesOutputDropped := uint64(0)
  1137  					expectedIP1StatPacketsSent := uint64(0)
  1138  					expectedIP2StatValidPacketsReceived := uint64(1)
  1139  					expectedIP2StatIPTablesInputDropped := uint64(0)
  1140  					switch subTest.expectResult {
  1141  					case noneDropped:
  1142  						expectedIP1StatPacketsSent = 1
  1143  					case echoRequestDroppedAtInput:
  1144  						expectedIP2StatIPTablesInputDropped = 1
  1145  					case echoRequestDroppedAtForward:
  1146  						expectedIP1StatIPTablesForawrdDropped = 1
  1147  						expectedIP2StatValidPacketsReceived = 0
  1148  					case echoReplyDropped:
  1149  						expectedIP1StatIPTablesOutputDropped = 1
  1150  					default:
  1151  						t.Fatalf("unhandled expectResult = %d", subTest.expectResult)
  1152  					}
  1153  
  1154  					if got := ip1Stats.IPTablesForwardDropped.Value(); got != expectedIP1StatIPTablesForawrdDropped {
  1155  						t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, expectedIP1StatIPTablesForawrdDropped)
  1156  					}
  1157  					if got := ip1Stats.IPTablesOutputDropped.Value(); got != expectedIP1StatIPTablesOutputDropped {
  1158  						t.Errorf("got ip1Stats.IPTablesOutputDropped.Value() = %d, want = %d", got, expectedIP1StatIPTablesOutputDropped)
  1159  					}
  1160  					if got := ip1Stats.PacketsSent.Value(); got != expectedIP1StatPacketsSent {
  1161  						t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, expectedIP1StatPacketsSent)
  1162  					}
  1163  
  1164  					ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
  1165  					if err != nil {
  1166  						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
  1167  					}
  1168  					ep2Stats := ep2.Stats()
  1169  					ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
  1170  					if !ok {
  1171  						t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
  1172  					}
  1173  					ip2Stats := ipEP2Stats.IPStats()
  1174  					if got := ip2Stats.PacketsReceived.Value(); got != 0 {
  1175  						t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
  1176  					}
  1177  					if got := ip2Stats.ValidPacketsReceived.Value(); got != expectedIP2StatValidPacketsReceived {
  1178  						t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, expectedIP2StatValidPacketsReceived)
  1179  					}
  1180  					if got := ip2Stats.IPTablesInputDropped.Value(); got != expectedIP2StatIPTablesInputDropped {
  1181  						t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, expectedIP2StatIPTablesInputDropped)
  1182  					}
  1183  					if got := ip2Stats.PacketsSent.Value(); got != 0 {
  1184  						t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
  1185  					}
  1186  
  1187  					expectPacket := subTest.expectResult == noneDropped
  1188  					p := e1.Read()
  1189  					if (p != nil) != expectPacket {
  1190  						t.Errorf("got e1.Read() = %#v, want = (_ == nil) = %t", p, expectPacket)
  1191  					}
  1192  					if p != nil {
  1193  						payload := stack.PayloadSince(p.NetworkHeader())
  1194  						defer payload.Release()
  1195  						test.checker(t, payload)
  1196  						p.DecRef()
  1197  					}
  1198  					if p := e2.Read(); p != nil {
  1199  						t.Errorf("got e1.Read() = %#v, want = nil)", p)
  1200  						p.DecRef()
  1201  					}
  1202  				})
  1203  			}
  1204  		})
  1205  	}
  1206  }
  1207  
  1208  func setupNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) {
  1209  	t.Helper()
  1210  
  1211  	ipv6 := netProto == ipv6.ProtocolNumber
  1212  	ipt := s.IPTables()
  1213  	table := ipt.GetTable(stack.NATID, ipv6)
  1214  	ruleIdx := table.BuiltinChains[hook]
  1215  	table.Rules[ruleIdx].Filter = filter
  1216  	table.Rules[ruleIdx].Target = target
  1217  	// Make sure the packet is not dropped by the next rule.
  1218  	table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
  1219  	ipt.ForceReplaceTable(stack.NATID, table, ipv6)
  1220  }
  1221  
  1222  func setupDNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
  1223  	t.Helper()
  1224  
  1225  	setupNAT(
  1226  		t,
  1227  		s,
  1228  		netProto,
  1229  		stack.Prerouting,
  1230  		stack.IPHeaderFilter{
  1231  			Protocol:       transProto,
  1232  			CheckProtocol:  true,
  1233  			InputInterface: utils.RouterNIC2Name,
  1234  		},
  1235  		target)
  1236  }
  1237  
  1238  func setupSNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
  1239  	t.Helper()
  1240  
  1241  	setupNAT(
  1242  		t,
  1243  		s,
  1244  		netProto,
  1245  		stack.Postrouting,
  1246  		stack.IPHeaderFilter{
  1247  			Protocol:        transProto,
  1248  			CheckProtocol:   true,
  1249  			OutputInterface: utils.RouterNIC1Name,
  1250  		},
  1251  		target)
  1252  }
  1253  
  1254  func setupTwiceNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, dnatTarget, snatTarget stack.Target) {
  1255  	t.Helper()
  1256  
  1257  	ipv6 := netProto == ipv6.ProtocolNumber
  1258  	ipt := s.IPTables()
  1259  
  1260  	table := stack.Table{
  1261  		Rules: []stack.Rule{
  1262  			// Prerouting
  1263  			{
  1264  				Filter: stack.IPHeaderFilter{
  1265  					Protocol:       transProto,
  1266  					CheckProtocol:  true,
  1267  					InputInterface: utils.RouterNIC2Name,
  1268  				},
  1269  				Target: dnatTarget,
  1270  			},
  1271  			{
  1272  				Target: &stack.AcceptTarget{},
  1273  			},
  1274  
  1275  			// Input
  1276  			{
  1277  				Target: &stack.AcceptTarget{},
  1278  			},
  1279  
  1280  			// Forward
  1281  			{
  1282  				Target: &stack.AcceptTarget{},
  1283  			},
  1284  
  1285  			// Output
  1286  			{
  1287  				Target: &stack.AcceptTarget{},
  1288  			},
  1289  
  1290  			// Postrouting
  1291  			{
  1292  				Filter: stack.IPHeaderFilter{
  1293  					Protocol:        transProto,
  1294  					CheckProtocol:   true,
  1295  					OutputInterface: utils.RouterNIC1Name,
  1296  				},
  1297  				Target: snatTarget,
  1298  			},
  1299  			{
  1300  				Target: &stack.AcceptTarget{},
  1301  			},
  1302  		},
  1303  		BuiltinChains: [stack.NumHooks]int{
  1304  			stack.Prerouting:  0,
  1305  			stack.Input:       2,
  1306  			stack.Forward:     3,
  1307  			stack.Output:      4,
  1308  			stack.Postrouting: 5,
  1309  		},
  1310  	}
  1311  
  1312  	ipt.ForceReplaceTable(stack.NATID, table, ipv6)
  1313  }
  1314  
  1315  type natType struct {
  1316  	name     string
  1317  	setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address, dnatPort uint16)
  1318  }
  1319  
  1320  var (
  1321  	snatTypes = []natType{
  1322  		{
  1323  			name: "SNAT",
  1324  			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address, _ uint16) {
  1325  				t.Helper()
  1326  
  1327  				setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr, ChangeAddress: true, ChangePort: true})
  1328  			},
  1329  		},
  1330  		{
  1331  			name: "Masquerade",
  1332  			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address, _ uint16) {
  1333  				t.Helper()
  1334  
  1335  				setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
  1336  			},
  1337  		},
  1338  	}
  1339  
  1340  	dnatTarget = natType{
  1341  		name: "DNAT",
  1342  		setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address, dnatPort uint16) {
  1343  			t.Helper()
  1344  
  1345  			setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: dnatPort, ChangeAddress: true, ChangePort: true})
  1346  		},
  1347  	}
  1348  
  1349  	dnatTypes = []natType{
  1350  		{
  1351  			name: "Redirect",
  1352  			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address, dnatPort uint16) {
  1353  				t.Helper()
  1354  
  1355  				setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: dnatPort})
  1356  			},
  1357  		},
  1358  		dnatTarget,
  1359  	}
  1360  
  1361  	twiceNATTypes = []natType{
  1362  		{
  1363  			name: "DNAT-Masquerade",
  1364  			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address, dnatPort uint16) {
  1365  				t.Helper()
  1366  
  1367  				setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: dnatPort, ChangeAddress: true, ChangePort: true}, &stack.MasqueradeTarget{NetworkProtocol: netProto})
  1368  			},
  1369  		},
  1370  		{
  1371  			name: "DNAT-SNAT",
  1372  			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address, dnatPort uint16) {
  1373  				t.Helper()
  1374  
  1375  				setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: dnatPort, ChangeAddress: true, ChangePort: true}, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr, ChangeAddress: true, ChangePort: true})
  1376  			},
  1377  		},
  1378  	}
  1379  )
  1380  
  1381  func TestNATEcho(t *testing.T) {
  1382  	const ident = 1
  1383  
  1384  	v4EchoPkt := func(srcAddr, dstAddr tcpip.Address, reply bool) []byte {
  1385  		icmpType := header.ICMPv4Echo
  1386  		if reply {
  1387  			icmpType = header.ICMPv4EchoReply
  1388  		}
  1389  
  1390  		return icmpv4Packet(srcAddr, dstAddr, icmpType, ident)
  1391  	}
  1392  
  1393  	checkV4EchoPkt := func(t *testing.T, v *buffer.View, srcAddr, dstAddr tcpip.Address, reply bool) {
  1394  		t.Helper()
  1395  
  1396  		icmpType := header.ICMPv4Echo
  1397  		if reply {
  1398  			icmpType = header.ICMPv4EchoReply
  1399  		}
  1400  
  1401  		checker.IPv4(t, v,
  1402  			checker.SrcAddr(srcAddr),
  1403  			checker.DstAddr(dstAddr),
  1404  			checker.ICMPv4(
  1405  				checker.ICMPv4Type(icmpType),
  1406  				checker.ICMPv4Checksum(),
  1407  			),
  1408  		)
  1409  	}
  1410  
  1411  	v6EchoPkt := func(srcAddr, dstAddr tcpip.Address, reply bool) []byte {
  1412  		icmpType := header.ICMPv6EchoRequest
  1413  		if reply {
  1414  			icmpType = header.ICMPv6EchoReply
  1415  		}
  1416  
  1417  		return icmpv6Packet(srcAddr, dstAddr, icmpType, ident)
  1418  	}
  1419  
  1420  	checkV6EchoPkt := func(t *testing.T, v *buffer.View, srcAddr, dstAddr tcpip.Address, reply bool) {
  1421  		t.Helper()
  1422  
  1423  		icmpType := header.ICMPv6EchoRequest
  1424  		if reply {
  1425  			icmpType = header.ICMPv6EchoReply
  1426  		}
  1427  
  1428  		checker.IPv6(t, v,
  1429  			checker.SrcAddr(srcAddr),
  1430  			checker.DstAddr(dstAddr),
  1431  			checker.ICMPv6(
  1432  				checker.ICMPv6Type(icmpType),
  1433  			),
  1434  		)
  1435  	}
  1436  
  1437  	type natTypeTest struct {
  1438  		name                                   string
  1439  		natTypes                               []natType
  1440  		requestSrc, requestDst                 tcpip.Address
  1441  		expectedRequestSrc, expectedRequestDst tcpip.Address
  1442  	}
  1443  
  1444  	tests := []struct {
  1445  		name         string
  1446  		netProto     tcpip.NetworkProtocolNumber
  1447  		transProto   tcpip.TransportProtocolNumber
  1448  		echoPkt      func(srcAddr, dstAddr tcpip.Address, reply bool) []byte
  1449  		checkEchoPkt func(t *testing.T, v *buffer.View, srcAddr, dstAddr tcpip.Address, reply bool)
  1450  
  1451  		natTypes []natTypeTest
  1452  	}{
  1453  		{
  1454  			name:         "IPv4",
  1455  			netProto:     header.IPv4ProtocolNumber,
  1456  			transProto:   header.ICMPv4ProtocolNumber,
  1457  			echoPkt:      v4EchoPkt,
  1458  			checkEchoPkt: checkV4EchoPkt,
  1459  
  1460  			natTypes: []natTypeTest{
  1461  				{
  1462  					name:               "SNAT",
  1463  					natTypes:           snatTypes,
  1464  					requestSrc:         utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1465  					requestDst:         utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1466  					expectedRequestSrc: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1467  					expectedRequestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1468  				},
  1469  				{
  1470  					name:               "DNAT",
  1471  					natTypes:           []natType{dnatTarget},
  1472  					requestSrc:         utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1473  					requestDst:         utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
  1474  					expectedRequestSrc: utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1475  					expectedRequestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1476  				},
  1477  				{
  1478  					name:               "Twice-NAT",
  1479  					natTypes:           twiceNATTypes,
  1480  					requestSrc:         utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1481  					requestDst:         utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
  1482  					expectedRequestSrc: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  1483  					expectedRequestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1484  				},
  1485  			},
  1486  		},
  1487  		{
  1488  			name:         "IPv6",
  1489  			netProto:     header.IPv6ProtocolNumber,
  1490  			transProto:   header.ICMPv6ProtocolNumber,
  1491  			echoPkt:      v6EchoPkt,
  1492  			checkEchoPkt: checkV6EchoPkt,
  1493  
  1494  			natTypes: []natTypeTest{
  1495  				{
  1496  					name:               "SNAT",
  1497  					natTypes:           snatTypes,
  1498  					requestSrc:         utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1499  					requestDst:         utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1500  					expectedRequestSrc: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1501  					expectedRequestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1502  				},
  1503  				{
  1504  					name:               "DNAT",
  1505  					natTypes:           []natType{dnatTarget},
  1506  					requestSrc:         utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1507  					requestDst:         utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
  1508  					expectedRequestSrc: utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1509  					expectedRequestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1510  				},
  1511  				{
  1512  					name:               "Twice-NAT",
  1513  					natTypes:           twiceNATTypes,
  1514  					requestSrc:         utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1515  					requestDst:         utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
  1516  					expectedRequestSrc: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  1517  					expectedRequestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1518  				},
  1519  			},
  1520  		},
  1521  	}
  1522  
  1523  	for _, test := range tests {
  1524  		t.Run(test.name, func(t *testing.T) {
  1525  			for _, natTypeTest := range test.natTypes {
  1526  				t.Run(natTypeTest.name, func(t *testing.T) {
  1527  					for _, natType := range natTypeTest.natTypes {
  1528  						t.Run(natType.name, func(t *testing.T) {
  1529  							s := stack.New(stack.Options{
  1530  								NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  1531  								TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
  1532  							})
  1533  							defer s.Destroy()
  1534  
  1535  							ep1 := channel.New(1, header.IPv6MinimumMTU, "")
  1536  							ep2 := channel.New(1, header.IPv6MinimumMTU, "")
  1537  							utils.SetupRouterStack(t, s, ep1, ep2)
  1538  
  1539  							natType.setupNAT(t, s, test.netProto, test.transProto, natTypeTest.expectedRequestSrc, natTypeTest.expectedRequestDst, 0 /* dnatPort */)
  1540  
  1541  							// Send and check the Echo Request.
  1542  							{
  1543  								ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
  1544  									Payload: buffer.MakeWithData(test.echoPkt(natTypeTest.requestSrc, natTypeTest.requestDst, false /* reply */)),
  1545  								}))
  1546  								pkt := ep1.Read()
  1547  								if pkt == nil {
  1548  									t.Fatal("expected to read a packet on ep1")
  1549  								}
  1550  								payload := stack.PayloadSince(pkt.NetworkHeader())
  1551  								defer payload.Release()
  1552  								test.checkEchoPkt(t, payload, natTypeTest.expectedRequestSrc, natTypeTest.expectedRequestDst, false /* reply */)
  1553  								pkt.DecRef()
  1554  							}
  1555  
  1556  							if t.Failed() {
  1557  								t.FailNow()
  1558  							}
  1559  
  1560  							// Send and check the Echo Reply.
  1561  							{
  1562  								ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
  1563  									Payload: buffer.MakeWithData(test.echoPkt(natTypeTest.expectedRequestDst, natTypeTest.expectedRequestSrc, true /* reply */)),
  1564  								}))
  1565  								pkt := ep2.Read()
  1566  								if pkt == nil {
  1567  									t.Fatal("expected to read a packet on ep2")
  1568  								}
  1569  								payload := stack.PayloadSince(pkt.NetworkHeader())
  1570  								defer payload.Release()
  1571  								test.checkEchoPkt(t, payload, natTypeTest.requestDst, natTypeTest.requestSrc, true /* reply */)
  1572  								pkt.DecRef()
  1573  							}
  1574  						})
  1575  					}
  1576  				})
  1577  			}
  1578  		})
  1579  	}
  1580  }
  1581  
  1582  func TestNAT(t *testing.T) {
  1583  	const listenPort uint16 = 8080
  1584  
  1585  	type endpointAndAddresses struct {
  1586  		serverEP          tcpip.Endpoint
  1587  		serverAddr        tcpip.FullAddress
  1588  		serverReadableCH  chan struct{}
  1589  		serverConnectAddr tcpip.Address
  1590  
  1591  		clientEP          tcpip.Endpoint
  1592  		clientAddr        tcpip.Address
  1593  		clientReadableCH  chan struct{}
  1594  		clientConnectAddr tcpip.FullAddress
  1595  	}
  1596  
  1597  	newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
  1598  		t.Helper()
  1599  		var wq waiter.Queue
  1600  		we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
  1601  		wq.EventRegister(&we)
  1602  		t.Cleanup(func() {
  1603  			wq.EventUnregister(&we)
  1604  		})
  1605  
  1606  		ep, err := s.NewEndpoint(transProto, netProto, &wq)
  1607  		if err != nil {
  1608  			t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
  1609  		}
  1610  		t.Cleanup(ep.Close)
  1611  
  1612  		return ep, ch
  1613  	}
  1614  
  1615  	tests := []struct {
  1616  		name     string
  1617  		netProto tcpip.NetworkProtocolNumber
  1618  		// Setups up the stacks in such a way that:
  1619  		//
  1620  		//	- Host2 is the client for all tests.
  1621  		//	- When performing SNAT only:
  1622  		//   + Host1 is the server.
  1623  		//   + NAT will transform client-originating packets' source addresses to
  1624  		//     the router's NIC1's address before reaching Host1.
  1625  		//	- When performing DNAT only:
  1626  		//   + Router is the server.
  1627  		//   + Client will send packets directed to Host1.
  1628  		//   + NAT will transform client-originating packets' destination addresses
  1629  		//     to the router's NIC2's address.
  1630  		//	- When performing Twice-NAT:
  1631  		//   + Host1 is the server.
  1632  		//   + Client will send packets directed to router's NIC2.
  1633  		//   + NAT will transform client originating packets' destination addresses
  1634  		//     to Host1's address.
  1635  		//   + NAT will transform client-originating packets' source addresses to
  1636  		//     the router's NIC1's address before reaching Host1.
  1637  		epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
  1638  		natTypes   []natType
  1639  	}{
  1640  		{
  1641  			name:     "IPv4 SNAT",
  1642  			netProto: ipv4.ProtocolNumber,
  1643  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
  1644  				t.Helper()
  1645  
  1646  				listenerStack := host1Stack
  1647  				serverAddr := tcpip.FullAddress{
  1648  					Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1649  					Port: listenPort,
  1650  				}
  1651  				serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
  1652  				clientConnectPort := serverAddr.Port
  1653  				ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
  1654  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
  1655  				return endpointAndAddresses{
  1656  					serverEP:          ep1,
  1657  					serverAddr:        serverAddr,
  1658  					serverReadableCH:  ep1WECH,
  1659  					serverConnectAddr: serverConnectAddr,
  1660  
  1661  					clientEP:         ep2,
  1662  					clientAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1663  					clientReadableCH: ep2WECH,
  1664  					clientConnectAddr: tcpip.FullAddress{
  1665  						Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1666  						Port: clientConnectPort,
  1667  					},
  1668  				}
  1669  			},
  1670  			natTypes: snatTypes,
  1671  		},
  1672  		{
  1673  			name:     "IPv4 DNAT",
  1674  			netProto: ipv4.ProtocolNumber,
  1675  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
  1676  				t.Helper()
  1677  
  1678  				// If we are performing DNAT, then the packet will be redirected
  1679  				// to the router.
  1680  				listenerStack := routerStack
  1681  				serverAddr := tcpip.FullAddress{
  1682  					Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
  1683  					Port: listenPort,
  1684  				}
  1685  				serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address
  1686  				// DNAT will update the destination port to what the server is
  1687  				// bound to.
  1688  				clientConnectPort := serverAddr.Port + 1
  1689  				ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
  1690  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
  1691  				return endpointAndAddresses{
  1692  					serverEP:          ep1,
  1693  					serverAddr:        serverAddr,
  1694  					serverReadableCH:  ep1WECH,
  1695  					serverConnectAddr: serverConnectAddr,
  1696  
  1697  					clientEP:         ep2,
  1698  					clientAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1699  					clientReadableCH: ep2WECH,
  1700  					clientConnectAddr: tcpip.FullAddress{
  1701  						Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1702  						Port: clientConnectPort,
  1703  					},
  1704  				}
  1705  			},
  1706  			natTypes: dnatTypes,
  1707  		},
  1708  		{
  1709  			name:     "IPv4 Twice-NAT",
  1710  			netProto: ipv4.ProtocolNumber,
  1711  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
  1712  				t.Helper()
  1713  
  1714  				listenerStack := host1Stack
  1715  				serverAddr := tcpip.FullAddress{
  1716  					Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  1717  					Port: listenPort,
  1718  				}
  1719  				serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
  1720  				clientConnectPort := serverAddr.Port
  1721  				ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
  1722  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
  1723  				return endpointAndAddresses{
  1724  					serverEP:          ep1,
  1725  					serverAddr:        serverAddr,
  1726  					serverReadableCH:  ep1WECH,
  1727  					serverConnectAddr: serverConnectAddr,
  1728  
  1729  					clientEP:         ep2,
  1730  					clientAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
  1731  					clientReadableCH: ep2WECH,
  1732  					clientConnectAddr: tcpip.FullAddress{
  1733  						Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
  1734  						Port: clientConnectPort,
  1735  					},
  1736  				}
  1737  			},
  1738  			natTypes: twiceNATTypes,
  1739  		},
  1740  		{
  1741  			name:     "IPv6 SNAT",
  1742  			netProto: ipv6.ProtocolNumber,
  1743  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
  1744  				t.Helper()
  1745  
  1746  				listenerStack := host1Stack
  1747  				serverAddr := tcpip.FullAddress{
  1748  					Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1749  					Port: listenPort,
  1750  				}
  1751  				serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
  1752  				clientConnectPort := serverAddr.Port
  1753  				ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
  1754  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
  1755  				return endpointAndAddresses{
  1756  					serverEP:          ep1,
  1757  					serverAddr:        serverAddr,
  1758  					serverReadableCH:  ep1WECH,
  1759  					serverConnectAddr: serverConnectAddr,
  1760  
  1761  					clientEP:         ep2,
  1762  					clientAddr:       utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1763  					clientReadableCH: ep2WECH,
  1764  					clientConnectAddr: tcpip.FullAddress{
  1765  						Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1766  						Port: clientConnectPort,
  1767  					},
  1768  				}
  1769  			},
  1770  			natTypes: snatTypes,
  1771  		},
  1772  		{
  1773  			name:     "IPv6 DNAT",
  1774  			netProto: ipv6.ProtocolNumber,
  1775  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
  1776  				t.Helper()
  1777  
  1778  				// If we are performing DNAT, then the packet will be redirected
  1779  				// to the router.
  1780  				listenerStack := routerStack
  1781  				serverAddr := tcpip.FullAddress{
  1782  					Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
  1783  					Port: listenPort,
  1784  				}
  1785  				serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address
  1786  				// DNAT will update the destination port to what the server is
  1787  				// bound to.
  1788  				clientConnectPort := serverAddr.Port + 1
  1789  				ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
  1790  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
  1791  				return endpointAndAddresses{
  1792  					serverEP:          ep1,
  1793  					serverAddr:        serverAddr,
  1794  					serverReadableCH:  ep1WECH,
  1795  					serverConnectAddr: serverConnectAddr,
  1796  
  1797  					clientEP:         ep2,
  1798  					clientAddr:       utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1799  					clientReadableCH: ep2WECH,
  1800  					clientConnectAddr: tcpip.FullAddress{
  1801  						Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1802  						Port: clientConnectPort,
  1803  					},
  1804  				}
  1805  			},
  1806  			natTypes: dnatTypes,
  1807  		},
  1808  		{
  1809  			name:     "IPv6 Twice-NAT",
  1810  			netProto: ipv6.ProtocolNumber,
  1811  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
  1812  				t.Helper()
  1813  
  1814  				listenerStack := host1Stack
  1815  				serverAddr := tcpip.FullAddress{
  1816  					Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  1817  					Port: listenPort,
  1818  				}
  1819  				serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
  1820  				clientConnectPort := serverAddr.Port
  1821  				ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
  1822  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
  1823  				return endpointAndAddresses{
  1824  					serverEP:          ep1,
  1825  					serverAddr:        serverAddr,
  1826  					serverReadableCH:  ep1WECH,
  1827  					serverConnectAddr: serverConnectAddr,
  1828  
  1829  					clientEP:         ep2,
  1830  					clientAddr:       utils.Host2IPv6Addr.AddressWithPrefix.Address,
  1831  					clientReadableCH: ep2WECH,
  1832  					clientConnectAddr: tcpip.FullAddress{
  1833  						Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
  1834  						Port: clientConnectPort,
  1835  					},
  1836  				}
  1837  			},
  1838  			natTypes: twiceNATTypes,
  1839  		},
  1840  	}
  1841  
  1842  	subTests := []struct {
  1843  		name               string
  1844  		proto              tcpip.TransportProtocolNumber
  1845  		expectedConnectErr tcpip.Error
  1846  		setupServer        func(t *testing.T, ep tcpip.Endpoint)
  1847  		setupServerConn    func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
  1848  		needRemoteAddr     bool
  1849  	}{
  1850  		{
  1851  			name:               "UDP",
  1852  			proto:              udp.ProtocolNumber,
  1853  			expectedConnectErr: nil,
  1854  			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
  1855  				t.Helper()
  1856  
  1857  				if err := ep.Connect(clientAddr); err != nil {
  1858  					t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
  1859  				}
  1860  				return nil, nil
  1861  			},
  1862  			needRemoteAddr: true,
  1863  		},
  1864  		{
  1865  			name:               "TCP",
  1866  			proto:              tcp.ProtocolNumber,
  1867  			expectedConnectErr: &tcpip.ErrConnectStarted{},
  1868  			setupServer: func(t *testing.T, ep tcpip.Endpoint) {
  1869  				t.Helper()
  1870  
  1871  				if err := ep.Listen(1); err != nil {
  1872  					t.Fatalf("ep.Listen(1): %s", err)
  1873  				}
  1874  			},
  1875  			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
  1876  				t.Helper()
  1877  
  1878  				var addr tcpip.FullAddress
  1879  				for {
  1880  					newEP, wq, err := ep.Accept(&addr)
  1881  					if _, ok := err.(*tcpip.ErrWouldBlock); ok {
  1882  						<-ch
  1883  						continue
  1884  					}
  1885  					if err != nil {
  1886  						t.Fatalf("ep.Accept(_): %s", err)
  1887  					}
  1888  					if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
  1889  						"NIC",
  1890  					)); diff != "" {
  1891  						t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
  1892  					}
  1893  
  1894  					we, newCH := waiter.NewChannelEntry(waiter.ReadableEvents)
  1895  					wq.EventRegister(&we)
  1896  					return newEP, newCH
  1897  				}
  1898  			},
  1899  			needRemoteAddr: false,
  1900  		},
  1901  	}
  1902  
  1903  	for _, test := range tests {
  1904  		t.Run(test.name, func(t *testing.T) {
  1905  			for _, subTest := range subTests {
  1906  				t.Run(subTest.name, func(t *testing.T) {
  1907  					for _, natType := range test.natTypes {
  1908  						t.Run(natType.name, func(t *testing.T) {
  1909  							stackOpts := stack.Options{
  1910  								NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
  1911  								TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
  1912  							}
  1913  
  1914  							host1Stack := stack.New(stackOpts)
  1915  							defer host1Stack.Destroy()
  1916  							routerStack := stack.New(stackOpts)
  1917  							defer routerStack.Destroy()
  1918  							host2Stack := stack.New(stackOpts)
  1919  							defer host2Stack.Destroy()
  1920  							utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
  1921  
  1922  							epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
  1923  							natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr, listenPort)
  1924  
  1925  							if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil {
  1926  								t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err)
  1927  							}
  1928  							clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
  1929  							if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
  1930  								t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
  1931  							}
  1932  
  1933  							if subTest.setupServer != nil {
  1934  								subTest.setupServer(t, epsAndAddrs.serverEP)
  1935  							}
  1936  							{
  1937  								err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr)
  1938  								if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
  1939  									t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff)
  1940  								}
  1941  							}
  1942  							serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr}
  1943  							if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
  1944  								t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
  1945  							} else {
  1946  								serverConnectAddr.Port = addr.Port
  1947  							}
  1948  
  1949  							serverEP := epsAndAddrs.serverEP
  1950  							serverCH := epsAndAddrs.serverReadableCH
  1951  							if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil {
  1952  								defer ep.Close()
  1953  								serverEP = ep
  1954  								serverCH = ch
  1955  							}
  1956  
  1957  							write := func(ep tcpip.Endpoint, data []byte) {
  1958  								t.Helper()
  1959  
  1960  								var r bytes.Reader
  1961  								r.Reset(data)
  1962  								var wOpts tcpip.WriteOptions
  1963  								n, err := ep.Write(&r, wOpts)
  1964  								if err != nil {
  1965  									t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
  1966  								}
  1967  								if want := int64(len(data)); n != want {
  1968  									t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
  1969  								}
  1970  							}
  1971  
  1972  							read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
  1973  								t.Helper()
  1974  
  1975  								var buf bytes.Buffer
  1976  								var res tcpip.ReadResult
  1977  								for {
  1978  									var err tcpip.Error
  1979  									opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
  1980  									res, err = ep.Read(&buf, opts)
  1981  									if _, ok := err.(*tcpip.ErrWouldBlock); ok {
  1982  										<-ch
  1983  										continue
  1984  									}
  1985  									if err != nil {
  1986  										t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
  1987  									}
  1988  									break
  1989  								}
  1990  
  1991  								readResult := tcpip.ReadResult{
  1992  									Count: len(data),
  1993  									Total: len(data),
  1994  								}
  1995  								if subTest.needRemoteAddr {
  1996  									readResult.RemoteAddr = expectedFrom
  1997  								}
  1998  								if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
  1999  									"ControlMessages",
  2000  									"RemoteAddr.NIC",
  2001  								)); diff != "" {
  2002  									t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
  2003  								}
  2004  								if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
  2005  									t.Errorf("received data mismatch (-want +got):\n%s", diff)
  2006  								}
  2007  
  2008  								if t.Failed() {
  2009  									t.FailNow()
  2010  								}
  2011  							}
  2012  
  2013  							{
  2014  								data := []byte{1, 2, 3, 4}
  2015  								write(epsAndAddrs.clientEP, data)
  2016  								read(serverCH, serverEP, data, serverConnectAddr)
  2017  							}
  2018  
  2019  							{
  2020  								data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
  2021  								write(serverEP, data)
  2022  								read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr)
  2023  							}
  2024  						})
  2025  					}
  2026  				})
  2027  			}
  2028  		})
  2029  	}
  2030  }
  2031  
  2032  func encodeIPv4Header(v []byte, totalLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) {
  2033  	ip := header.IPv4(v)
  2034  	ip.Encode(&header.IPv4Fields{
  2035  		TotalLength: uint16(totalLen),
  2036  		Protocol:    uint8(transProto),
  2037  		TTL:         64,
  2038  		SrcAddr:     srcAddr,
  2039  		DstAddr:     dstAddr,
  2040  	})
  2041  	ip.SetChecksum(^ip.CalculateChecksum())
  2042  }
  2043  
  2044  func encodeIPv6Header(v []byte, payloadLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) {
  2045  	ip := header.IPv6(v)
  2046  	ip.Encode(&header.IPv6Fields{
  2047  		PayloadLength:     uint16(payloadLen),
  2048  		TransportProtocol: transProto,
  2049  		HopLimit:          64,
  2050  		SrcAddr:           srcAddr,
  2051  		DstAddr:           dstAddr,
  2052  	})
  2053  }
  2054  
  2055  func udpv4Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte {
  2056  	udpSize := header.UDPMinimumSize + dataSize
  2057  	hdr := prependable.New(header.IPv4MinimumSize + udpSize)
  2058  	udp := header.UDP(hdr.Prepend(udpSize))
  2059  	udp.SetSourcePort(srcPort)
  2060  	udp.SetDestinationPort(dstPort)
  2061  	udp.SetLength(uint16(udpSize))
  2062  	udp.SetChecksum(0)
  2063  	udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
  2064  		header.UDPProtocolNumber,
  2065  		srcAddr,
  2066  		dstAddr,
  2067  		uint16(len(udp)),
  2068  	)))
  2069  	encodeIPv4Header(
  2070  		hdr.Prepend(header.IPv4MinimumSize),
  2071  		hdr.UsedLength(),
  2072  		header.UDPProtocolNumber,
  2073  		srcAddr,
  2074  		dstAddr,
  2075  	)
  2076  	return hdr.View()
  2077  }
  2078  
  2079  func tcpv4Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte {
  2080  	tcpSize := header.TCPMinimumSize + dataSize
  2081  	hdr := prependable.New(header.IPv4MinimumSize + tcpSize)
  2082  	tcp := header.TCP(hdr.Prepend(tcpSize))
  2083  	tcp.SetSourcePort(srcPort)
  2084  	tcp.SetDestinationPort(dstPort)
  2085  	tcp.SetDataOffset(header.TCPMinimumSize)
  2086  	tcp.SetChecksum(0)
  2087  	tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum(
  2088  		header.TCPProtocolNumber,
  2089  		srcAddr,
  2090  		dstAddr,
  2091  		uint16(len(tcp)),
  2092  	)))
  2093  	encodeIPv4Header(
  2094  		hdr.Prepend(header.IPv4MinimumSize),
  2095  		hdr.UsedLength(),
  2096  		header.TCPProtocolNumber,
  2097  		srcAddr,
  2098  		dstAddr,
  2099  	)
  2100  	return hdr.View()
  2101  }
  2102  
  2103  func icmpv4Packet(srcAddr, dstAddr tcpip.Address, icmpType header.ICMPv4Type, ident uint16) []byte {
  2104  	hdr := prependable.New(header.IPv4MinimumSize + header.ICMPv4MinimumSize)
  2105  	icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
  2106  	icmp.SetType(icmpType)
  2107  	icmp.SetIdent(ident)
  2108  	icmp.SetChecksum(0)
  2109  	icmp.SetChecksum(^checksum.Checksum(icmp, 0))
  2110  	encodeIPv4Header(
  2111  		hdr.Prepend(header.IPv4MinimumSize),
  2112  		hdr.UsedLength(),
  2113  		header.ICMPv4ProtocolNumber,
  2114  		srcAddr,
  2115  		dstAddr,
  2116  	)
  2117  	return hdr.View()
  2118  }
  2119  
  2120  func udpv6Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte {
  2121  	udpSize := header.UDPMinimumSize + dataSize
  2122  	hdr := prependable.New(header.IPv6MinimumSize + udpSize)
  2123  	udp := header.UDP(hdr.Prepend(udpSize))
  2124  	udp.SetSourcePort(srcPort)
  2125  	udp.SetDestinationPort(dstPort)
  2126  	udp.SetLength(uint16(udpSize))
  2127  	udp.SetChecksum(0)
  2128  	udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
  2129  		header.UDPProtocolNumber,
  2130  		srcAddr,
  2131  		dstAddr,
  2132  		uint16(len(udp)),
  2133  	)))
  2134  	encodeIPv6Header(
  2135  		hdr.Prepend(header.IPv6MinimumSize),
  2136  		len(udp),
  2137  		header.UDPProtocolNumber,
  2138  		srcAddr,
  2139  		dstAddr,
  2140  	)
  2141  	return hdr.View()
  2142  }
  2143  
  2144  func tcpv6Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte {
  2145  	tcpSize := header.TCPMinimumSize + dataSize
  2146  	hdr := prependable.New(header.IPv6MinimumSize + tcpSize)
  2147  	tcp := header.TCP(hdr.Prepend(tcpSize))
  2148  	tcp.SetSourcePort(srcPort)
  2149  	tcp.SetDestinationPort(dstPort)
  2150  	tcp.SetDataOffset(header.TCPMinimumSize)
  2151  	tcp.SetChecksum(0)
  2152  	tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum(
  2153  		header.TCPProtocolNumber,
  2154  		srcAddr,
  2155  		dstAddr,
  2156  		uint16(len(tcp)),
  2157  	)))
  2158  	encodeIPv6Header(
  2159  		hdr.Prepend(header.IPv6MinimumSize),
  2160  		len(tcp),
  2161  		header.TCPProtocolNumber,
  2162  		srcAddr,
  2163  		dstAddr,
  2164  	)
  2165  	return hdr.View()
  2166  }
  2167  
  2168  func icmpv6Packet(srcAddr, dstAddr tcpip.Address, icmpType header.ICMPv6Type, ident uint16) []byte {
  2169  	hdr := prependable.New(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
  2170  	icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
  2171  	icmp.SetType(icmpType)
  2172  	icmp.SetIdent(ident)
  2173  	icmp.SetChecksum(0)
  2174  	icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
  2175  		Header: icmp,
  2176  		Src:    srcAddr,
  2177  		Dst:    dstAddr,
  2178  	}))
  2179  	encodeIPv6Header(
  2180  		hdr.Prepend(header.IPv6MinimumSize),
  2181  		len(icmp),
  2182  		header.ICMPv6ProtocolNumber,
  2183  		srcAddr,
  2184  		dstAddr,
  2185  	)
  2186  	return hdr.View()
  2187  }
  2188  
  2189  func TestNATICMPError(t *testing.T) {
  2190  	const (
  2191  		srcPort  = 1234
  2192  		dstPort  = 5432
  2193  		dataSize = 4
  2194  	)
  2195  
  2196  	type icmpTypeTest struct {
  2197  		name           string
  2198  		val            uint8
  2199  		expectResponse bool
  2200  	}
  2201  
  2202  	type transportTypeTest struct {
  2203  		name       string
  2204  		proto      tcpip.TransportProtocolNumber
  2205  		buf        []byte
  2206  		checkNATed func(*testing.T, *buffer.View)
  2207  	}
  2208  
  2209  	tests := []struct {
  2210  		name            string
  2211  		netProto        tcpip.NetworkProtocolNumber
  2212  		host1Addr       tcpip.Address
  2213  		icmpError       func(*testing.T, []byte, uint8) []byte
  2214  		decrementTTL    func([]byte)
  2215  		checkNATedError func(*testing.T, *buffer.View, []byte, uint8)
  2216  
  2217  		transportTypes []transportTypeTest
  2218  		icmpTypes      []icmpTypeTest
  2219  	}{
  2220  		{
  2221  			name:      "IPv4",
  2222  			netProto:  ipv4.ProtocolNumber,
  2223  			host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  2224  			icmpError: func(t *testing.T, original []byte, icmpType uint8) []byte {
  2225  				hdr := prependable.New(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original))
  2226  				if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
  2227  					t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
  2228  				}
  2229  				icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
  2230  				icmp.SetType(header.ICMPv4Type(icmpType))
  2231  				icmp.SetChecksum(0)
  2232  				icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
  2233  				encodeIPv4Header(
  2234  					hdr.Prepend(header.IPv4MinimumSize),
  2235  					hdr.UsedLength(),
  2236  					header.ICMPv4ProtocolNumber,
  2237  					utils.Host1IPv4Addr.AddressWithPrefix.Address,
  2238  					utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  2239  				)
  2240  				return hdr.View()
  2241  			},
  2242  			decrementTTL: func(v []byte) {
  2243  				ip := header.IPv4(v)
  2244  				ip.SetTTL(ip.TTL() - 1)
  2245  				ip.SetChecksum(0)
  2246  				ip.SetChecksum(^ip.CalculateChecksum())
  2247  			},
  2248  			checkNATedError: func(t *testing.T, v *buffer.View, original []byte, icmpType uint8) {
  2249  				checker.IPv4(t, v,
  2250  					checker.SrcAddr(utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address),
  2251  					checker.DstAddr(utils.Host2IPv4Addr.AddressWithPrefix.Address),
  2252  					checker.ICMPv4(
  2253  						checker.ICMPv4Type(header.ICMPv4Type(icmpType)),
  2254  						checker.ICMPv4Checksum(),
  2255  						checker.ICMPv4Payload(original),
  2256  					),
  2257  				)
  2258  			},
  2259  			transportTypes: []transportTypeTest{
  2260  				{
  2261  					name:  "UDP",
  2262  					proto: header.UDPProtocolNumber,
  2263  					buf: func() []byte {
  2264  						return udpv4Packet(utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize)
  2265  					}(),
  2266  					checkNATed: func(t *testing.T, v *buffer.View) {
  2267  						checker.IPv4(t, v,
  2268  							checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  2269  							checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
  2270  							checker.UDP(
  2271  								checker.SrcPort(srcPort),
  2272  								checker.DstPort(dstPort),
  2273  							),
  2274  						)
  2275  					},
  2276  				},
  2277  				{
  2278  					name:  "TCP",
  2279  					proto: header.TCPProtocolNumber,
  2280  					buf: func() []byte {
  2281  						return tcpv4Packet(utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize)
  2282  					}(),
  2283  					checkNATed: func(t *testing.T, v *buffer.View) {
  2284  						checker.IPv4(t, v,
  2285  							checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  2286  							checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
  2287  							checker.TCP(
  2288  								checker.SrcPort(srcPort),
  2289  								checker.DstPort(dstPort),
  2290  							),
  2291  						)
  2292  					},
  2293  				},
  2294  			},
  2295  			icmpTypes: []icmpTypeTest{
  2296  				{
  2297  					name:           "Destination Unreachable",
  2298  					val:            uint8(header.ICMPv4DstUnreachable),
  2299  					expectResponse: true,
  2300  				},
  2301  				{
  2302  					name:           "Time Exceeded",
  2303  					val:            uint8(header.ICMPv4TimeExceeded),
  2304  					expectResponse: true,
  2305  				},
  2306  				{
  2307  					name:           "Parameter Problem",
  2308  					val:            uint8(header.ICMPv4ParamProblem),
  2309  					expectResponse: true,
  2310  				},
  2311  				{
  2312  					name:           "Echo Request",
  2313  					val:            uint8(header.ICMPv4Echo),
  2314  					expectResponse: false,
  2315  				},
  2316  				{
  2317  					name:           "Echo Reply",
  2318  					val:            uint8(header.ICMPv4EchoReply),
  2319  					expectResponse: false,
  2320  				},
  2321  			},
  2322  		},
  2323  		{
  2324  			name:      "IPv6",
  2325  			netProto:  ipv6.ProtocolNumber,
  2326  			host1Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  2327  			icmpError: func(t *testing.T, original []byte, icmpType uint8) []byte {
  2328  				payloadLen := header.ICMPv6MinimumSize + len(original)
  2329  				hdr := prependable.New(header.IPv6MinimumSize + payloadLen)
  2330  				icmp := header.ICMPv6(hdr.Prepend(payloadLen))
  2331  				icmp.SetType(header.ICMPv6Type(icmpType))
  2332  				if n := copy(icmp.Payload(), original); n != len(original) {
  2333  					t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
  2334  				}
  2335  				icmp.SetChecksum(0)
  2336  				icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
  2337  					Header: icmp,
  2338  					Src:    utils.Host1IPv6Addr.AddressWithPrefix.Address,
  2339  					Dst:    utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  2340  				}))
  2341  				encodeIPv6Header(
  2342  					hdr.Prepend(header.IPv6MinimumSize),
  2343  					payloadLen,
  2344  					header.ICMPv6ProtocolNumber,
  2345  					utils.Host1IPv6Addr.AddressWithPrefix.Address,
  2346  					utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  2347  				)
  2348  				return hdr.View()
  2349  			},
  2350  			decrementTTL: func(v []byte) {
  2351  				ip := header.IPv6(v)
  2352  				ip.SetHopLimit(ip.HopLimit() - 1)
  2353  			},
  2354  			checkNATedError: func(t *testing.T, v *buffer.View, original []byte, icmpType uint8) {
  2355  				checker.IPv6(t, v,
  2356  					checker.SrcAddr(utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address),
  2357  					checker.DstAddr(utils.Host2IPv6Addr.AddressWithPrefix.Address),
  2358  					checker.ICMPv6(
  2359  						checker.ICMPv6Type(header.ICMPv6Type(icmpType)),
  2360  						checker.ICMPv6Payload(original),
  2361  					),
  2362  				)
  2363  			},
  2364  			transportTypes: []transportTypeTest{
  2365  				{
  2366  					name:  "UDP",
  2367  					proto: header.UDPProtocolNumber,
  2368  					buf: func() []byte {
  2369  						return udpv6Packet(utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize)
  2370  					}(),
  2371  					checkNATed: func(t *testing.T, v *buffer.View) {
  2372  						checker.IPv6(t, v,
  2373  							checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  2374  							checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
  2375  							checker.UDP(
  2376  								checker.SrcPort(srcPort),
  2377  								checker.DstPort(dstPort),
  2378  							),
  2379  						)
  2380  					},
  2381  				},
  2382  				{
  2383  					name:  "TCP",
  2384  					proto: header.TCPProtocolNumber,
  2385  					buf: func() []byte {
  2386  						return tcpv6Packet(utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize)
  2387  					}(),
  2388  					checkNATed: func(t *testing.T, v *buffer.View) {
  2389  						checker.IPv6(t, v,
  2390  							checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  2391  							checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
  2392  							checker.TCP(
  2393  								checker.SrcPort(srcPort),
  2394  								checker.DstPort(dstPort),
  2395  							),
  2396  						)
  2397  					},
  2398  				},
  2399  			},
  2400  			icmpTypes: []icmpTypeTest{
  2401  				{
  2402  					name:           "Destination Unreachable",
  2403  					val:            uint8(header.ICMPv6DstUnreachable),
  2404  					expectResponse: true,
  2405  				},
  2406  				{
  2407  					name:           "Packet Too Big",
  2408  					val:            uint8(header.ICMPv6PacketTooBig),
  2409  					expectResponse: true,
  2410  				},
  2411  				{
  2412  					name:           "Time Exceeded",
  2413  					val:            uint8(header.ICMPv6TimeExceeded),
  2414  					expectResponse: true,
  2415  				},
  2416  				{
  2417  					name:           "Parameter Problem",
  2418  					val:            uint8(header.ICMPv6ParamProblem),
  2419  					expectResponse: true,
  2420  				},
  2421  				{
  2422  					name:           "Echo Request",
  2423  					val:            uint8(header.ICMPv6EchoRequest),
  2424  					expectResponse: false,
  2425  				},
  2426  				{
  2427  					name:           "Echo Reply",
  2428  					val:            uint8(header.ICMPv6EchoReply),
  2429  					expectResponse: false,
  2430  				},
  2431  			},
  2432  		},
  2433  	}
  2434  
  2435  	trimTests := []struct {
  2436  		name            string
  2437  		trimLen         int
  2438  		expectNATedICMP bool
  2439  	}{
  2440  		{
  2441  			name:            "Trim nothing",
  2442  			trimLen:         0,
  2443  			expectNATedICMP: true,
  2444  		},
  2445  		{
  2446  			name:            "Trim data",
  2447  			trimLen:         dataSize,
  2448  			expectNATedICMP: true,
  2449  		},
  2450  		{
  2451  			name:            "Trim data and transport header",
  2452  			trimLen:         dataSize + 1,
  2453  			expectNATedICMP: false,
  2454  		},
  2455  	}
  2456  
  2457  	for _, test := range tests {
  2458  		t.Run(test.name, func(t *testing.T) {
  2459  			for _, transportType := range test.transportTypes {
  2460  				t.Run(transportType.name, func(t *testing.T) {
  2461  					for _, icmpType := range test.icmpTypes {
  2462  						t.Run(icmpType.name, func(t *testing.T) {
  2463  							for _, trimTest := range trimTests {
  2464  								t.Run(trimTest.name, func(t *testing.T) {
  2465  									s := stack.New(stack.Options{
  2466  										NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  2467  										TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
  2468  									})
  2469  									defer s.Destroy()
  2470  
  2471  									ep1 := channel.New(1, header.IPv6MinimumMTU, "")
  2472  									ep2 := channel.New(1, header.IPv6MinimumMTU, "")
  2473  									utils.SetupRouterStack(t, s, ep1, ep2)
  2474  
  2475  									ipv6 := test.netProto == ipv6.ProtocolNumber
  2476  									ipt := s.IPTables()
  2477  
  2478  									table := stack.Table{
  2479  										Rules: []stack.Rule{
  2480  											// Prerouting
  2481  											{
  2482  												Filter: stack.IPHeaderFilter{
  2483  													Protocol:       transportType.proto,
  2484  													CheckProtocol:  true,
  2485  													InputInterface: utils.RouterNIC2Name,
  2486  												},
  2487  												Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort, ChangeAddress: true, ChangePort: true},
  2488  											},
  2489  											{
  2490  												Target: &stack.AcceptTarget{},
  2491  											},
  2492  
  2493  											// Input
  2494  											{
  2495  												Target: &stack.AcceptTarget{},
  2496  											},
  2497  
  2498  											// Forward
  2499  											{
  2500  												Target: &stack.AcceptTarget{},
  2501  											},
  2502  
  2503  											// Output
  2504  											{
  2505  												Target: &stack.AcceptTarget{},
  2506  											},
  2507  
  2508  											// Postrouting
  2509  											{
  2510  												Filter: stack.IPHeaderFilter{
  2511  													Protocol:        transportType.proto,
  2512  													CheckProtocol:   true,
  2513  													OutputInterface: utils.RouterNIC1Name,
  2514  												},
  2515  												Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
  2516  											},
  2517  											{
  2518  												Target: &stack.AcceptTarget{},
  2519  											},
  2520  										},
  2521  										BuiltinChains: [stack.NumHooks]int{
  2522  											stack.Prerouting:  0,
  2523  											stack.Input:       2,
  2524  											stack.Forward:     3,
  2525  											stack.Output:      4,
  2526  											stack.Postrouting: 5,
  2527  										},
  2528  									}
  2529  
  2530  									ipt.ForceReplaceTable(stack.NATID, table, ipv6)
  2531  
  2532  									buf := transportType.buf
  2533  
  2534  									ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2535  										Payload: buffer.MakeWithData(append([]byte{}, buf...)),
  2536  									}))
  2537  
  2538  									{
  2539  										pkt := ep1.Read()
  2540  										if pkt == nil {
  2541  											t.Fatal("expected to read a packet on ep1")
  2542  										}
  2543  										pktView := stack.PayloadSince(pkt.NetworkHeader())
  2544  										defer pktView.Release()
  2545  										pkt.DecRef()
  2546  										transportType.checkNATed(t, pktView)
  2547  										if t.Failed() {
  2548  											t.FailNow()
  2549  										}
  2550  
  2551  										pktSlice := pktView.AsSlice()[:pktView.Size()-trimTest.trimLen]
  2552  										buf = buf[:len(buf)-trimTest.trimLen]
  2553  
  2554  										ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2555  											Payload: buffer.MakeWithData(test.icmpError(t, pktSlice, icmpType.val)),
  2556  										}))
  2557  									}
  2558  
  2559  									pkt := ep2.Read()
  2560  									expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP
  2561  									if (pkt != nil) != expectResponse {
  2562  										t.Fatalf("got ep2.Read() = %#v, want = (_ == nil) = %t", pkt, expectResponse)
  2563  									}
  2564  									if !expectResponse {
  2565  										return
  2566  									}
  2567  									test.decrementTTL(buf)
  2568  									payload := stack.PayloadSince(pkt.NetworkHeader())
  2569  									defer payload.Release()
  2570  									test.checkNATedError(t, payload, buf, icmpType.val)
  2571  									pkt.DecRef()
  2572  								})
  2573  							}
  2574  						})
  2575  					}
  2576  				})
  2577  			}
  2578  		})
  2579  	}
  2580  }
  2581  
  2582  func TestSNATHandlePortOrIdentConflicts(t *testing.T) {
  2583  	const dstPort = 5432
  2584  
  2585  	type portOrIdentRange struct {
  2586  		first uint16
  2587  		last  uint16
  2588  	}
  2589  
  2590  	type srcPortOrIdentRangeTest struct {
  2591  		name          string
  2592  		originalRange portOrIdentRange
  2593  		targetRange   portOrIdentRange
  2594  	}
  2595  
  2596  	srcPortRanges := []srcPortOrIdentRangeTest{
  2597  		{
  2598  			name:          "Less than 512",
  2599  			originalRange: portOrIdentRange{first: 1, last: 511},
  2600  			targetRange:   portOrIdentRange{first: 1, last: 511},
  2601  		},
  2602  		{
  2603  			name:          "Greater than or equal to 512 but less than 1024",
  2604  			originalRange: portOrIdentRange{first: 512, last: 1023},
  2605  			targetRange:   portOrIdentRange{first: 1, last: 1023},
  2606  		},
  2607  		{
  2608  			name:          "Greater than or equal to 1024",
  2609  			originalRange: portOrIdentRange{first: 1024, last: math.MaxUint16},
  2610  			targetRange:   portOrIdentRange{first: 1024, last: math.MaxUint16},
  2611  		},
  2612  	}
  2613  
  2614  	// Unlike TCP/UDP, the Ident may be mapped to any 16-bit value.
  2615  	identRanges := []srcPortOrIdentRangeTest{
  2616  		{
  2617  			name:          "Less than 512",
  2618  			originalRange: portOrIdentRange{first: 0, last: 511},
  2619  			targetRange:   portOrIdentRange{first: 0, last: math.MaxUint16},
  2620  		},
  2621  		{
  2622  			name:          "Greater than or equal to 512 but less than 1024",
  2623  			originalRange: portOrIdentRange{first: 512, last: 1023},
  2624  			targetRange:   portOrIdentRange{first: 0, last: math.MaxUint16},
  2625  		},
  2626  		{
  2627  			name:          "Greater than or equal to 1024",
  2628  			originalRange: portOrIdentRange{first: 1024, last: math.MaxUint16},
  2629  			targetRange:   portOrIdentRange{first: 0, last: math.MaxUint16},
  2630  		},
  2631  	}
  2632  
  2633  	type transportTypeTest struct {
  2634  		name                 string
  2635  		proto                tcpip.TransportProtocolNumber
  2636  		buf                  func(tcpip.Address, uint16) []byte
  2637  		checkNATed           func(*testing.T, *buffer.View, uint16, bool, portOrIdentRange)
  2638  		srcPortOrIdentRanges []srcPortOrIdentRangeTest
  2639  	}
  2640  
  2641  	compareSrcPortOrIdent := func(t *testing.T, gotPort uint16, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2642  		t.Helper()
  2643  
  2644  		if firstPacket {
  2645  			if gotPort != originalSrcPort {
  2646  				t.Errorf("got port/ident = %d, want = %d", gotPort, originalSrcPort)
  2647  			}
  2648  			return
  2649  		}
  2650  
  2651  		if gotPort < expectedRange.first || gotPort > expectedRange.last {
  2652  			t.Errorf("got port/ident = %d, want in range [%d, %d]", gotPort, expectedRange.first, expectedRange.last)
  2653  		}
  2654  	}
  2655  
  2656  	tests := []struct {
  2657  		name           string
  2658  		netProto       tcpip.NetworkProtocolNumber
  2659  		routerNIC1Addr tcpip.Address
  2660  		srcAddrs       []tcpip.Address
  2661  		transportTypes []transportTypeTest
  2662  	}{
  2663  		{
  2664  			name:           "IPv4",
  2665  			netProto:       ipv4.ProtocolNumber,
  2666  			routerNIC1Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  2667  			srcAddrs: []tcpip.Address{
  2668  				utils.Ipv4Addr1.AddressWithPrefix.Address,
  2669  				utils.Ipv4Addr2.AddressWithPrefix.Address,
  2670  				utils.Ipv4Addr3.AddressWithPrefix.Address,
  2671  			},
  2672  			transportTypes: []transportTypeTest{
  2673  				{
  2674  					name:  "UDP",
  2675  					proto: header.UDPProtocolNumber,
  2676  					buf: func(srcAddr tcpip.Address, srcPort uint16) []byte {
  2677  						return udpv4Packet(srcAddr, utils.Host1IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */)
  2678  					},
  2679  					checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2680  						checker.IPv4(t, v,
  2681  							checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  2682  							checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
  2683  							checker.UDP(
  2684  								checker.DstPort(dstPort),
  2685  							),
  2686  						)
  2687  
  2688  						if !t.Failed() {
  2689  							compareSrcPortOrIdent(t, header.UDP(header.IPv4(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange)
  2690  						}
  2691  					},
  2692  					srcPortOrIdentRanges: srcPortRanges,
  2693  				},
  2694  				{
  2695  					name:  "TCP",
  2696  					proto: header.TCPProtocolNumber,
  2697  					buf: func(srcAddr tcpip.Address, srcPort uint16) []byte {
  2698  						return tcpv4Packet(srcAddr, utils.Host1IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */)
  2699  					},
  2700  					checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2701  						checker.IPv4(t, v,
  2702  							checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  2703  							checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
  2704  							checker.TCP(
  2705  								checker.DstPort(dstPort),
  2706  							),
  2707  						)
  2708  
  2709  						if !t.Failed() {
  2710  							compareSrcPortOrIdent(t, header.TCP(header.IPv4(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange)
  2711  						}
  2712  					},
  2713  					srcPortOrIdentRanges: srcPortRanges,
  2714  				},
  2715  				{
  2716  					name:  "ICMP Echo",
  2717  					proto: header.ICMPv4ProtocolNumber,
  2718  					buf: func(srcAddr tcpip.Address, ident uint16) []byte {
  2719  						return icmpv4Packet(srcAddr, utils.Host1IPv4Addr.AddressWithPrefix.Address, header.ICMPv4Echo, ident)
  2720  					},
  2721  					checkNATed: func(t *testing.T, v *buffer.View, originalIdent uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2722  						checker.IPv4(t, v,
  2723  							checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  2724  							checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
  2725  							checker.ICMPv4(
  2726  								checker.ICMPv4Type(header.ICMPv4Echo),
  2727  								checker.ICMPv4Checksum(),
  2728  							),
  2729  						)
  2730  
  2731  						if !t.Failed() {
  2732  							compareSrcPortOrIdent(t, header.ICMPv4(header.IPv4(v.AsSlice()).Payload()).Ident(), originalIdent, firstPacket, expectedRange)
  2733  						}
  2734  					},
  2735  					srcPortOrIdentRanges: identRanges,
  2736  				},
  2737  			},
  2738  		},
  2739  		{
  2740  			name:           "IPv6",
  2741  			netProto:       ipv6.ProtocolNumber,
  2742  			routerNIC1Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  2743  			srcAddrs: []tcpip.Address{
  2744  				utils.Ipv6Addr1.AddressWithPrefix.Address,
  2745  				utils.Ipv6Addr2.AddressWithPrefix.Address,
  2746  				utils.Ipv6Addr2.AddressWithPrefix.Address,
  2747  			},
  2748  			transportTypes: []transportTypeTest{
  2749  				{
  2750  					name:  "UDP",
  2751  					proto: header.UDPProtocolNumber,
  2752  					buf: func(srcAddr tcpip.Address, srcPort uint16) []byte {
  2753  						return udpv6Packet(srcAddr, utils.Host1IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */)
  2754  					},
  2755  					checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2756  						checker.IPv6(t, v,
  2757  							checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  2758  							checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
  2759  							checker.UDP(
  2760  								checker.DstPort(dstPort),
  2761  							),
  2762  						)
  2763  
  2764  						if !t.Failed() {
  2765  							compareSrcPortOrIdent(t, header.UDP(header.IPv6(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange)
  2766  						}
  2767  					},
  2768  					srcPortOrIdentRanges: srcPortRanges,
  2769  				},
  2770  				{
  2771  					name:  "TCP",
  2772  					proto: header.TCPProtocolNumber,
  2773  					buf: func(srcAddr tcpip.Address, srcPort uint16) []byte {
  2774  						return tcpv6Packet(srcAddr, utils.Host1IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */)
  2775  					},
  2776  					checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2777  						checker.IPv6(t, v,
  2778  							checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  2779  							checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
  2780  							checker.TCP(
  2781  								checker.DstPort(dstPort),
  2782  							),
  2783  						)
  2784  
  2785  						if !t.Failed() {
  2786  							compareSrcPortOrIdent(t, header.TCP(header.IPv6(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange)
  2787  						}
  2788  					},
  2789  					srcPortOrIdentRanges: srcPortRanges,
  2790  				},
  2791  				{
  2792  					name:  "ICMP Echo",
  2793  					proto: header.ICMPv6ProtocolNumber,
  2794  					buf: func(srcAddr tcpip.Address, ident uint16) []byte {
  2795  						return icmpv6Packet(srcAddr, utils.Host1IPv6Addr.AddressWithPrefix.Address, header.ICMPv6EchoRequest, ident)
  2796  					},
  2797  					checkNATed: func(t *testing.T, v *buffer.View, originalIdent uint16, firstPacket bool, expectedRange portOrIdentRange) {
  2798  						checker.IPv6(t, v,
  2799  							checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  2800  							checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
  2801  							checker.ICMPv6(
  2802  								checker.ICMPv6Type(header.ICMPv6EchoRequest),
  2803  							),
  2804  						)
  2805  
  2806  						if !t.Failed() {
  2807  							compareSrcPortOrIdent(t, header.ICMPv6(header.IPv6(v.AsSlice()).Payload()).Ident(), originalIdent, firstPacket, expectedRange)
  2808  						}
  2809  					},
  2810  					srcPortOrIdentRanges: identRanges,
  2811  				},
  2812  			},
  2813  		},
  2814  	}
  2815  
  2816  	natTypes := []struct {
  2817  		name   string
  2818  		target func(tcpip.NetworkProtocolNumber, tcpip.Address) stack.Target
  2819  	}{
  2820  		{
  2821  			name: "Masquerade",
  2822  			target: func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) stack.Target {
  2823  				return &stack.MasqueradeTarget{NetworkProtocol: netProto}
  2824  			},
  2825  		},
  2826  		{
  2827  			name: "SNAT",
  2828  			target: func(netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) stack.Target {
  2829  				return &stack.SNATTarget{NetworkProtocol: netProto, Addr: addr, ChangeAddress: true, ChangePort: true}
  2830  			},
  2831  		},
  2832  	}
  2833  
  2834  	for _, test := range tests {
  2835  		t.Run(test.name, func(t *testing.T) {
  2836  			for _, transportType := range test.transportTypes {
  2837  				t.Run(transportType.name, func(t *testing.T) {
  2838  					for _, natType := range natTypes {
  2839  						t.Run(natType.name, func(t *testing.T) {
  2840  							for _, srcPortOrIdentRange := range transportType.srcPortOrIdentRanges {
  2841  								t.Run(srcPortOrIdentRange.name, func(t *testing.T) {
  2842  									for _, srcPortOrIdent := range [2]uint16{srcPortOrIdentRange.originalRange.first, srcPortOrIdentRange.originalRange.last} {
  2843  										t.Run(fmt.Sprintf("OriginalSrcPortOrIdent=%d", srcPortOrIdent), func(t *testing.T) {
  2844  											s := stack.New(stack.Options{
  2845  												NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  2846  												TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
  2847  											})
  2848  											defer s.Destroy()
  2849  
  2850  											ep1 := channel.New(1, header.IPv6MinimumMTU, "")
  2851  											ep2 := channel.New(1, header.IPv6MinimumMTU, "")
  2852  											utils.SetupRouterStack(t, s, ep1, ep2)
  2853  
  2854  											ipv6 := test.netProto == ipv6.ProtocolNumber
  2855  											ipt := s.IPTables()
  2856  
  2857  											table := stack.Table{
  2858  												Rules: []stack.Rule{
  2859  													// Prerouting
  2860  													{
  2861  														Target: &stack.AcceptTarget{},
  2862  													},
  2863  
  2864  													// Input
  2865  													{
  2866  														Target: &stack.AcceptTarget{},
  2867  													},
  2868  
  2869  													// Forward
  2870  													{
  2871  														Target: &stack.AcceptTarget{},
  2872  													},
  2873  
  2874  													// Output
  2875  													{
  2876  														Target: &stack.AcceptTarget{},
  2877  													},
  2878  
  2879  													// Postrouting
  2880  													{
  2881  														Filter: stack.IPHeaderFilter{
  2882  															Protocol:        transportType.proto,
  2883  															CheckProtocol:   true,
  2884  															OutputInterface: utils.RouterNIC1Name,
  2885  														},
  2886  														Target: natType.target(test.netProto, test.routerNIC1Addr),
  2887  													},
  2888  													{
  2889  														Target: &stack.AcceptTarget{},
  2890  													},
  2891  												},
  2892  												BuiltinChains: [stack.NumHooks]int{
  2893  													stack.Prerouting:  0,
  2894  													stack.Input:       1,
  2895  													stack.Forward:     2,
  2896  													stack.Output:      3,
  2897  													stack.Postrouting: 4,
  2898  												},
  2899  											}
  2900  
  2901  											ipt.ForceReplaceTable(stack.NATID, table, ipv6)
  2902  
  2903  											for i, srcAddr := range test.srcAddrs {
  2904  												t.Run(fmt.Sprintf("Packet#%d", i), func(t *testing.T) {
  2905  													ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
  2906  														Payload: buffer.MakeWithData(transportType.buf(srcAddr, srcPortOrIdent)),
  2907  													}))
  2908  
  2909  													pkt := ep1.Read()
  2910  													if pkt == nil {
  2911  														t.Fatal("expected to read a packet on ep1")
  2912  													}
  2913  													pktView := stack.PayloadSince(pkt.NetworkHeader())
  2914  													defer pktView.Release()
  2915  													pkt.DecRef()
  2916  													transportType.checkNATed(t, pktView, srcPortOrIdent, i == 0, srcPortOrIdentRange.targetRange)
  2917  												})
  2918  											}
  2919  										})
  2920  									}
  2921  								})
  2922  							}
  2923  						})
  2924  					}
  2925  				})
  2926  			}
  2927  		})
  2928  	}
  2929  }
  2930  
  2931  func TestSNATLocallyGeneratedTrafficPorts(t *testing.T) {
  2932  	s := stack.New(stack.Options{
  2933  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  2934  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
  2935  	})
  2936  	defer s.Destroy()
  2937  
  2938  	ep1 := channel.New(1, header.IPv4MinimumMTU, "")
  2939  	ep2 := channel.New(1, header.IPv4MinimumMTU, "")
  2940  	utils.SetupRouterStack(t, s, ep1, ep2)
  2941  
  2942  	// Configure Masquerade NAT on the router stack.
  2943  	ipt := s.IPTables()
  2944  	table := stack.Table{
  2945  		Rules: []stack.Rule{
  2946  			// Prerouting
  2947  			{
  2948  				Target: &stack.AcceptTarget{},
  2949  			},
  2950  
  2951  			// Input
  2952  			{
  2953  				Target: &stack.AcceptTarget{},
  2954  			},
  2955  
  2956  			// Forward
  2957  			{
  2958  				Target: &stack.AcceptTarget{},
  2959  			},
  2960  
  2961  			// Output
  2962  			{
  2963  				Target: &stack.AcceptTarget{},
  2964  			},
  2965  
  2966  			// Postrouting
  2967  			{
  2968  				Filter: stack.IPHeaderFilter{
  2969  					Protocol:        udp.ProtocolNumber,
  2970  					CheckProtocol:   true,
  2971  					OutputInterface: utils.RouterNIC2Name,
  2972  				},
  2973  				Target: &stack.MasqueradeTarget{NetworkProtocol: ipv4.ProtocolNumber},
  2974  			},
  2975  			{
  2976  				Target: &stack.AcceptTarget{},
  2977  			},
  2978  		},
  2979  		BuiltinChains: [stack.NumHooks]int{
  2980  			stack.Prerouting:  0,
  2981  			stack.Input:       1,
  2982  			stack.Forward:     2,
  2983  			stack.Output:      3,
  2984  			stack.Postrouting: 4,
  2985  		},
  2986  	}
  2987  	ipt.ForceReplaceTable(stack.NATID, table, false /* ipv6 */)
  2988  
  2989  	routerNIC2Addr := utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address
  2990  	ep1Addr := utils.Host1IPv4Addr.AddressWithPrefix.Address
  2991  	var ep1Port uint16 = 1234
  2992  	ep2Addr := utils.Host2IPv4Addr.AddressWithPrefix.Address
  2993  	var ep2Port uint16 = 2345
  2994  
  2995  	// Inject an incoming packet on NIC1 destined to an address that will be
  2996  	// routed out of NIC2. Expect that we can read the packet on ep2 coming from
  2997  	// the stack's address assigned on NIC2, because it should have performed
  2998  	// Masquerade NAT on the forwarded traffic.
  2999  	ep1.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  3000  		Payload: buffer.MakeWithData(udpv4Packet(ep1Addr, ep2Addr, ep1Port, ep2Port, 0 /* dataSize */)),
  3001  	}))
  3002  	pkt := ep2.Read()
  3003  	if pkt == nil {
  3004  		t.Fatal("expected to read a packet on ep2")
  3005  	}
  3006  	pktView := stack.PayloadSince(pkt.NetworkHeader())
  3007  	defer pktView.Release()
  3008  	pkt.DecRef()
  3009  	checker.IPv4(t, pktView,
  3010  		checker.SrcAddr(routerNIC2Addr),
  3011  		checker.DstAddr(ep2Addr),
  3012  		checker.UDP(
  3013  			checker.SrcPort(ep1Port),
  3014  			checker.DstPort(ep2Port),
  3015  		),
  3016  	)
  3017  
  3018  	// Now bind a UDP socket on the stack itself to the same port used by the
  3019  	// previous packet, and send a packet to the same address.
  3020  	var wq waiter.Queue
  3021  	we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
  3022  	wq.EventRegister(&we)
  3023  	defer wq.EventUnregister(&we)
  3024  
  3025  	ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
  3026  	if err != nil {
  3027  		t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ipv4.ProtocolNumber, err)
  3028  	}
  3029  	defer ep.Close()
  3030  
  3031  	srcAddr := tcpip.FullAddress{Addr: routerNIC2Addr, Port: ep1Port}
  3032  	if err := ep.Bind(srcAddr); err != nil {
  3033  		t.Fatalf("ep.Bind(%#v): %s", srcAddr, err)
  3034  	}
  3035  	dstAddr := tcpip.FullAddress{Addr: ep2Addr, Port: ep2Port}
  3036  	if err := ep.Connect(dstAddr); err != nil {
  3037  		t.Fatalf("ep.Connect(%#v): %s", dstAddr, err)
  3038  	}
  3039  
  3040  	data := []byte{1, 2, 3, 4}
  3041  	var r bytes.Reader
  3042  	r.Reset(data)
  3043  	var wOpts tcpip.WriteOptions
  3044  	n, err := ep.Write(&r, wOpts)
  3045  	if err != nil {
  3046  		t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
  3047  	}
  3048  	if want := int64(len(data)); n != want {
  3049  		t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
  3050  	}
  3051  
  3052  	// The router should perform source port remapping for the locally generated
  3053  	// traffic so that it does not conflict with the existing conntrack entry, so
  3054  	// ep2 should observe the traffic as coming from the router's address, but
  3055  	// *not* from the same port as the traffic from ep1 before.
  3056  	pkt = ep2.Read()
  3057  	if pkt == nil {
  3058  		t.Fatal("expected to read a packet on ep2")
  3059  	}
  3060  	pktView = stack.PayloadSince(pkt.NetworkHeader())
  3061  	defer pktView.Release()
  3062  	pkt.DecRef()
  3063  	checker.IPv4(t, pktView,
  3064  		checker.SrcAddr(routerNIC2Addr),
  3065  		checker.DstAddr(ep2Addr),
  3066  		checker.UDP(
  3067  			checker.DstPort(ep2Port),
  3068  			checker.Payload(data),
  3069  		),
  3070  	)
  3071  	gotPort := header.UDP(header.IPv4(pktView.AsSlice()).Payload()).SourcePort()
  3072  	if gotPort == ep1Port {
  3073  		t.Errorf("got src port == ep1Port (%d), should be remapped to avoid conflict", gotPort)
  3074  	}
  3075  
  3076  	// We should also be able to reply on either connection, by injecting inbound
  3077  	// traffic on ep2 destined to the router.
  3078  	//
  3079  	// Traffic destined to the port originally used in the traffic injected on ep1
  3080  	// should go to ep1.
  3081  	ep2.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  3082  		Payload: buffer.MakeWithData(udpv4Packet(ep2Addr, routerNIC2Addr, ep2Port, ep1Port, 0 /* dataSize */)),
  3083  	}))
  3084  	pkt = ep1.Read()
  3085  	if pkt == nil {
  3086  		t.Fatal("expected to read a packet on ep2")
  3087  	}
  3088  	pktView = stack.PayloadSince(pkt.NetworkHeader())
  3089  	defer pktView.Release()
  3090  	pkt.DecRef()
  3091  	checker.IPv4(t, pktView,
  3092  		checker.SrcAddr(ep2Addr),
  3093  		checker.DstAddr(ep1Addr),
  3094  		checker.UDP(
  3095  			checker.SrcPort(ep2Port),
  3096  			checker.DstPort(ep1Port),
  3097  		),
  3098  	)
  3099  
  3100  	// And traffic destined to the remapped source port chosen by conntrack for
  3101  	// the socket bound on the stack should go to the socket.
  3102  	reply := udpv4Packet(ep2Addr, routerNIC2Addr, ep2Port, gotPort, 0 /* dataSize */)
  3103  	reply = append(reply, data...)
  3104  	ep2.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
  3105  		Payload: buffer.MakeWithData(reply),
  3106  	}))
  3107  	var buf bytes.Buffer
  3108  	var res tcpip.ReadResult
  3109  	for {
  3110  		var err tcpip.Error
  3111  		res, err = ep.Read(&buf, tcpip.ReadOptions{})
  3112  		if _, ok := err.(*tcpip.ErrWouldBlock); ok {
  3113  			<-ch
  3114  			continue
  3115  		}
  3116  		if err != nil {
  3117  			t.Fatalf("ep.Read(_, {}): %s", err)
  3118  		}
  3119  		break
  3120  	}
  3121  	if diff := cmp.Diff(
  3122  		tcpip.ReadResult{
  3123  			Count: 0,
  3124  			Total: 0,
  3125  		},
  3126  		res,
  3127  		checker.IgnoreCmpPath("ControlMessages"),
  3128  	); diff != "" {
  3129  		t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
  3130  	}
  3131  }
  3132  
  3133  func TestLocallyRoutedPackets(t *testing.T) {
  3134  	const nicID = 1
  3135  
  3136  	tests := []struct {
  3137  		name     string
  3138  		netProto tcpip.NetworkProtocolNumber
  3139  		addr     tcpip.Address
  3140  	}{
  3141  		{
  3142  			name:     "IPv4",
  3143  			netProto: ipv4.ProtocolNumber,
  3144  			addr:     utils.Host1IPv4Addr.AddressWithPrefix.Address,
  3145  		},
  3146  		{
  3147  			name:     "IPv6",
  3148  			netProto: ipv6.ProtocolNumber,
  3149  			addr:     utils.Host1IPv6Addr.AddressWithPrefix.Address,
  3150  		},
  3151  	}
  3152  
  3153  	for _, test := range tests {
  3154  		t.Run(test.name, func(t *testing.T) {
  3155  			s := stack.New(stack.Options{
  3156  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  3157  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
  3158  			})
  3159  			defer s.Destroy()
  3160  
  3161  			if err := s.CreateNIC(nicID, loopback.New()); err != nil {
  3162  				t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
  3163  			}
  3164  			protocolAddr := tcpip.ProtocolAddress{
  3165  				Protocol:          test.netProto,
  3166  				AddressWithPrefix: test.addr.WithPrefix(),
  3167  			}
  3168  			if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
  3169  				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
  3170  			}
  3171  
  3172  			s.SetRouteTable([]tcpip.Route{
  3173  				{
  3174  					Destination: protocolAddr.AddressWithPrefix.Subnet(),
  3175  					NIC:         nicID,
  3176  				},
  3177  			})
  3178  
  3179  			// Set IPTables so we create entries in the conntrack table.
  3180  			{
  3181  				ipv6 := test.netProto == ipv6.ProtocolNumber
  3182  				ipt := s.IPTables()
  3183  				filter := ipt.GetTable(stack.FilterID, ipv6)
  3184  				ipt.ForceReplaceTable(stack.FilterID, filter, ipv6)
  3185  			}
  3186  
  3187  			var wq waiter.Queue
  3188  			we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
  3189  			wq.EventRegister(&we)
  3190  			defer wq.EventUnregister(&we)
  3191  
  3192  			ep, err := s.NewEndpoint(udp.ProtocolNumber, test.netProto, &wq)
  3193  			if err != nil {
  3194  				t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
  3195  			}
  3196  			defer ep.Close()
  3197  
  3198  			fullAddr := tcpip.FullAddress{Addr: test.addr, Port: 1234}
  3199  			if err := ep.Bind(fullAddr); err != nil {
  3200  				t.Fatalf("ep.Bind(%#v): %s", fullAddr, err)
  3201  			}
  3202  			if err := ep.Connect(fullAddr); err != nil {
  3203  				t.Fatalf("ep.Connect(%#v): %s", fullAddr, err)
  3204  			}
  3205  
  3206  			data := []byte{1, 2, 3, 4}
  3207  
  3208  			var r bytes.Reader
  3209  			r.Reset(data)
  3210  			var wOpts tcpip.WriteOptions
  3211  			n, err := ep.Write(&r, wOpts)
  3212  			if err != nil {
  3213  				t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
  3214  			}
  3215  			if want := int64(len(data)); n != want {
  3216  				t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
  3217  			}
  3218  
  3219  			var buf bytes.Buffer
  3220  			var res tcpip.ReadResult
  3221  			for {
  3222  				var err tcpip.Error
  3223  				res, err = ep.Read(&buf, tcpip.ReadOptions{})
  3224  				if _, ok := err.(*tcpip.ErrWouldBlock); ok {
  3225  					<-ch
  3226  					continue
  3227  				}
  3228  				if err != nil {
  3229  					t.Fatalf("ep.Read(_, {}): %s", err)
  3230  				}
  3231  				break
  3232  			}
  3233  			if diff := cmp.Diff(
  3234  				tcpip.ReadResult{
  3235  					Count: len(data),
  3236  					Total: len(data),
  3237  				},
  3238  				res,
  3239  				checker.IgnoreCmpPath("ControlMessages"),
  3240  			); diff != "" {
  3241  				t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
  3242  			}
  3243  			if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
  3244  				t.Errorf("received data mismatch (-want +got):\n%s", diff)
  3245  			}
  3246  		})
  3247  	}
  3248  }
  3249  
  3250  type icmpv4Matcher struct {
  3251  	icmpType header.ICMPv4Type
  3252  }
  3253  
  3254  func (m *icmpv4Matcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches bool, hotdrop bool) {
  3255  	if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber {
  3256  		return false, false
  3257  	}
  3258  
  3259  	if pkt.TransportProtocolNumber != header.ICMPv4ProtocolNumber {
  3260  		return false, false
  3261  	}
  3262  
  3263  	return header.ICMPv4(pkt.TransportHeader().Slice()).Type() == m.icmpType, false
  3264  }
  3265  
  3266  type icmpv6Matcher struct {
  3267  	icmpType header.ICMPv6Type
  3268  }
  3269  
  3270  func (m *icmpv6Matcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches bool, hotdrop bool) {
  3271  	if pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
  3272  		return false, false
  3273  	}
  3274  
  3275  	if pkt.TransportProtocolNumber != header.ICMPv6ProtocolNumber {
  3276  		return false, false
  3277  	}
  3278  
  3279  	return header.ICMPv6(pkt.TransportHeader().Slice()).Type() == m.icmpType, false
  3280  }
  3281  
  3282  func TestRejectWith(t *testing.T) {
  3283  	type natHook struct {
  3284  		hook    stack.Hook
  3285  		dstAddr tcpip.Address
  3286  		matcher stack.Matcher
  3287  
  3288  		errorICMPDstAddr tcpip.Address
  3289  		errorICMPPayload []byte
  3290  	}
  3291  
  3292  	type rejectWithVal struct {
  3293  		name          string
  3294  		val           int
  3295  		errorICMPCode uint8
  3296  	}
  3297  
  3298  	rxICMPv4EchoRequest := func(dst tcpip.Address) []byte {
  3299  		return utils.ICMPv4Echo(utils.Host1IPv4Addr.AddressWithPrefix.Address, dst, ttl, header.ICMPv4Echo)
  3300  	}
  3301  
  3302  	rxICMPv6EchoRequest := func(dst tcpip.Address) []byte {
  3303  		return utils.ICMPv6Echo(utils.Host1IPv6Addr.AddressWithPrefix.Address, dst, ttl, header.ICMPv6EchoRequest)
  3304  	}
  3305  
  3306  	tests := []struct {
  3307  		name              string
  3308  		netProto          tcpip.NetworkProtocolNumber
  3309  		rxICMPEchoRequest func(tcpip.Address) []byte
  3310  		icmpChecker       func(*testing.T, *buffer.View, tcpip.Address, uint8, uint8, []byte)
  3311  
  3312  		natHooks []natHook
  3313  
  3314  		rejectTarget   func(*testing.T, stack.NetworkProtocol, int) stack.Target
  3315  		rejectWithVals []rejectWithVal
  3316  		errorICMPType  uint8
  3317  	}{
  3318  		{
  3319  			name:              "IPv4",
  3320  			netProto:          header.IPv4ProtocolNumber,
  3321  			rxICMPEchoRequest: rxICMPv4EchoRequest,
  3322  
  3323  			icmpChecker: func(t *testing.T, v *buffer.View, dstAddr tcpip.Address, icmpType, icmpCode uint8, origPayload []byte) {
  3324  				t.Helper()
  3325  
  3326  				checker.IPv4(t, v,
  3327  					checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  3328  					checker.DstAddr(dstAddr),
  3329  					checker.ICMPv4(
  3330  						checker.ICMPv4Checksum(),
  3331  						checker.ICMPv4Type(header.ICMPv4Type(icmpType)),
  3332  						checker.ICMPv4Code(header.ICMPv4Code(icmpCode)),
  3333  						checker.ICMPv4Payload(origPayload),
  3334  					),
  3335  				)
  3336  			},
  3337  			natHooks: []natHook{
  3338  				{
  3339  					hook:             stack.Input,
  3340  					dstAddr:          utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  3341  					matcher:          &icmpv4Matcher{icmpType: header.ICMPv4Echo},
  3342  					errorICMPDstAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  3343  					errorICMPPayload: rxICMPv4EchoRequest(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
  3344  				},
  3345  				{
  3346  					hook:             stack.Forward,
  3347  					dstAddr:          utils.Host2IPv4Addr.AddressWithPrefix.Address,
  3348  					matcher:          &icmpv4Matcher{icmpType: header.ICMPv4Echo},
  3349  					errorICMPDstAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
  3350  					errorICMPPayload: rxICMPv4EchoRequest(utils.Host2IPv4Addr.AddressWithPrefix.Address),
  3351  				},
  3352  				{
  3353  					hook:             stack.Output,
  3354  					dstAddr:          utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  3355  					matcher:          &icmpv4Matcher{icmpType: header.ICMPv4EchoReply},
  3356  					errorICMPDstAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
  3357  					errorICMPPayload: utils.ICMPv4Echo(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, utils.Host1IPv4Addr.AddressWithPrefix.Address, ttl, header.ICMPv4EchoReply),
  3358  				},
  3359  			},
  3360  			rejectTarget: func(t *testing.T, netProto stack.NetworkProtocol, rejectWith int) stack.Target {
  3361  				handler, ok := netProto.(stack.RejectIPv4WithHandler)
  3362  				if !ok {
  3363  					t.Fatalf("expected %T to implement %T", netProto, handler)
  3364  				}
  3365  
  3366  				return &stack.RejectIPv4Target{
  3367  					Handler:    handler,
  3368  					RejectWith: stack.RejectIPv4WithICMPType(rejectWith),
  3369  				}
  3370  			},
  3371  			rejectWithVals: []rejectWithVal{
  3372  				{
  3373  					name:          "ICMP Network Unreachable",
  3374  					val:           int(stack.RejectIPv4WithICMPNetUnreachable),
  3375  					errorICMPCode: uint8(header.ICMPv4NetUnreachable),
  3376  				},
  3377  				{
  3378  					name:          "ICMP Host Unreachable",
  3379  					val:           int(stack.RejectIPv4WithICMPHostUnreachable),
  3380  					errorICMPCode: uint8(header.ICMPv4HostUnreachable),
  3381  				},
  3382  				{
  3383  					name:          "ICMP Port Unreachable",
  3384  					val:           int(stack.RejectIPv4WithICMPPortUnreachable),
  3385  					errorICMPCode: uint8(header.ICMPv4PortUnreachable),
  3386  				},
  3387  				{
  3388  					name:          "ICMP Network Prohibited",
  3389  					val:           int(stack.RejectIPv4WithICMPNetProhibited),
  3390  					errorICMPCode: uint8(header.ICMPv4NetProhibited),
  3391  				},
  3392  				{
  3393  					name:          "ICMP Host Prohibited",
  3394  					val:           int(stack.RejectIPv4WithICMPHostProhibited),
  3395  					errorICMPCode: uint8(header.ICMPv4HostProhibited),
  3396  				},
  3397  				{
  3398  					name:          "ICMP Administratively Prohibited",
  3399  					val:           int(stack.RejectIPv4WithICMPAdminProhibited),
  3400  					errorICMPCode: uint8(header.ICMPv4AdminProhibited),
  3401  				},
  3402  			},
  3403  			errorICMPType: uint8(header.ICMPv4DstUnreachable),
  3404  		},
  3405  		{
  3406  			name:              "IPv6",
  3407  			netProto:          header.IPv6ProtocolNumber,
  3408  			rxICMPEchoRequest: rxICMPv6EchoRequest,
  3409  
  3410  			icmpChecker: func(t *testing.T, v *buffer.View, dstAddr tcpip.Address, icmpType, icmpCode uint8, origPayload []byte) {
  3411  				t.Helper()
  3412  
  3413  				checker.IPv6(t, v,
  3414  					checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  3415  					checker.DstAddr(dstAddr),
  3416  					checker.ICMPv6(
  3417  						checker.ICMPv6Type(header.ICMPv6Type(icmpType)),
  3418  						checker.ICMPv6Code(header.ICMPv6Code(icmpCode)),
  3419  						checker.ICMPv6Payload(origPayload),
  3420  					),
  3421  				)
  3422  			},
  3423  			natHooks: []natHook{
  3424  				{
  3425  					hook:             stack.Input,
  3426  					dstAddr:          utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  3427  					matcher:          &icmpv6Matcher{icmpType: header.ICMPv6EchoRequest},
  3428  					errorICMPDstAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  3429  					errorICMPPayload: rxICMPv6EchoRequest(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
  3430  				},
  3431  				{
  3432  					hook:             stack.Forward,
  3433  					dstAddr:          utils.Host2IPv6Addr.AddressWithPrefix.Address,
  3434  					matcher:          &icmpv6Matcher{icmpType: header.ICMPv6EchoRequest},
  3435  					errorICMPDstAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
  3436  					errorICMPPayload: rxICMPv6EchoRequest(utils.Host2IPv6Addr.AddressWithPrefix.Address),
  3437  				},
  3438  				{
  3439  					hook:             stack.Output,
  3440  					dstAddr:          utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  3441  					matcher:          &icmpv6Matcher{icmpType: header.ICMPv6EchoReply},
  3442  					errorICMPDstAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
  3443  					errorICMPPayload: utils.ICMPv6Echo(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, utils.Host1IPv6Addr.AddressWithPrefix.Address, ttl, header.ICMPv6EchoReply),
  3444  				},
  3445  			},
  3446  			rejectTarget: func(t *testing.T, netProto stack.NetworkProtocol, rejectWith int) stack.Target {
  3447  				handler, ok := netProto.(stack.RejectIPv6WithHandler)
  3448  				if !ok {
  3449  					t.Fatalf("expected %T to implement %T", netProto, handler)
  3450  				}
  3451  
  3452  				return &stack.RejectIPv6Target{
  3453  					Handler:    handler,
  3454  					RejectWith: stack.RejectIPv6WithICMPType(rejectWith),
  3455  				}
  3456  			},
  3457  			rejectWithVals: []rejectWithVal{
  3458  				{
  3459  					name:          "ICMP No Route",
  3460  					val:           int(stack.RejectIPv6WithICMPNoRoute),
  3461  					errorICMPCode: uint8(header.ICMPv6NetworkUnreachable),
  3462  				},
  3463  				{
  3464  					name:          "ICMP Address Unreachable",
  3465  					val:           int(stack.RejectIPv6WithICMPAddrUnreachable),
  3466  					errorICMPCode: uint8(header.ICMPv6AddressUnreachable),
  3467  				},
  3468  				{
  3469  					name:          "ICMP Port Unreachable",
  3470  					val:           int(stack.RejectIPv6WithICMPPortUnreachable),
  3471  					errorICMPCode: uint8(header.ICMPv6PortUnreachable),
  3472  				},
  3473  				{
  3474  					name:          "ICMP Administratively Prohibited",
  3475  					val:           int(stack.RejectIPv6WithICMPAdminProhibited),
  3476  					errorICMPCode: uint8(header.ICMPv6Prohibited),
  3477  				},
  3478  			},
  3479  			errorICMPType: uint8(header.ICMPv6DstUnreachable),
  3480  		},
  3481  	}
  3482  
  3483  	for _, test := range tests {
  3484  		t.Run(test.name, func(t *testing.T) {
  3485  			for _, natHook := range test.natHooks {
  3486  				t.Run(natHook.hook.String(), func(t *testing.T) {
  3487  					for _, rejectWith := range test.rejectWithVals {
  3488  						t.Run(rejectWith.name, func(t *testing.T) {
  3489  							s := stack.New(stack.Options{
  3490  								NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  3491  								TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
  3492  							})
  3493  							defer s.Destroy()
  3494  
  3495  							ep1 := channel.New(1, header.IPv6MinimumMTU, "")
  3496  							ep2 := channel.New(1, header.IPv6MinimumMTU, "")
  3497  							utils.SetupRouterStack(t, s, ep1, ep2)
  3498  
  3499  							{
  3500  								ipv6 := test.netProto == ipv6.ProtocolNumber
  3501  								ipt := s.IPTables()
  3502  								filter := ipt.GetTable(stack.FilterID, ipv6)
  3503  								ruleIdx := filter.BuiltinChains[natHook.hook]
  3504  								filter.Rules[ruleIdx].Matchers = []stack.Matcher{natHook.matcher}
  3505  								filter.Rules[ruleIdx].Target = test.rejectTarget(t, s.NetworkProtocolInstance(test.netProto), rejectWith.val)
  3506  								// Make sure the packet is not dropped by the next rule.
  3507  								filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
  3508  								ipt.ForceReplaceTable(stack.FilterID, filter, ipv6)
  3509  							}
  3510  
  3511  							func() {
  3512  								pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  3513  									Payload: buffer.MakeWithData(test.rxICMPEchoRequest(natHook.dstAddr)),
  3514  								})
  3515  								defer pkt.DecRef()
  3516  								ep1.InjectInbound(test.netProto, pkt)
  3517  							}()
  3518  
  3519  							{
  3520  								pkt := ep1.Read()
  3521  								if pkt == nil {
  3522  									t.Fatal("expected to read a packet on ep1")
  3523  								}
  3524  								payload := stack.PayloadSince(pkt.NetworkHeader())
  3525  								defer payload.Release()
  3526  								test.icmpChecker(
  3527  									t,
  3528  									payload,
  3529  									natHook.errorICMPDstAddr,
  3530  									test.errorICMPType,
  3531  									rejectWith.errorICMPCode,
  3532  									natHook.errorICMPPayload,
  3533  								)
  3534  								pkt.DecRef()
  3535  							}
  3536  						})
  3537  					}
  3538  				})
  3539  			}
  3540  		})
  3541  	}
  3542  }
  3543  
  3544  // TestInvalidTransportHeader tests that bad transport headers (with a bad
  3545  // length/offset field) don't panic.
  3546  func TestInvalidTransportHeader(t *testing.T) {
  3547  	tests := []struct {
  3548  		name       string
  3549  		setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint)
  3550  		genPacket  func(int8) *stack.PacketBuffer
  3551  		offset     int8
  3552  	}{
  3553  		{
  3554  			name:       "TCP4 offset small",
  3555  			setupStack: genStackV4,
  3556  			genPacket:  genTCP4,
  3557  			offset:     -1,
  3558  		},
  3559  		{
  3560  			name:       "TCP4 offset large",
  3561  			setupStack: genStackV4,
  3562  			genPacket:  genTCP4,
  3563  			offset:     1,
  3564  		},
  3565  		{
  3566  			name:       "UDP4 offset small",
  3567  			setupStack: genStackV4,
  3568  			genPacket:  genUDP4,
  3569  			offset:     -1,
  3570  		},
  3571  		{
  3572  			name:       "UDP4 offset large",
  3573  			setupStack: genStackV4,
  3574  			genPacket:  genUDP4,
  3575  			offset:     1,
  3576  		},
  3577  		{
  3578  			name:       "TCP6 offset small",
  3579  			setupStack: genStackV6,
  3580  			genPacket:  genTCP6,
  3581  			offset:     -1,
  3582  		},
  3583  		{
  3584  			name:       "TCP6 offset large",
  3585  			setupStack: genStackV6,
  3586  			genPacket:  genTCP6,
  3587  			offset:     1,
  3588  		},
  3589  		{
  3590  			name:       "UDP6 offset small",
  3591  			setupStack: genStackV6,
  3592  			genPacket:  genUDP6,
  3593  			offset:     -1,
  3594  		},
  3595  		{
  3596  			name:       "UDP6 offset large",
  3597  			setupStack: genStackV6,
  3598  			genPacket:  genUDP6,
  3599  			offset:     1,
  3600  		},
  3601  	}
  3602  
  3603  	for _, test := range tests {
  3604  		t.Run(test.name, func(t *testing.T) {
  3605  			s, e := test.setupStack(t)
  3606  
  3607  			// Enable iptables and conntrack.
  3608  			ipt := s.IPTables()
  3609  			filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
  3610  			ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */)
  3611  
  3612  			// This can panic if conntrack isn't checking lengths.
  3613  			e.InjectInbound(header.IPv4ProtocolNumber, test.genPacket(test.offset))
  3614  		})
  3615  	}
  3616  }
  3617  
  3618  func genTCP4(offset int8) *stack.PacketBuffer {
  3619  	pktSize := header.IPv4MinimumSize + header.TCPMinimumSize
  3620  	hdr := prependable.New(pktSize)
  3621  
  3622  	tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
  3623  	tcp.Encode(&header.TCPFields{
  3624  		SeqNum:     0,
  3625  		AckNum:     0,
  3626  		DataOffset: header.TCPMinimumSize + uint8(offset)*4, // DataOffset must be a multiple of 4.
  3627  		Flags:      header.TCPFlagSyn,
  3628  		Checksum:   0,
  3629  	})
  3630  
  3631  	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  3632  	ip.Encode(&header.IPv4Fields{
  3633  		TOS:            0,
  3634  		TotalLength:    uint16(pktSize),
  3635  		ID:             1,
  3636  		Flags:          0,
  3637  		FragmentOffset: 0,
  3638  		TTL:            48,
  3639  		Protocol:       uint8(header.TCPProtocolNumber),
  3640  		SrcAddr:        srcAddrV4,
  3641  		DstAddr:        dstAddrV4,
  3642  	})
  3643  	ip.SetChecksum(0)
  3644  	ip.SetChecksum(^ip.CalculateChecksum())
  3645  
  3646  	buf := buffer.MakeWithData(append([]byte{}, hdr.View()...))
  3647  	return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
  3648  }
  3649  
  3650  func genTCP6(offset int8) *stack.PacketBuffer {
  3651  	pktSize := header.IPv6MinimumSize + header.TCPMinimumSize
  3652  	hdr := prependable.New(pktSize)
  3653  
  3654  	tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
  3655  	tcp.Encode(&header.TCPFields{
  3656  		SeqNum:     0,
  3657  		AckNum:     0,
  3658  		DataOffset: header.TCPMinimumSize + uint8(offset)*4, // DataOffset must be a multiple of 4.
  3659  		Flags:      header.TCPFlagSyn,
  3660  		Checksum:   0,
  3661  	})
  3662  
  3663  	ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  3664  	ip.Encode(&header.IPv6Fields{
  3665  		PayloadLength:     header.TCPMinimumSize,
  3666  		TransportProtocol: header.TCPProtocolNumber,
  3667  		HopLimit:          255,
  3668  		SrcAddr:           srcAddrV6,
  3669  		DstAddr:           dstAddrV6,
  3670  	})
  3671  
  3672  	buf := buffer.MakeWithData(append([]byte{}, hdr.View()...))
  3673  	return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
  3674  }
  3675  
  3676  func genUDP4(offset int8) *stack.PacketBuffer {
  3677  	pktSize := header.IPv4MinimumSize + header.UDPMinimumSize
  3678  	hdr := prependable.New(pktSize)
  3679  
  3680  	udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
  3681  	udp.Encode(&header.UDPFields{
  3682  		SrcPort:  343,
  3683  		DstPort:  2401,
  3684  		Length:   header.UDPMinimumSize + uint16(offset),
  3685  		Checksum: 0,
  3686  	})
  3687  
  3688  	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
  3689  	ip.Encode(&header.IPv4Fields{
  3690  		TOS:            0,
  3691  		TotalLength:    uint16(pktSize),
  3692  		ID:             1,
  3693  		Flags:          0,
  3694  		FragmentOffset: 0,
  3695  		TTL:            48,
  3696  		Protocol:       uint8(header.UDPProtocolNumber),
  3697  		SrcAddr:        srcAddrV4,
  3698  		DstAddr:        dstAddrV4,
  3699  	})
  3700  	ip.SetChecksum(0)
  3701  	ip.SetChecksum(^ip.CalculateChecksum())
  3702  
  3703  	buf := buffer.MakeWithData(append([]byte{}, hdr.View()...))
  3704  	return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
  3705  }
  3706  
  3707  func genUDP6(offset int8) *stack.PacketBuffer {
  3708  	pktSize := header.IPv6MinimumSize + header.UDPMinimumSize
  3709  	hdr := prependable.New(pktSize)
  3710  
  3711  	udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
  3712  	udp.Encode(&header.UDPFields{
  3713  		SrcPort:  343,
  3714  		DstPort:  2401,
  3715  		Length:   header.UDPMinimumSize + uint16(offset),
  3716  		Checksum: 0,
  3717  	})
  3718  
  3719  	ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
  3720  	ip.Encode(&header.IPv6Fields{
  3721  		PayloadLength:     header.UDPMinimumSize,
  3722  		TransportProtocol: header.UDPProtocolNumber,
  3723  		HopLimit:          255,
  3724  		SrcAddr:           srcAddrV6,
  3725  		DstAddr:           dstAddrV6,
  3726  	})
  3727  
  3728  	buf := buffer.MakeWithData(append([]byte{}, hdr.View()...))
  3729  	return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
  3730  }