github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/tests/integration/route_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package route_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"github.com/SagerNet/gvisor/pkg/tcpip"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/link/loopback"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/tests/utils"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/icmp"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    36  	"github.com/SagerNet/gvisor/pkg/waiter"
    37  )
    38  
    39  // TestLocalPing tests pinging a remote that is local the stack.
    40  //
    41  // This tests that a local route is created and packets do not leave the stack.
    42  func TestLocalPing(t *testing.T) {
    43  	const (
    44  		nicID = 1
    45  
    46  		// icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
    47  		// request/reply packets.
    48  		icmpDataOffset = 8
    49  	)
    50  	ipv4Loopback := testutil.MustParse4("127.0.0.1")
    51  
    52  	channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
    53  	channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
    54  		channelEP := e.(*channel.Endpoint)
    55  		if n := channelEP.Drain(); n != 0 {
    56  			t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
    57  		}
    58  	}
    59  
    60  	ipv4ICMPBuf := func(t *testing.T) buffer.View {
    61  		data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
    62  		hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
    63  		hdr.SetType(header.ICMPv4Echo)
    64  		if n := copy(hdr.Payload(), data[:]); n != len(data) {
    65  			t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
    66  		}
    67  		return buffer.View(hdr)
    68  	}
    69  
    70  	ipv6ICMPBuf := func(t *testing.T) buffer.View {
    71  		data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
    72  		hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
    73  		hdr.SetType(header.ICMPv6EchoRequest)
    74  		if n := copy(hdr.Payload(), data[:]); n != len(data) {
    75  			t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
    76  		}
    77  		return buffer.View(hdr)
    78  	}
    79  
    80  	tests := []struct {
    81  		name               string
    82  		transProto         tcpip.TransportProtocolNumber
    83  		netProto           tcpip.NetworkProtocolNumber
    84  		linkEndpoint       func() stack.LinkEndpoint
    85  		localAddr          tcpip.Address
    86  		icmpBuf            func(*testing.T) buffer.View
    87  		expectedConnectErr tcpip.Error
    88  		checkLinkEndpoint  func(t *testing.T, e stack.LinkEndpoint)
    89  	}{
    90  		{
    91  			name:              "IPv4 loopback",
    92  			transProto:        icmp.ProtocolNumber4,
    93  			netProto:          ipv4.ProtocolNumber,
    94  			linkEndpoint:      loopback.New,
    95  			localAddr:         ipv4Loopback,
    96  			icmpBuf:           ipv4ICMPBuf,
    97  			checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
    98  		},
    99  		{
   100  			name:              "IPv6 loopback",
   101  			transProto:        icmp.ProtocolNumber6,
   102  			netProto:          ipv6.ProtocolNumber,
   103  			linkEndpoint:      loopback.New,
   104  			localAddr:         header.IPv6Loopback,
   105  			icmpBuf:           ipv6ICMPBuf,
   106  			checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
   107  		},
   108  		{
   109  			name:              "IPv4 non-loopback",
   110  			transProto:        icmp.ProtocolNumber4,
   111  			netProto:          ipv4.ProtocolNumber,
   112  			linkEndpoint:      channelEP,
   113  			localAddr:         utils.Ipv4Addr.Address,
   114  			icmpBuf:           ipv4ICMPBuf,
   115  			checkLinkEndpoint: channelEPCheck,
   116  		},
   117  		{
   118  			name:              "IPv6 non-loopback",
   119  			transProto:        icmp.ProtocolNumber6,
   120  			netProto:          ipv6.ProtocolNumber,
   121  			linkEndpoint:      channelEP,
   122  			localAddr:         utils.Ipv6Addr.Address,
   123  			icmpBuf:           ipv6ICMPBuf,
   124  			checkLinkEndpoint: channelEPCheck,
   125  		},
   126  		{
   127  			name:               "IPv4 loopback without local address",
   128  			transProto:         icmp.ProtocolNumber4,
   129  			netProto:           ipv4.ProtocolNumber,
   130  			linkEndpoint:       loopback.New,
   131  			icmpBuf:            ipv4ICMPBuf,
   132  			expectedConnectErr: &tcpip.ErrNoRoute{},
   133  			checkLinkEndpoint:  func(*testing.T, stack.LinkEndpoint) {},
   134  		},
   135  		{
   136  			name:               "IPv6 loopback without local address",
   137  			transProto:         icmp.ProtocolNumber6,
   138  			netProto:           ipv6.ProtocolNumber,
   139  			linkEndpoint:       loopback.New,
   140  			icmpBuf:            ipv6ICMPBuf,
   141  			expectedConnectErr: &tcpip.ErrNoRoute{},
   142  			checkLinkEndpoint:  func(*testing.T, stack.LinkEndpoint) {},
   143  		},
   144  		{
   145  			name:               "IPv4 non-loopback without local address",
   146  			transProto:         icmp.ProtocolNumber4,
   147  			netProto:           ipv4.ProtocolNumber,
   148  			linkEndpoint:       channelEP,
   149  			icmpBuf:            ipv4ICMPBuf,
   150  			expectedConnectErr: &tcpip.ErrNoRoute{},
   151  			checkLinkEndpoint:  channelEPCheck,
   152  		},
   153  		{
   154  			name:               "IPv6 non-loopback without local address",
   155  			transProto:         icmp.ProtocolNumber6,
   156  			netProto:           ipv6.ProtocolNumber,
   157  			linkEndpoint:       channelEP,
   158  			icmpBuf:            ipv6ICMPBuf,
   159  			expectedConnectErr: &tcpip.ErrNoRoute{},
   160  			checkLinkEndpoint:  channelEPCheck,
   161  		},
   162  	}
   163  
   164  	for _, test := range tests {
   165  		t.Run(test.name, func(t *testing.T) {
   166  			for _, allowExternalLoopback := range []bool{true, false} {
   167  				t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) {
   168  					s := stack.New(stack.Options{
   169  						NetworkProtocols: []stack.NetworkProtocolFactory{
   170  							ipv4.NewProtocolWithOptions(ipv4.Options{
   171  								AllowExternalLoopbackTraffic: allowExternalLoopback,
   172  							}),
   173  							ipv6.NewProtocolWithOptions(ipv6.Options{
   174  								AllowExternalLoopbackTraffic: allowExternalLoopback,
   175  							}),
   176  						},
   177  						TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
   178  						HandleLocal:        true,
   179  					})
   180  					e := test.linkEndpoint()
   181  					if err := s.CreateNIC(nicID, e); err != nil {
   182  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   183  					}
   184  
   185  					if len(test.localAddr) != 0 {
   186  						if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
   187  							t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
   188  						}
   189  					}
   190  
   191  					var wq waiter.Queue
   192  					we, ch := waiter.NewChannelEntry(nil)
   193  					wq.EventRegister(&we, waiter.ReadableEvents)
   194  					ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
   195  					if err != nil {
   196  						t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
   197  					}
   198  					defer ep.Close()
   199  
   200  					connAddr := tcpip.FullAddress{Addr: test.localAddr}
   201  					if err := ep.Connect(connAddr); err != test.expectedConnectErr {
   202  						t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
   203  					}
   204  
   205  					if test.expectedConnectErr != nil {
   206  						return
   207  					}
   208  
   209  					var r bytes.Reader
   210  					payload := test.icmpBuf(t)
   211  					r.Reset(payload)
   212  					var wOpts tcpip.WriteOptions
   213  					if n, err := ep.Write(&r, wOpts); err != nil {
   214  						t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
   215  					} else if n != int64(len(payload)) {
   216  						t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
   217  					}
   218  
   219  					// Wait for the endpoint to become readable.
   220  					<-ch
   221  
   222  					var w bytes.Buffer
   223  					rr, err := ep.Read(&w, tcpip.ReadOptions{
   224  						NeedRemoteAddr: true,
   225  					})
   226  					if err != nil {
   227  						t.Fatalf("ep.Read(...): %s", err)
   228  					}
   229  					if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
   230  						t.Errorf("received data mismatch (-want +got):\n%s", diff)
   231  					}
   232  					if rr.RemoteAddr.Addr != test.localAddr {
   233  						t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
   234  					}
   235  
   236  					test.checkLinkEndpoint(t, e)
   237  				})
   238  			}
   239  		})
   240  	}
   241  }
   242  
   243  // TestLocalUDP tests sending UDP packets between two endpoints that are local
   244  // to the stack.
   245  //
   246  // This tests that that packets never leave the stack and the addresses
   247  // used when sending a packet.
   248  func TestLocalUDP(t *testing.T) {
   249  	const (
   250  		nicID = 1
   251  	)
   252  
   253  	tests := []struct {
   254  		name             string
   255  		canBePrimaryAddr tcpip.ProtocolAddress
   256  		firstPrimaryAddr tcpip.ProtocolAddress
   257  	}{
   258  		{
   259  			name:             "IPv4",
   260  			canBePrimaryAddr: utils.Ipv4Addr1,
   261  			firstPrimaryAddr: utils.Ipv4Addr2,
   262  		},
   263  		{
   264  			name:             "IPv6",
   265  			canBePrimaryAddr: utils.Ipv6Addr1,
   266  			firstPrimaryAddr: utils.Ipv6Addr2,
   267  		},
   268  	}
   269  
   270  	subTests := []struct {
   271  		name             string
   272  		addAddress       bool
   273  		expectedWriteErr tcpip.Error
   274  	}{
   275  		{
   276  			name:             "Unassigned local address",
   277  			addAddress:       false,
   278  			expectedWriteErr: &tcpip.ErrNoRoute{},
   279  		},
   280  		{
   281  			name:             "Assigned local address",
   282  			addAddress:       true,
   283  			expectedWriteErr: nil,
   284  		},
   285  	}
   286  
   287  	for _, test := range tests {
   288  		t.Run(test.name, func(t *testing.T) {
   289  			for _, subTest := range subTests {
   290  				t.Run(subTest.name, func(t *testing.T) {
   291  					stackOpts := stack.Options{
   292  						NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   293  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   294  						HandleLocal:        true,
   295  					}
   296  
   297  					s := stack.New(stackOpts)
   298  					ep := channel.New(1, header.IPv6MinimumMTU, "")
   299  
   300  					if err := s.CreateNIC(nicID, ep); err != nil {
   301  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   302  					}
   303  
   304  					if subTest.addAddress {
   305  						if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
   306  							t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
   307  						}
   308  						if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
   309  							t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
   310  						}
   311  					}
   312  
   313  					var serverWQ waiter.Queue
   314  					serverWE, serverCH := waiter.NewChannelEntry(nil)
   315  					serverWQ.EventRegister(&serverWE, waiter.ReadableEvents)
   316  					server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
   317  					if err != nil {
   318  						t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
   319  					}
   320  					defer server.Close()
   321  
   322  					bindAddr := tcpip.FullAddress{Port: 80}
   323  					if err := server.Bind(bindAddr); err != nil {
   324  						t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
   325  					}
   326  
   327  					var clientWQ waiter.Queue
   328  					clientWE, clientCH := waiter.NewChannelEntry(nil)
   329  					clientWQ.EventRegister(&clientWE, waiter.ReadableEvents)
   330  					client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
   331  					if err != nil {
   332  						t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
   333  					}
   334  					defer client.Close()
   335  
   336  					serverAddr := tcpip.FullAddress{
   337  						Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
   338  						Port: 80,
   339  					}
   340  
   341  					clientPayload := []byte{1, 2, 3, 4}
   342  					{
   343  						var r bytes.Reader
   344  						r.Reset(clientPayload)
   345  						wOpts := tcpip.WriteOptions{
   346  							To: &serverAddr,
   347  						}
   348  						if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
   349  							t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
   350  						} else if subTest.expectedWriteErr != nil {
   351  							// Nothing else to test if we expected not to be able to send the
   352  							// UDP packet.
   353  							return
   354  						} else if n != int64(len(clientPayload)) {
   355  							t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload))
   356  						}
   357  					}
   358  
   359  					// Wait for the server endpoint to become readable.
   360  					<-serverCH
   361  
   362  					var clientAddr tcpip.FullAddress
   363  					var readBuf bytes.Buffer
   364  					if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
   365  						t.Fatalf("server.Read(_): %s", err)
   366  					} else {
   367  						clientAddr = read.RemoteAddr
   368  
   369  						if diff := cmp.Diff(tcpip.ReadResult{
   370  							Count: readBuf.Len(),
   371  							Total: readBuf.Len(),
   372  							RemoteAddr: tcpip.FullAddress{
   373  								Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
   374  							},
   375  						}, read, checker.IgnoreCmpPath(
   376  							"ControlMessages",
   377  							"RemoteAddr.NIC",
   378  							"RemoteAddr.Port",
   379  						)); diff != "" {
   380  							t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
   381  						}
   382  						if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" {
   383  							t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
   384  						}
   385  						if t.Failed() {
   386  							t.FailNow()
   387  						}
   388  					}
   389  
   390  					serverPayload := []byte{1, 2, 3, 4}
   391  					{
   392  						var r bytes.Reader
   393  						r.Reset(serverPayload)
   394  						wOpts := tcpip.WriteOptions{
   395  							To: &clientAddr,
   396  						}
   397  						if n, err := server.Write(&r, wOpts); err != nil {
   398  							t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
   399  						} else if n != int64(len(serverPayload)) {
   400  							t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
   401  						}
   402  					}
   403  
   404  					// Wait for the client endpoint to become readable.
   405  					<-clientCH
   406  
   407  					readBuf.Reset()
   408  					if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
   409  						t.Fatalf("client.Read(_): %s", err)
   410  					} else {
   411  						if diff := cmp.Diff(tcpip.ReadResult{
   412  							Count:      readBuf.Len(),
   413  							Total:      readBuf.Len(),
   414  							RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
   415  						}, read, checker.IgnoreCmpPath(
   416  							"ControlMessages",
   417  							"RemoteAddr.NIC",
   418  							"RemoteAddr.Port",
   419  						)); diff != "" {
   420  							t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
   421  						}
   422  						if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" {
   423  							t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
   424  						}
   425  						if t.Failed() {
   426  							t.FailNow()
   427  						}
   428  					}
   429  				})
   430  			}
   431  		})
   432  	}
   433  }