gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/tests/integration/route_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package route_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"gvisor.dev/gvisor/pkg/tcpip"
    24  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    25  	"gvisor.dev/gvisor/pkg/tcpip/header"
    26  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    27  	"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
    28  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    29  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    30  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    31  	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
    32  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    33  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    34  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    35  	"gvisor.dev/gvisor/pkg/waiter"
    36  )
    37  
    38  // TestLocalPing tests pinging a remote that is local the stack.
    39  //
    40  // This tests that a local route is created and packets do not leave the stack.
    41  func TestLocalPing(t *testing.T) {
    42  	const (
    43  		nicID = 1
    44  
    45  		// icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
    46  		// request/reply packets.
    47  		icmpDataOffset = 8
    48  	)
    49  	ipv4Loopback := tcpip.AddressWithPrefix{
    50  		Address:   testutil.MustParse4("127.0.0.1"),
    51  		PrefixLen: 8,
    52  	}
    53  
    54  	channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
    55  	channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
    56  		channelEP := e.(*channel.Endpoint)
    57  		if n := channelEP.Drain(); n != 0 {
    58  			t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
    59  		}
    60  	}
    61  
    62  	ipv4ICMPBuf := func(t *testing.T) []byte {
    63  		data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
    64  		hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
    65  		hdr.SetType(header.ICMPv4Echo)
    66  		if n := copy(hdr.Payload(), data[:]); n != len(data) {
    67  			t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
    68  		}
    69  		return hdr
    70  	}
    71  
    72  	ipv6ICMPBuf := func(t *testing.T) []byte {
    73  		data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
    74  		hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
    75  		hdr.SetType(header.ICMPv6EchoRequest)
    76  		if n := copy(hdr.Payload(), data[:]); n != len(data) {
    77  			t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
    78  		}
    79  		return hdr
    80  	}
    81  
    82  	tests := []struct {
    83  		name               string
    84  		transProto         tcpip.TransportProtocolNumber
    85  		netProto           tcpip.NetworkProtocolNumber
    86  		linkEndpoint       func() stack.LinkEndpoint
    87  		localAddr          tcpip.AddressWithPrefix
    88  		icmpBuf            func(*testing.T) []byte
    89  		expectedConnectErr tcpip.Error
    90  		checkLinkEndpoint  func(t *testing.T, e stack.LinkEndpoint)
    91  	}{
    92  		{
    93  			name:              "IPv4 loopback",
    94  			transProto:        icmp.ProtocolNumber4,
    95  			netProto:          ipv4.ProtocolNumber,
    96  			linkEndpoint:      loopback.New,
    97  			localAddr:         ipv4Loopback,
    98  			icmpBuf:           ipv4ICMPBuf,
    99  			checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
   100  		},
   101  		{
   102  			name:              "IPv6 loopback",
   103  			transProto:        icmp.ProtocolNumber6,
   104  			netProto:          ipv6.ProtocolNumber,
   105  			linkEndpoint:      loopback.New,
   106  			localAddr:         header.IPv6Loopback.WithPrefix(),
   107  			icmpBuf:           ipv6ICMPBuf,
   108  			checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
   109  		},
   110  		{
   111  			name:              "IPv4 non-loopback",
   112  			transProto:        icmp.ProtocolNumber4,
   113  			netProto:          ipv4.ProtocolNumber,
   114  			linkEndpoint:      channelEP,
   115  			localAddr:         utils.Ipv4Addr,
   116  			icmpBuf:           ipv4ICMPBuf,
   117  			checkLinkEndpoint: channelEPCheck,
   118  		},
   119  		{
   120  			name:              "IPv6 non-loopback",
   121  			transProto:        icmp.ProtocolNumber6,
   122  			netProto:          ipv6.ProtocolNumber,
   123  			linkEndpoint:      channelEP,
   124  			localAddr:         utils.Ipv6Addr,
   125  			icmpBuf:           ipv6ICMPBuf,
   126  			checkLinkEndpoint: channelEPCheck,
   127  		},
   128  		{
   129  			name:               "IPv4 loopback without local address",
   130  			transProto:         icmp.ProtocolNumber4,
   131  			netProto:           ipv4.ProtocolNumber,
   132  			linkEndpoint:       loopback.New,
   133  			icmpBuf:            ipv4ICMPBuf,
   134  			expectedConnectErr: &tcpip.ErrHostUnreachable{},
   135  			checkLinkEndpoint:  func(*testing.T, stack.LinkEndpoint) {},
   136  		},
   137  		{
   138  			name:               "IPv6 loopback without local address",
   139  			transProto:         icmp.ProtocolNumber6,
   140  			netProto:           ipv6.ProtocolNumber,
   141  			linkEndpoint:       loopback.New,
   142  			icmpBuf:            ipv6ICMPBuf,
   143  			expectedConnectErr: &tcpip.ErrHostUnreachable{},
   144  			checkLinkEndpoint:  func(*testing.T, stack.LinkEndpoint) {},
   145  		},
   146  		{
   147  			name:               "IPv4 non-loopback without local address",
   148  			transProto:         icmp.ProtocolNumber4,
   149  			netProto:           ipv4.ProtocolNumber,
   150  			linkEndpoint:       channelEP,
   151  			icmpBuf:            ipv4ICMPBuf,
   152  			expectedConnectErr: &tcpip.ErrHostUnreachable{},
   153  			checkLinkEndpoint:  channelEPCheck,
   154  		},
   155  		{
   156  			name:               "IPv6 non-loopback without local address",
   157  			transProto:         icmp.ProtocolNumber6,
   158  			netProto:           ipv6.ProtocolNumber,
   159  			linkEndpoint:       channelEP,
   160  			icmpBuf:            ipv6ICMPBuf,
   161  			expectedConnectErr: &tcpip.ErrHostUnreachable{},
   162  			checkLinkEndpoint:  channelEPCheck,
   163  		},
   164  	}
   165  
   166  	for _, test := range tests {
   167  		t.Run(test.name, func(t *testing.T) {
   168  			for _, allowExternalLoopback := range []bool{true, false} {
   169  				t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) {
   170  					s := stack.New(stack.Options{
   171  						NetworkProtocols: []stack.NetworkProtocolFactory{
   172  							ipv4.NewProtocolWithOptions(ipv4.Options{
   173  								AllowExternalLoopbackTraffic: allowExternalLoopback,
   174  							}),
   175  							ipv6.NewProtocolWithOptions(ipv6.Options{
   176  								AllowExternalLoopbackTraffic: allowExternalLoopback,
   177  							}),
   178  						},
   179  						TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
   180  						HandleLocal:        true,
   181  					})
   182  					defer s.Destroy()
   183  					e := test.linkEndpoint()
   184  					if err := s.CreateNIC(nicID, e); err != nil {
   185  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   186  					}
   187  
   188  					if test.localAddr.Address.Len() != 0 {
   189  						protocolAddr := tcpip.ProtocolAddress{
   190  							Protocol:          test.netProto,
   191  							AddressWithPrefix: test.localAddr,
   192  						}
   193  						if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   194  							t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
   195  						}
   196  					}
   197  
   198  					var wq waiter.Queue
   199  					we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   200  					wq.EventRegister(&we)
   201  					ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
   202  					if err != nil {
   203  						t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
   204  					}
   205  					defer ep.Close()
   206  
   207  					connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
   208  					if err := ep.Connect(connAddr); err != test.expectedConnectErr {
   209  						t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
   210  					}
   211  
   212  					if test.expectedConnectErr != nil {
   213  						return
   214  					}
   215  
   216  					var r bytes.Reader
   217  					payload := test.icmpBuf(t)
   218  					r.Reset(payload)
   219  					var wOpts tcpip.WriteOptions
   220  					if n, err := ep.Write(&r, wOpts); err != nil {
   221  						t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
   222  					} else if n != int64(len(payload)) {
   223  						t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
   224  					}
   225  
   226  					// Wait for the endpoint to become readable.
   227  					<-ch
   228  
   229  					var w bytes.Buffer
   230  					rr, err := ep.Read(&w, tcpip.ReadOptions{
   231  						NeedRemoteAddr: true,
   232  					})
   233  					if err != nil {
   234  						t.Fatalf("ep.Read(...): %s", err)
   235  					}
   236  					if diff := cmp.Diff(w.Bytes()[icmpDataOffset:], payload[icmpDataOffset:]); diff != "" {
   237  						t.Errorf("received data mismatch (-want +got):\n%s", diff)
   238  					}
   239  					if rr.RemoteAddr.Addr != test.localAddr.Address {
   240  						t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
   241  					}
   242  
   243  					test.checkLinkEndpoint(t, e)
   244  				})
   245  			}
   246  		})
   247  	}
   248  }
   249  
   250  // TestLocalUDP tests sending UDP packets between two endpoints that are local
   251  // to the stack.
   252  //
   253  // This tests that that packets never leave the stack and the addresses
   254  // used when sending a packet.
   255  func TestLocalUDP(t *testing.T) {
   256  	const (
   257  		nicID = 1
   258  	)
   259  
   260  	tests := []struct {
   261  		name             string
   262  		canBePrimaryAddr tcpip.ProtocolAddress
   263  		firstPrimaryAddr tcpip.ProtocolAddress
   264  	}{
   265  		{
   266  			name:             "IPv4",
   267  			canBePrimaryAddr: utils.Ipv4Addr1,
   268  			firstPrimaryAddr: utils.Ipv4Addr2,
   269  		},
   270  		{
   271  			name:             "IPv6",
   272  			canBePrimaryAddr: utils.Ipv6Addr1,
   273  			firstPrimaryAddr: utils.Ipv6Addr2,
   274  		},
   275  	}
   276  
   277  	subTests := []struct {
   278  		name             string
   279  		addAddress       bool
   280  		expectedWriteErr tcpip.Error
   281  	}{
   282  		{
   283  			name:             "Unassigned local address",
   284  			addAddress:       false,
   285  			expectedWriteErr: &tcpip.ErrHostUnreachable{},
   286  		},
   287  		{
   288  			name:             "Assigned local address",
   289  			addAddress:       true,
   290  			expectedWriteErr: nil,
   291  		},
   292  	}
   293  
   294  	for _, test := range tests {
   295  		t.Run(test.name, func(t *testing.T) {
   296  			for _, subTest := range subTests {
   297  				t.Run(subTest.name, func(t *testing.T) {
   298  					stackOpts := stack.Options{
   299  						NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
   300  						TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   301  						HandleLocal:        true,
   302  					}
   303  
   304  					s := stack.New(stackOpts)
   305  					defer s.Destroy()
   306  					ep := channel.New(1, header.IPv6MinimumMTU, "")
   307  
   308  					if err := s.CreateNIC(nicID, ep); err != nil {
   309  						t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
   310  					}
   311  
   312  					if subTest.addAddress {
   313  						if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
   314  							t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
   315  						}
   316  						properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
   317  						if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
   318  							t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
   319  						}
   320  					}
   321  
   322  					var serverWQ waiter.Queue
   323  					serverWE, serverCH := waiter.NewChannelEntry(waiter.ReadableEvents)
   324  					serverWQ.EventRegister(&serverWE)
   325  					server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
   326  					if err != nil {
   327  						t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
   328  					}
   329  					defer server.Close()
   330  
   331  					bindAddr := tcpip.FullAddress{Port: 80}
   332  					if err := server.Bind(bindAddr); err != nil {
   333  						t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
   334  					}
   335  
   336  					var clientWQ waiter.Queue
   337  					clientWE, clientCH := waiter.NewChannelEntry(waiter.ReadableEvents)
   338  					clientWQ.EventRegister(&clientWE)
   339  					client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
   340  					if err != nil {
   341  						t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
   342  					}
   343  					defer client.Close()
   344  
   345  					serverAddr := tcpip.FullAddress{
   346  						Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
   347  						Port: 80,
   348  					}
   349  
   350  					clientPayload := []byte{1, 2, 3, 4}
   351  					{
   352  						var r bytes.Reader
   353  						r.Reset(clientPayload)
   354  						wOpts := tcpip.WriteOptions{
   355  							To: &serverAddr,
   356  						}
   357  						if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
   358  							t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
   359  						} else if subTest.expectedWriteErr != nil {
   360  							// Nothing else to test if we expected not to be able to send the
   361  							// UDP packet.
   362  							return
   363  						} else if n != int64(len(clientPayload)) {
   364  							t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload))
   365  						}
   366  					}
   367  
   368  					// Wait for the server endpoint to become readable.
   369  					<-serverCH
   370  
   371  					var clientAddr tcpip.FullAddress
   372  					var readBuf bytes.Buffer
   373  					if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
   374  						t.Fatalf("server.Read(_): %s", err)
   375  					} else {
   376  						clientAddr = read.RemoteAddr
   377  
   378  						if diff := cmp.Diff(tcpip.ReadResult{
   379  							Count: readBuf.Len(),
   380  							Total: readBuf.Len(),
   381  							RemoteAddr: tcpip.FullAddress{
   382  								Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
   383  							},
   384  						}, read, checker.IgnoreCmpPath(
   385  							"ControlMessages",
   386  							"RemoteAddr.NIC",
   387  							"RemoteAddr.Port",
   388  						)); diff != "" {
   389  							t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
   390  						}
   391  						if diff := cmp.Diff(clientPayload, readBuf.Bytes()); diff != "" {
   392  							t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
   393  						}
   394  						if t.Failed() {
   395  							t.FailNow()
   396  						}
   397  					}
   398  
   399  					serverPayload := []byte{1, 2, 3, 4}
   400  					{
   401  						var r bytes.Reader
   402  						r.Reset(serverPayload)
   403  						wOpts := tcpip.WriteOptions{
   404  							To: &clientAddr,
   405  						}
   406  						if n, err := server.Write(&r, wOpts); err != nil {
   407  							t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
   408  						} else if n != int64(len(serverPayload)) {
   409  							t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
   410  						}
   411  					}
   412  
   413  					// Wait for the client endpoint to become readable.
   414  					<-clientCH
   415  
   416  					readBuf.Reset()
   417  					if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
   418  						t.Fatalf("client.Read(_): %s", err)
   419  					} else {
   420  						if diff := cmp.Diff(tcpip.ReadResult{
   421  							Count:      readBuf.Len(),
   422  							Total:      readBuf.Len(),
   423  							RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
   424  						}, read, checker.IgnoreCmpPath(
   425  							"ControlMessages",
   426  							"RemoteAddr.NIC",
   427  							"RemoteAddr.Port",
   428  						)); diff != "" {
   429  							t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
   430  						}
   431  						if diff := cmp.Diff(serverPayload, readBuf.Bytes()); diff != "" {
   432  							t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
   433  						}
   434  						if t.Failed() {
   435  							t.FailNow()
   436  						}
   437  					}
   438  				})
   439  			}
   440  		})
   441  	}
   442  }