github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/tests/integration/forward_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package forward_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"github.com/SagerNet/gvisor/pkg/tcpip"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/network/arp"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/tests/utils"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    35  	"github.com/SagerNet/gvisor/pkg/waiter"
    36  )
    37  
    38  const ttl = 64
    39  
    40  var (
    41  	ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
    42  	ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
    43  )
    44  
    45  func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
    46  	utils.RxICMPv4EchoRequest(e, src, dst, ttl)
    47  }
    48  
    49  func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
    50  	utils.RxICMPv6EchoRequest(e, src, dst, ttl)
    51  }
    52  
    53  func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
    54  	checker.IPv4(t, b,
    55  		checker.SrcAddr(src),
    56  		checker.DstAddr(dst),
    57  		checker.TTL(ttl-1),
    58  		checker.ICMPv4(
    59  			checker.ICMPv4Type(header.ICMPv4Echo)))
    60  }
    61  
    62  func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
    63  	checker.IPv6(t, b,
    64  		checker.SrcAddr(src),
    65  		checker.DstAddr(dst),
    66  		checker.TTL(ttl-1),
    67  		checker.ICMPv6(
    68  			checker.ICMPv6Type(header.ICMPv6EchoRequest)))
    69  }
    70  
    71  func TestForwarding(t *testing.T) {
    72  	const listenPort = 8080
    73  
    74  	type endpointAndAddresses struct {
    75  		serverEP         tcpip.Endpoint
    76  		serverAddr       tcpip.Address
    77  		serverReadableCH chan struct{}
    78  
    79  		clientEP         tcpip.Endpoint
    80  		clientAddr       tcpip.Address
    81  		clientReadableCH chan struct{}
    82  	}
    83  
    84  	newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
    85  		t.Helper()
    86  		var wq waiter.Queue
    87  		we, ch := waiter.NewChannelEntry(nil)
    88  		wq.EventRegister(&we, waiter.ReadableEvents)
    89  		ep, err := s.NewEndpoint(transProto, netProto, &wq)
    90  		if err != nil {
    91  			t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
    92  		}
    93  
    94  		t.Cleanup(func() {
    95  			wq.EventUnregister(&we)
    96  		})
    97  
    98  		return ep, ch
    99  	}
   100  
   101  	tests := []struct {
   102  		name       string
   103  		epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
   104  	}{
   105  		{
   106  			name: "IPv4 host1 server with host2 client",
   107  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   108  				ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
   109  				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
   110  				return endpointAndAddresses{
   111  					serverEP:         ep1,
   112  					serverAddr:       utils.Host1IPv4Addr.AddressWithPrefix.Address,
   113  					serverReadableCH: ep1WECH,
   114  
   115  					clientEP:         ep2,
   116  					clientAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
   117  					clientReadableCH: ep2WECH,
   118  				}
   119  			},
   120  		},
   121  		{
   122  			name: "IPv6 host2 server with host1 client",
   123  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   124  				ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
   125  				ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
   126  				return endpointAndAddresses{
   127  					serverEP:         ep1,
   128  					serverAddr:       utils.Host2IPv6Addr.AddressWithPrefix.Address,
   129  					serverReadableCH: ep1WECH,
   130  
   131  					clientEP:         ep2,
   132  					clientAddr:       utils.Host1IPv6Addr.AddressWithPrefix.Address,
   133  					clientReadableCH: ep2WECH,
   134  				}
   135  			},
   136  		},
   137  		{
   138  			name: "IPv4 host2 server with routerNIC1 client",
   139  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   140  				ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
   141  				ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber)
   142  				return endpointAndAddresses{
   143  					serverEP:         ep1,
   144  					serverAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
   145  					serverReadableCH: ep1WECH,
   146  
   147  					clientEP:         ep2,
   148  					clientAddr:       utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
   149  					clientReadableCH: ep2WECH,
   150  				}
   151  			},
   152  		},
   153  		{
   154  			name: "IPv6 routerNIC2 server with host1 client",
   155  			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
   156  				ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber)
   157  				ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
   158  				return endpointAndAddresses{
   159  					serverEP:         ep1,
   160  					serverAddr:       utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
   161  					serverReadableCH: ep1WECH,
   162  
   163  					clientEP:         ep2,
   164  					clientAddr:       utils.Host1IPv6Addr.AddressWithPrefix.Address,
   165  					clientReadableCH: ep2WECH,
   166  				}
   167  			},
   168  		},
   169  	}
   170  
   171  	subTests := []struct {
   172  		name               string
   173  		proto              tcpip.TransportProtocolNumber
   174  		expectedConnectErr tcpip.Error
   175  		setupServer        func(t *testing.T, ep tcpip.Endpoint)
   176  		setupServerConn    func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
   177  		needRemoteAddr     bool
   178  	}{
   179  		{
   180  			name:               "UDP",
   181  			proto:              udp.ProtocolNumber,
   182  			expectedConnectErr: nil,
   183  			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
   184  				t.Helper()
   185  
   186  				if err := ep.Connect(clientAddr); err != nil {
   187  					t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
   188  				}
   189  				return nil, nil
   190  			},
   191  			needRemoteAddr: true,
   192  		},
   193  		{
   194  			name:               "TCP",
   195  			proto:              tcp.ProtocolNumber,
   196  			expectedConnectErr: &tcpip.ErrConnectStarted{},
   197  			setupServer: func(t *testing.T, ep tcpip.Endpoint) {
   198  				t.Helper()
   199  
   200  				if err := ep.Listen(1); err != nil {
   201  					t.Fatalf("ep.Listen(1): %s", err)
   202  				}
   203  			},
   204  			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
   205  				t.Helper()
   206  
   207  				var addr tcpip.FullAddress
   208  				for {
   209  					newEP, wq, err := ep.Accept(&addr)
   210  					if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   211  						<-ch
   212  						continue
   213  					}
   214  					if err != nil {
   215  						t.Fatalf("ep.Accept(_): %s", err)
   216  					}
   217  					if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
   218  						"NIC",
   219  					)); diff != "" {
   220  						t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
   221  					}
   222  
   223  					we, newCH := waiter.NewChannelEntry(nil)
   224  					wq.EventRegister(&we, waiter.ReadableEvents)
   225  					return newEP, newCH
   226  				}
   227  			},
   228  			needRemoteAddr: false,
   229  		},
   230  	}
   231  
   232  	for _, test := range tests {
   233  		t.Run(test.name, func(t *testing.T) {
   234  			for _, subTest := range subTests {
   235  				t.Run(subTest.name, func(t *testing.T) {
   236  					stackOpts := stack.Options{
   237  						NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
   238  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
   239  					}
   240  
   241  					host1Stack := stack.New(stackOpts)
   242  					routerStack := stack.New(stackOpts)
   243  					host2Stack := stack.New(stackOpts)
   244  					utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
   245  
   246  					epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
   247  					defer epsAndAddrs.serverEP.Close()
   248  					defer epsAndAddrs.clientEP.Close()
   249  
   250  					serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
   251  					if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
   252  						t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
   253  					}
   254  					clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
   255  					if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
   256  						t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
   257  					}
   258  
   259  					if subTest.setupServer != nil {
   260  						subTest.setupServer(t, epsAndAddrs.serverEP)
   261  					}
   262  					{
   263  						err := epsAndAddrs.clientEP.Connect(serverAddr)
   264  						if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
   265  							t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
   266  						}
   267  					}
   268  					if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
   269  						t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
   270  					} else {
   271  						clientAddr = addr
   272  						clientAddr.NIC = 0
   273  					}
   274  
   275  					serverEP := epsAndAddrs.serverEP
   276  					serverCH := epsAndAddrs.serverReadableCH
   277  					if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil {
   278  						defer ep.Close()
   279  						serverEP = ep
   280  						serverCH = ch
   281  					}
   282  
   283  					write := func(ep tcpip.Endpoint, data []byte) {
   284  						t.Helper()
   285  
   286  						var r bytes.Reader
   287  						r.Reset(data)
   288  						var wOpts tcpip.WriteOptions
   289  						n, err := ep.Write(&r, wOpts)
   290  						if err != nil {
   291  							t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
   292  						}
   293  						if want := int64(len(data)); n != want {
   294  							t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
   295  						}
   296  					}
   297  
   298  					data := []byte{1, 2, 3, 4}
   299  					write(epsAndAddrs.clientEP, data)
   300  
   301  					read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
   302  						t.Helper()
   303  
   304  						var buf bytes.Buffer
   305  						var res tcpip.ReadResult
   306  						for {
   307  							var err tcpip.Error
   308  							opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
   309  							res, err = ep.Read(&buf, opts)
   310  							if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   311  								<-ch
   312  								continue
   313  							}
   314  							if err != nil {
   315  								t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
   316  							}
   317  							break
   318  						}
   319  
   320  						readResult := tcpip.ReadResult{
   321  							Count: len(data),
   322  							Total: len(data),
   323  						}
   324  						if subTest.needRemoteAddr {
   325  							readResult.RemoteAddr = expectedFrom
   326  						}
   327  						if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
   328  							"ControlMessages",
   329  							"RemoteAddr.NIC",
   330  						)); diff != "" {
   331  							t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
   332  						}
   333  						if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
   334  							t.Errorf("received data mismatch (-want +got):\n%s", diff)
   335  						}
   336  
   337  						if t.Failed() {
   338  							t.FailNow()
   339  						}
   340  					}
   341  
   342  					read(serverCH, serverEP, data, clientAddr)
   343  
   344  					data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
   345  					write(serverEP, data)
   346  					read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
   347  				})
   348  			}
   349  		})
   350  	}
   351  }
   352  
   353  func TestMulticastForwarding(t *testing.T) {
   354  	const (
   355  		nicID1 = 1
   356  		nicID2 = 2
   357  	)
   358  
   359  	var (
   360  		ipv4LinkLocalUnicastAddr   = testutil.MustParse4("169.254.0.10")
   361  		ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
   362  
   363  		ipv6LinkLocalUnicastAddr   = testutil.MustParse6("fe80::a")
   364  		ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
   365  	)
   366  
   367  	tests := []struct {
   368  		name             string
   369  		srcAddr, dstAddr tcpip.Address
   370  		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
   371  		expectForward    bool
   372  		checker          func(*testing.T, []byte)
   373  	}{
   374  		{
   375  			name:          "IPv4 link-local multicast destination",
   376  			srcAddr:       utils.RemoteIPv4Addr,
   377  			dstAddr:       ipv4LinkLocalMulticastAddr,
   378  			rx:            rxICMPv4EchoRequest,
   379  			expectForward: false,
   380  		},
   381  		{
   382  			name:          "IPv4 link-local source",
   383  			srcAddr:       ipv4LinkLocalUnicastAddr,
   384  			dstAddr:       utils.RemoteIPv4Addr,
   385  			rx:            rxICMPv4EchoRequest,
   386  			expectForward: false,
   387  		},
   388  		{
   389  			name:          "IPv4 link-local destination",
   390  			srcAddr:       utils.RemoteIPv4Addr,
   391  			dstAddr:       ipv4LinkLocalUnicastAddr,
   392  			rx:            rxICMPv4EchoRequest,
   393  			expectForward: false,
   394  		},
   395  		{
   396  			name:          "IPv4 non-link-local unicast",
   397  			srcAddr:       utils.RemoteIPv4Addr,
   398  			dstAddr:       utils.Ipv4Addr2.AddressWithPrefix.Address,
   399  			rx:            rxICMPv4EchoRequest,
   400  			expectForward: true,
   401  			checker: func(t *testing.T, b []byte) {
   402  				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
   403  			},
   404  		},
   405  		{
   406  			name:          "IPv4 non-link-local multicast",
   407  			srcAddr:       utils.RemoteIPv4Addr,
   408  			dstAddr:       ipv4GlobalMulticastAddr,
   409  			rx:            rxICMPv4EchoRequest,
   410  			expectForward: true,
   411  			checker: func(t *testing.T, b []byte) {
   412  				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
   413  			},
   414  		},
   415  
   416  		{
   417  			name:          "IPv6 link-local multicast destination",
   418  			srcAddr:       utils.RemoteIPv6Addr,
   419  			dstAddr:       ipv6LinkLocalMulticastAddr,
   420  			rx:            rxICMPv6EchoRequest,
   421  			expectForward: false,
   422  		},
   423  		{
   424  			name:          "IPv6 link-local source",
   425  			srcAddr:       ipv6LinkLocalUnicastAddr,
   426  			dstAddr:       utils.RemoteIPv6Addr,
   427  			rx:            rxICMPv6EchoRequest,
   428  			expectForward: false,
   429  		},
   430  		{
   431  			name:          "IPv6 link-local destination",
   432  			srcAddr:       utils.RemoteIPv6Addr,
   433  			dstAddr:       ipv6LinkLocalUnicastAddr,
   434  			rx:            rxICMPv6EchoRequest,
   435  			expectForward: false,
   436  		},
   437  		{
   438  			name:          "IPv6 non-link-local unicast",
   439  			srcAddr:       utils.RemoteIPv6Addr,
   440  			dstAddr:       utils.Ipv6Addr2.AddressWithPrefix.Address,
   441  			rx:            rxICMPv6EchoRequest,
   442  			expectForward: true,
   443  			checker: func(t *testing.T, b []byte) {
   444  				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
   445  			},
   446  		},
   447  		{
   448  			name:          "IPv6 non-link-local multicast",
   449  			srcAddr:       utils.RemoteIPv6Addr,
   450  			dstAddr:       ipv6GlobalMulticastAddr,
   451  			rx:            rxICMPv6EchoRequest,
   452  			expectForward: true,
   453  			checker: func(t *testing.T, b []byte) {
   454  				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
   455  			},
   456  		},
   457  	}
   458  
   459  	for _, test := range tests {
   460  		t.Run(test.name, func(t *testing.T) {
   461  			s := stack.New(stack.Options{
   462  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   463  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   464  			})
   465  
   466  			e1 := channel.New(1, header.IPv6MinimumMTU, "")
   467  			if err := s.CreateNIC(nicID1, e1); err != nil {
   468  				t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
   469  			}
   470  
   471  			e2 := channel.New(1, header.IPv6MinimumMTU, "")
   472  			if err := s.CreateNIC(nicID2, e2); err != nil {
   473  				t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
   474  			}
   475  
   476  			if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
   477  				t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
   478  			}
   479  			if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
   480  				t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
   481  			}
   482  
   483  			if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
   484  				t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
   485  			}
   486  			if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
   487  				t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
   488  			}
   489  
   490  			s.SetRouteTable([]tcpip.Route{
   491  				{
   492  					Destination: header.IPv4EmptySubnet,
   493  					NIC:         nicID2,
   494  				},
   495  				{
   496  					Destination: header.IPv6EmptySubnet,
   497  					NIC:         nicID2,
   498  				},
   499  			})
   500  
   501  			test.rx(e1, test.srcAddr, test.dstAddr)
   502  
   503  			p, ok := e2.Read()
   504  			if ok != test.expectForward {
   505  				t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward)
   506  			}
   507  
   508  			if test.expectForward {
   509  				test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
   510  			}
   511  		})
   512  	}
   513  }
   514  
   515  func TestPerInterfaceForwarding(t *testing.T) {
   516  	const (
   517  		nicID1 = 1
   518  		nicID2 = 2
   519  	)
   520  
   521  	tests := []struct {
   522  		name             string
   523  		srcAddr, dstAddr tcpip.Address
   524  		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
   525  		checker          func(*testing.T, []byte)
   526  	}{
   527  		{
   528  			name:    "IPv4 unicast",
   529  			srcAddr: utils.RemoteIPv4Addr,
   530  			dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
   531  			rx:      rxICMPv4EchoRequest,
   532  			checker: func(t *testing.T, b []byte) {
   533  				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
   534  			},
   535  		},
   536  		{
   537  			name:    "IPv4 multicast",
   538  			srcAddr: utils.RemoteIPv4Addr,
   539  			dstAddr: ipv4GlobalMulticastAddr,
   540  			rx:      rxICMPv4EchoRequest,
   541  			checker: func(t *testing.T, b []byte) {
   542  				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
   543  			},
   544  		},
   545  
   546  		{
   547  			name:    "IPv6 unicast",
   548  			srcAddr: utils.RemoteIPv6Addr,
   549  			dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
   550  			rx:      rxICMPv6EchoRequest,
   551  			checker: func(t *testing.T, b []byte) {
   552  				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
   553  			},
   554  		},
   555  		{
   556  			name:    "IPv6 multicast",
   557  			srcAddr: utils.RemoteIPv6Addr,
   558  			dstAddr: ipv6GlobalMulticastAddr,
   559  			rx:      rxICMPv6EchoRequest,
   560  			checker: func(t *testing.T, b []byte) {
   561  				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
   562  			},
   563  		},
   564  	}
   565  
   566  	netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
   567  
   568  	for _, test := range tests {
   569  		t.Run(test.name, func(t *testing.T) {
   570  			s := stack.New(stack.Options{
   571  				NetworkProtocols: []stack.NetworkProtocolFactory{
   572  					// ARP is not used in this test but it is a network protocol that does
   573  					// not support forwarding. We install the protocol to make sure that
   574  					// forwarding information for a NIC is only reported for network
   575  					// protocols that support forwarding.
   576  					arp.NewProtocol,
   577  
   578  					ipv4.NewProtocol,
   579  					ipv6.NewProtocol,
   580  				},
   581  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   582  			})
   583  
   584  			e1 := channel.New(1, header.IPv6MinimumMTU, "")
   585  			if err := s.CreateNIC(nicID1, e1); err != nil {
   586  				t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
   587  			}
   588  
   589  			e2 := channel.New(1, header.IPv6MinimumMTU, "")
   590  			if err := s.CreateNIC(nicID2, e2); err != nil {
   591  				t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
   592  			}
   593  
   594  			for _, add := range [...]struct {
   595  				nicID tcpip.NICID
   596  				addr  tcpip.ProtocolAddress
   597  			}{
   598  				{
   599  					nicID: nicID1,
   600  					addr:  utils.RouterNIC1IPv4Addr,
   601  				},
   602  				{
   603  					nicID: nicID1,
   604  					addr:  utils.RouterNIC1IPv6Addr,
   605  				},
   606  				{
   607  					nicID: nicID2,
   608  					addr:  utils.RouterNIC2IPv4Addr,
   609  				},
   610  				{
   611  					nicID: nicID2,
   612  					addr:  utils.RouterNIC2IPv6Addr,
   613  				},
   614  			} {
   615  				if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
   616  					t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
   617  				}
   618  			}
   619  
   620  			// Only enable forwarding on NIC1 and make sure that only packets arriving
   621  			// on NIC1 are forwarded.
   622  			for _, netProto := range netProtos {
   623  				if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
   624  					t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
   625  				}
   626  			}
   627  
   628  			nicsInfo := s.NICInfo()
   629  			for _, subTest := range [...]struct {
   630  				nicID            tcpip.NICID
   631  				nicEP            *channel.Endpoint
   632  				otherNICID       tcpip.NICID
   633  				otherNICEP       *channel.Endpoint
   634  				expectForwarding bool
   635  			}{
   636  				{
   637  					nicID:            nicID1,
   638  					nicEP:            e1,
   639  					otherNICID:       nicID2,
   640  					otherNICEP:       e2,
   641  					expectForwarding: true,
   642  				},
   643  				{
   644  					nicID:            nicID2,
   645  					nicEP:            e2,
   646  					otherNICID:       nicID2,
   647  					otherNICEP:       e1,
   648  					expectForwarding: false,
   649  				},
   650  			} {
   651  				t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
   652  					nicInfo, ok := nicsInfo[subTest.nicID]
   653  					if !ok {
   654  						t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
   655  					} else {
   656  						forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
   657  						for _, netProto := range netProtos {
   658  							forwarding[netProto] = subTest.expectForwarding
   659  						}
   660  
   661  						if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
   662  							t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
   663  						}
   664  					}
   665  
   666  					s.SetRouteTable([]tcpip.Route{
   667  						{
   668  							Destination: header.IPv4EmptySubnet,
   669  							NIC:         subTest.otherNICID,
   670  						},
   671  						{
   672  							Destination: header.IPv6EmptySubnet,
   673  							NIC:         subTest.otherNICID,
   674  						},
   675  					})
   676  
   677  					test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
   678  					if p, ok := subTest.nicEP.Read(); ok {
   679  						t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
   680  					}
   681  					if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
   682  						t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
   683  					} else if subTest.expectForwarding {
   684  						test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
   685  					}
   686  				})
   687  			}
   688  		})
   689  	}
   690  }