gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/tests/integration/forward_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package forward_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"gvisor.dev/gvisor/pkg/buffer"
    24  	"gvisor.dev/gvisor/pkg/tcpip"
    25  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    26  	"gvisor.dev/gvisor/pkg/tcpip/header"
    27  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    28  	"gvisor.dev/gvisor/pkg/tcpip/network/arp"
    29  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    30  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    31  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    32  	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
    33  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    34  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    35  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    36  	"gvisor.dev/gvisor/pkg/waiter"
    37  )
    38  
    39  const ttl = 64
    40  
    41  func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
    42  	utils.RxICMPv4EchoRequest(e, src, dst, ttl)
    43  }
    44  
    45  func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
    46  	utils.RxICMPv6EchoRequest(e, src, dst, ttl)
    47  }
    48  
    49  func forwardedICMPv4EchoRequestChecker(t *testing.T, v *buffer.View, src, dst tcpip.Address) {
    50  	checker.IPv4(t, v,
    51  		checker.SrcAddr(src),
    52  		checker.DstAddr(dst),
    53  		checker.TTL(ttl-1),
    54  		checker.ICMPv4(
    55  			checker.ICMPv4Type(header.ICMPv4Echo)))
    56  }
    57  
    58  func forwardedICMPv6EchoRequestChecker(t *testing.T, v *buffer.View, src, dst tcpip.Address) {
    59  	checker.IPv6(t, v,
    60  		checker.SrcAddr(src),
    61  		checker.DstAddr(dst),
    62  		checker.TTL(ttl-1),
    63  		checker.ICMPv6(
    64  			checker.ICMPv6Type(header.ICMPv6EchoRequest)))
    65  }
    66  
    67  func TestForwarding(t *testing.T) {
    68  	const listenPort = 8080
    69  
    70  	type endpointAndAddresses struct {
    71  		serverEP         tcpip.Endpoint
    72  		serverAddr       tcpip.Address
    73  		serverReadableCH chan struct{}
    74  
    75  		clientEP         tcpip.Endpoint
    76  		clientAddr       tcpip.Address
    77  		clientReadableCH chan struct{}
    78  	}
    79  
    80  	newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
    81  		t.Helper()
    82  		var wq waiter.Queue
    83  		we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
    84  		wq.EventRegister(&we)
    85  		ep, err := s.NewEndpoint(transProto, netProto, &wq)
    86  		if err != nil {
    87  			t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
    88  		}
    89  
    90  		t.Cleanup(func() {
    91  			wq.EventUnregister(&we)
    92  		})
    93  
    94  		return ep, ch
    95  	}
    96  
    97  	tests := []struct {
    98  		name       string
    99  		epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
   100  	}{
   101  		{
   102  			name: "IPv4 host1 server with host2 client",
   103  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   104  				ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
   105  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
   106  				return endpointAndAddresses{
   107  					serverEP:         ep1,
   108  					serverAddr:       utils.Host1IPv4Addr.AddressWithPrefix.Address,
   109  					serverReadableCH: ep1WECH,
   110  
   111  					clientEP:         ep2,
   112  					clientAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
   113  					clientReadableCH: ep2WECH,
   114  				}
   115  			},
   116  		},
   117  		{
   118  			name: "IPv6 host2 server with host1 client",
   119  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   120  				ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
   121  				ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
   122  				return endpointAndAddresses{
   123  					serverEP:         ep1,
   124  					serverAddr:       utils.Host2IPv6Addr.AddressWithPrefix.Address,
   125  					serverReadableCH: ep1WECH,
   126  
   127  					clientEP:         ep2,
   128  					clientAddr:       utils.Host1IPv6Addr.AddressWithPrefix.Address,
   129  					clientReadableCH: ep2WECH,
   130  				}
   131  			},
   132  		},
   133  		{
   134  			name: "IPv4 host2 server with routerNIC1 client",
   135  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   136  				ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
   137  				ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber)
   138  				return endpointAndAddresses{
   139  					serverEP:         ep1,
   140  					serverAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
   141  					serverReadableCH: ep1WECH,
   142  
   143  					clientEP:         ep2,
   144  					clientAddr:       utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
   145  					clientReadableCH: ep2WECH,
   146  				}
   147  			},
   148  		},
   149  		{
   150  			name: "IPv6 routerNIC2 server with host1 client",
   151  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   152  				ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber)
   153  				ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
   154  				return endpointAndAddresses{
   155  					serverEP:         ep1,
   156  					serverAddr:       utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
   157  					serverReadableCH: ep1WECH,
   158  
   159  					clientEP:         ep2,
   160  					clientAddr:       utils.Host1IPv6Addr.AddressWithPrefix.Address,
   161  					clientReadableCH: ep2WECH,
   162  				}
   163  			},
   164  		},
   165  	}
   166  
   167  	subTests := []struct {
   168  		name               string
   169  		proto              tcpip.TransportProtocolNumber
   170  		expectedConnectErr tcpip.Error
   171  		setupServer        func(t *testing.T, ep tcpip.Endpoint)
   172  		setupServerConn    func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
   173  		needRemoteAddr     bool
   174  	}{
   175  		{
   176  			name:               "UDP",
   177  			proto:              udp.ProtocolNumber,
   178  			expectedConnectErr: nil,
   179  			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
   180  				t.Helper()
   181  
   182  				if err := ep.Connect(clientAddr); err != nil {
   183  					t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
   184  				}
   185  				return nil, nil
   186  			},
   187  			needRemoteAddr: true,
   188  		},
   189  		{
   190  			name:               "TCP",
   191  			proto:              tcp.ProtocolNumber,
   192  			expectedConnectErr: &tcpip.ErrConnectStarted{},
   193  			setupServer: func(t *testing.T, ep tcpip.Endpoint) {
   194  				t.Helper()
   195  
   196  				if err := ep.Listen(1); err != nil {
   197  					t.Fatalf("ep.Listen(1): %s", err)
   198  				}
   199  			},
   200  			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
   201  				t.Helper()
   202  
   203  				var addr tcpip.FullAddress
   204  				for {
   205  					newEP, wq, err := ep.Accept(&addr)
   206  					if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   207  						<-ch
   208  						continue
   209  					}
   210  					if err != nil {
   211  						t.Fatalf("ep.Accept(_): %s", err)
   212  					}
   213  					if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
   214  						"NIC",
   215  					)); diff != "" {
   216  						t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
   217  					}
   218  
   219  					we, newCH := waiter.NewChannelEntry(waiter.ReadableEvents)
   220  					wq.EventRegister(&we)
   221  					return newEP, newCH
   222  				}
   223  			},
   224  			needRemoteAddr: false,
   225  		},
   226  	}
   227  
   228  	for _, test := range tests {
   229  		t.Run(test.name, func(t *testing.T) {
   230  			for _, subTest := range subTests {
   231  				t.Run(subTest.name, func(t *testing.T) {
   232  					stackOpts := stack.Options{
   233  						NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   234  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
   235  					}
   236  
   237  					host1Stack := stack.New(stackOpts)
   238  					defer host1Stack.Destroy()
   239  					routerStack := stack.New(stackOpts)
   240  					defer routerStack.Destroy()
   241  					host2Stack := stack.New(stackOpts)
   242  					defer host2Stack.Destroy()
   243  					utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
   244  
   245  					epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
   246  					defer epsAndAddrs.serverEP.Close()
   247  					defer epsAndAddrs.clientEP.Close()
   248  
   249  					serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
   250  					if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
   251  						t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
   252  					}
   253  					clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
   254  					if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
   255  						t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
   256  					}
   257  
   258  					if subTest.setupServer != nil {
   259  						subTest.setupServer(t, epsAndAddrs.serverEP)
   260  					}
   261  					{
   262  						err := epsAndAddrs.clientEP.Connect(serverAddr)
   263  						if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
   264  							t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
   265  						}
   266  					}
   267  					if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
   268  						t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
   269  					} else {
   270  						clientAddr = addr
   271  						clientAddr.NIC = 0
   272  					}
   273  
   274  					serverEP := epsAndAddrs.serverEP
   275  					serverCH := epsAndAddrs.serverReadableCH
   276  					if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil {
   277  						defer ep.Close()
   278  						serverEP = ep
   279  						serverCH = ch
   280  					}
   281  
   282  					write := func(ep tcpip.Endpoint, data []byte) {
   283  						t.Helper()
   284  
   285  						var r bytes.Reader
   286  						r.Reset(data)
   287  						var wOpts tcpip.WriteOptions
   288  						n, err := ep.Write(&r, wOpts)
   289  						if err != nil {
   290  							t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
   291  						}
   292  						if want := int64(len(data)); n != want {
   293  							t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
   294  						}
   295  					}
   296  
   297  					data := []byte{1, 2, 3, 4}
   298  					write(epsAndAddrs.clientEP, data)
   299  
   300  					read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
   301  						t.Helper()
   302  
   303  						var buf bytes.Buffer
   304  						var res tcpip.ReadResult
   305  						for {
   306  							var err tcpip.Error
   307  							opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
   308  							res, err = ep.Read(&buf, opts)
   309  							if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   310  								<-ch
   311  								continue
   312  							}
   313  							if err != nil {
   314  								t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
   315  							}
   316  							break
   317  						}
   318  
   319  						readResult := tcpip.ReadResult{
   320  							Count: len(data),
   321  							Total: len(data),
   322  						}
   323  						if subTest.needRemoteAddr {
   324  							readResult.RemoteAddr = expectedFrom
   325  						}
   326  						if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
   327  							"ControlMessages",
   328  							"RemoteAddr.NIC",
   329  						)); diff != "" {
   330  							t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   331  						}
   332  						if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
   333  							t.Errorf("received data mismatch (-want +got):\n%s", diff)
   334  						}
   335  
   336  						if t.Failed() {
   337  							t.FailNow()
   338  						}
   339  					}
   340  
   341  					read(serverCH, serverEP, data, clientAddr)
   342  
   343  					data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
   344  					write(serverEP, data)
   345  					read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
   346  				})
   347  			}
   348  		})
   349  	}
   350  }
   351  
   352  type fillableLinkEndpoint struct {
   353  	*channel.Endpoint
   354  	full bool
   355  }
   356  
   357  func (e *fillableLinkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
   358  	if e.full {
   359  		return 0, &tcpip.ErrNoBufferSpace{}
   360  	}
   361  
   362  	return e.Endpoint.WritePackets(pkts)
   363  }
   364  
   365  func TestUnicastForwarding(t *testing.T) {
   366  	const (
   367  		nicID1 = 1
   368  		nicID2 = 2
   369  	)
   370  
   371  	var (
   372  		ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
   373  		ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
   374  	)
   375  
   376  	tests := []struct {
   377  		name             string
   378  		netProto         tcpip.NetworkProtocolNumber
   379  		srcAddr, dstAddr tcpip.Address
   380  		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
   381  		expectForward    bool
   382  		checker          func(*testing.T, *buffer.View)
   383  	}{
   384  		{
   385  			name:          "IPv4 link-local source",
   386  			netProto:      ipv4.ProtocolNumber,
   387  			srcAddr:       ipv4LinkLocalUnicastAddr,
   388  			dstAddr:       utils.RemoteIPv4Addr,
   389  			rx:            rxICMPv4EchoRequest,
   390  			expectForward: false,
   391  		},
   392  		{
   393  			name:          "IPv4 link-local destination",
   394  			netProto:      ipv4.ProtocolNumber,
   395  			srcAddr:       utils.RemoteIPv4Addr,
   396  			dstAddr:       ipv4LinkLocalUnicastAddr,
   397  			rx:            rxICMPv4EchoRequest,
   398  			expectForward: false,
   399  		},
   400  		{
   401  			name:          "IPv4 non-link-local unicast",
   402  			netProto:      ipv4.ProtocolNumber,
   403  			srcAddr:       utils.RemoteIPv4Addr,
   404  			dstAddr:       utils.Ipv4Addr2.AddressWithPrefix.Address,
   405  			rx:            rxICMPv4EchoRequest,
   406  			expectForward: true,
   407  			checker: func(t *testing.T, v *buffer.View) {
   408  				forwardedICMPv4EchoRequestChecker(t, v, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
   409  			},
   410  		},
   411  		{
   412  			name:          "IPv6 link-local source",
   413  			netProto:      ipv6.ProtocolNumber,
   414  			srcAddr:       ipv6LinkLocalUnicastAddr,
   415  			dstAddr:       utils.RemoteIPv6Addr,
   416  			rx:            rxICMPv6EchoRequest,
   417  			expectForward: false,
   418  		},
   419  		{
   420  			name:          "IPv6 link-local destination",
   421  			netProto:      ipv6.ProtocolNumber,
   422  			srcAddr:       utils.RemoteIPv6Addr,
   423  			dstAddr:       ipv6LinkLocalUnicastAddr,
   424  			rx:            rxICMPv6EchoRequest,
   425  			expectForward: false,
   426  		},
   427  		{
   428  			name:          "IPv6 non-link-local unicast",
   429  			netProto:      ipv6.ProtocolNumber,
   430  			srcAddr:       utils.RemoteIPv6Addr,
   431  			dstAddr:       utils.Ipv6Addr2.AddressWithPrefix.Address,
   432  			rx:            rxICMPv6EchoRequest,
   433  			expectForward: true,
   434  			checker: func(t *testing.T, v *buffer.View) {
   435  				forwardedICMPv6EchoRequestChecker(t, v, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
   436  			},
   437  		},
   438  	}
   439  
   440  	for _, test := range tests {
   441  		t.Run(test.name, func(t *testing.T) {
   442  			for _, full := range []bool{true, false} {
   443  				t.Run(fmt.Sprintf("Full=%t", full), func(t *testing.T) {
   444  					s := stack.New(stack.Options{
   445  						NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   446  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   447  					})
   448  
   449  					e1 := channel.New(1, header.IPv6MinimumMTU, "")
   450  					defer e1.Close()
   451  					if err := s.CreateNIC(nicID1, e1); err != nil {
   452  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
   453  					}
   454  
   455  					e2 := fillableLinkEndpoint{Endpoint: channel.New(1, header.IPv6MinimumMTU, ""), full: full}
   456  					defer e2.Close()
   457  					if err := s.CreateNIC(nicID2, &e2); err != nil {
   458  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
   459  					}
   460  
   461  					protocolAddrV4 := tcpip.ProtocolAddress{
   462  						Protocol:          ipv4.ProtocolNumber,
   463  						AddressWithPrefix: utils.Ipv4Addr,
   464  					}
   465  					if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
   466  						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
   467  					}
   468  					protocolAddrV6 := tcpip.ProtocolAddress{
   469  						Protocol:          ipv6.ProtocolNumber,
   470  						AddressWithPrefix: utils.Ipv6Addr,
   471  					}
   472  					if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
   473  						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
   474  					}
   475  
   476  					if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
   477  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
   478  					}
   479  					if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
   480  						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
   481  					}
   482  
   483  					s.SetRouteTable([]tcpip.Route{
   484  						{
   485  							Destination: header.IPv4EmptySubnet,
   486  							NIC:         nicID2,
   487  						},
   488  						{
   489  							Destination: header.IPv6EmptySubnet,
   490  							NIC:         nicID2,
   491  						},
   492  					})
   493  
   494  					test.rx(e1, test.srcAddr, test.dstAddr)
   495  
   496  					expectForward := test.expectForward && !full
   497  					p := e2.Read()
   498  					if (p != nil) != expectForward {
   499  						t.Fatalf("got e2.Read() = %#v, want = (_ == nil) = %t", p, expectForward)
   500  					}
   501  
   502  					if expectForward {
   503  						payload := stack.PayloadSince(p.NetworkHeader())
   504  						defer payload.Release()
   505  						test.checker(t, payload)
   506  						p.DecRef()
   507  					}
   508  
   509  					checkOutgoingDeviceNoBufferSpaceCounter := func(nicID tcpip.NICID, expectErr bool) {
   510  						t.Helper()
   511  
   512  						expectCounter := uint64(0)
   513  						if expectErr {
   514  							expectCounter = 1
   515  						}
   516  
   517  						netEP, err := s.GetNetworkEndpoint(nicID, test.netProto)
   518  						if err != nil {
   519  							t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.netProto, err)
   520  						}
   521  
   522  						stats := netEP.Stats()
   523  						ipStats, ok := stats.(stack.IPNetworkEndpointStats)
   524  						if !ok {
   525  							t.Fatalf("%#v is not a %T", stats, ipStats)
   526  						}
   527  
   528  						if got := ipStats.IPStats().Forwarding.OutgoingDeviceNoBufferSpace.Value(); got != expectCounter {
   529  							t.Errorf("got ipStats.IPStats().Forwarding.OutgoingDeviceNoBufferSpace.Value() = %d, want = %d", got, expectCounter)
   530  						}
   531  					}
   532  					checkOutgoingDeviceNoBufferSpaceCounter(nicID1, test.expectForward && full)
   533  					checkOutgoingDeviceNoBufferSpaceCounter(nicID2, false)
   534  				})
   535  			}
   536  		})
   537  	}
   538  }
   539  
   540  func TestPerInterfaceForwarding(t *testing.T) {
   541  	const (
   542  		nicID1 = 1
   543  		nicID2 = 2
   544  	)
   545  
   546  	tests := []struct {
   547  		name             string
   548  		srcAddr, dstAddr tcpip.Address
   549  		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
   550  		checker          func(*testing.T, *buffer.View)
   551  	}{
   552  		{
   553  			name:    "IPv4 unicast",
   554  			srcAddr: utils.RemoteIPv4Addr,
   555  			dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
   556  			rx:      rxICMPv4EchoRequest,
   557  			checker: func(t *testing.T, v *buffer.View) {
   558  				forwardedICMPv4EchoRequestChecker(t, v, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
   559  			},
   560  		},
   561  		{
   562  			name:    "IPv6 unicast",
   563  			srcAddr: utils.RemoteIPv6Addr,
   564  			dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
   565  			rx:      rxICMPv6EchoRequest,
   566  			checker: func(t *testing.T, v *buffer.View) {
   567  				forwardedICMPv6EchoRequestChecker(t, v, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
   568  			},
   569  		},
   570  	}
   571  
   572  	netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
   573  
   574  	for _, test := range tests {
   575  		t.Run(test.name, func(t *testing.T) {
   576  			s := stack.New(stack.Options{
   577  				NetworkProtocols: []stack.NetworkProtocolFactory{
   578  					// ARP is not used in this test but it is a network protocol that does
   579  					// not support forwarding. We install the protocol to make sure that
   580  					// forwarding information for a NIC is only reported for network
   581  					// protocols that support forwarding.
   582  					arp.NewProtocol,
   583  
   584  					ipv4.NewProtocol,
   585  					ipv6.NewProtocol,
   586  				},
   587  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   588  			})
   589  
   590  			e1 := channel.New(1, header.IPv6MinimumMTU, "")
   591  			defer e1.Close()
   592  			if err := s.CreateNIC(nicID1, e1); err != nil {
   593  				t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
   594  			}
   595  
   596  			e2 := channel.New(1, header.IPv6MinimumMTU, "")
   597  			defer e2.Close()
   598  			if err := s.CreateNIC(nicID2, e2); err != nil {
   599  				t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
   600  			}
   601  
   602  			for _, add := range [...]struct {
   603  				nicID tcpip.NICID
   604  				addr  tcpip.ProtocolAddress
   605  			}{
   606  				{
   607  					nicID: nicID1,
   608  					addr:  utils.RouterNIC1IPv4Addr,
   609  				},
   610  				{
   611  					nicID: nicID1,
   612  					addr:  utils.RouterNIC1IPv6Addr,
   613  				},
   614  				{
   615  					nicID: nicID2,
   616  					addr:  utils.RouterNIC2IPv4Addr,
   617  				},
   618  				{
   619  					nicID: nicID2,
   620  					addr:  utils.RouterNIC2IPv6Addr,
   621  				},
   622  			} {
   623  				if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
   624  					t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
   625  				}
   626  			}
   627  
   628  			// Only enable forwarding on NIC1 and make sure that only packets arriving
   629  			// on NIC1 are forwarded.
   630  			for _, netProto := range netProtos {
   631  				if _, err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
   632  					t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
   633  				}
   634  			}
   635  
   636  			nicsInfo := s.NICInfo()
   637  			for _, subTest := range [...]struct {
   638  				nicID            tcpip.NICID
   639  				nicEP            *channel.Endpoint
   640  				otherNICID       tcpip.NICID
   641  				otherNICEP       *channel.Endpoint
   642  				expectForwarding bool
   643  			}{
   644  				{
   645  					nicID:            nicID1,
   646  					nicEP:            e1,
   647  					otherNICID:       nicID2,
   648  					otherNICEP:       e2,
   649  					expectForwarding: true,
   650  				},
   651  				{
   652  					nicID:            nicID2,
   653  					nicEP:            e2,
   654  					otherNICID:       nicID1,
   655  					otherNICEP:       e1,
   656  					expectForwarding: false,
   657  				},
   658  			} {
   659  				t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
   660  					nicInfo, ok := nicsInfo[subTest.nicID]
   661  					if !ok {
   662  						t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
   663  					} else {
   664  						forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
   665  						for _, netProto := range netProtos {
   666  							forwarding[netProto] = subTest.expectForwarding
   667  						}
   668  
   669  						if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
   670  							t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
   671  						}
   672  					}
   673  
   674  					s.SetRouteTable([]tcpip.Route{
   675  						{
   676  							Destination: header.IPv4EmptySubnet,
   677  							NIC:         subTest.otherNICID,
   678  						},
   679  						{
   680  							Destination: header.IPv6EmptySubnet,
   681  							NIC:         subTest.otherNICID,
   682  						},
   683  					})
   684  
   685  					test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
   686  					if p := subTest.nicEP.Read(); p != nil {
   687  						t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
   688  						p.DecRef()
   689  					}
   690  					p := subTest.otherNICEP.Read()
   691  					if (p != nil) != subTest.expectForwarding {
   692  						t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
   693  					}
   694  					if p != nil {
   695  						payload := stack.PayloadSince(p.NetworkHeader())
   696  						defer payload.Release()
   697  						test.checker(t, payload)
   698  						p.DecRef()
   699  					}
   700  				})
   701  			}
   702  		})
   703  	}
   704  }