gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/syscalls/linux/socket_bind_to_device_distribution.cc (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <arpa/inet.h>
    16  #include <linux/if_tun.h>
    17  #include <net/if.h>
    18  #include <netinet/in.h>
    19  #include <sys/ioctl.h>
    20  #include <sys/socket.h>
    21  #include <sys/types.h>
    22  #include <sys/un.h>
    23  
    24  #include <atomic>
    25  #include <cstdio>
    26  #include <cstring>
    27  #include <map>
    28  #include <memory>
    29  #include <unordered_map>
    30  #include <unordered_set>
    31  #include <utility>
    32  #include <vector>
    33  
    34  #include "gmock/gmock.h"
    35  #include "gtest/gtest.h"
    36  #include "test/syscalls/linux/ip_socket_test_util.h"
    37  #include "test/syscalls/linux/socket_bind_to_device_util.h"
    38  #include "test/util/capability_util.h"
    39  #include "test/util/socket_util.h"
    40  #include "test/util/test_util.h"
    41  #include "test/util/thread_util.h"
    42  
    43  namespace gvisor {
    44  namespace testing {
    45  
    46  using std::string;
    47  using std::vector;
    48  
    49  struct EndpointConfig {
    50    std::string bind_to_device;
    51    double expected_ratio;
    52  };
    53  
    54  struct DistributionTestCase {
    55    std::string name;
    56    std::vector<EndpointConfig> endpoints;
    57  };
    58  
    59  struct ListenerConnector {
    60    TestAddress listener;
    61    TestAddress connector;
    62  };
    63  
    64  // Test fixture for SO_BINDTODEVICE tests the distribution of packets received
    65  // with varying SO_BINDTODEVICE settings.
    66  class BindToDeviceDistributionTest
    67      : public ::testing::TestWithParam<
    68            ::testing::tuple<ListenerConnector, DistributionTestCase>> {
    69   protected:
    70    void SetUp() override {
    71      printf("Testing case: %s, listener=%s, connector=%s\n",
    72             ::testing::get<1>(GetParam()).name.c_str(),
    73             ::testing::get<0>(GetParam()).listener.description.c_str(),
    74             ::testing::get<0>(GetParam()).connector.description.c_str());
    75      ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
    76          << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
    77    }
    78  };
    79  
    80  // Binds sockets to different devices and then creates many TCP connections.
    81  // Checks that the distribution of connections received on the sockets matches
    82  // the expectation.
    83  TEST_P(BindToDeviceDistributionTest, Tcp) {
    84    auto const& [listener_connector, test] = GetParam();
    85  
    86    TestAddress const& listener = listener_connector.listener;
    87    TestAddress const& connector = listener_connector.connector;
    88    sockaddr_storage listen_addr = listener.addr;
    89    sockaddr_storage conn_addr = connector.addr;
    90  
    91    auto interface_names = GetInterfaceNames();
    92  
    93    // Create the listening sockets.
    94    std::vector<FileDescriptor> listener_fds;
    95    std::vector<std::unique_ptr<Tunnel>> all_tunnels;
    96    for (auto const& endpoint : test.endpoints) {
    97      if (!endpoint.bind_to_device.empty() &&
    98          interface_names.find(endpoint.bind_to_device) ==
    99              interface_names.end()) {
   100        all_tunnels.push_back(
   101            ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
   102        interface_names.insert(endpoint.bind_to_device);
   103      }
   104  
   105      listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE(
   106          Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)));
   107      int fd = listener_fds.back().get();
   108  
   109      ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
   110                             sizeof(kSockOptOn)),
   111                  SyscallSucceeds());
   112      ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
   113                             endpoint.bind_to_device.c_str(),
   114                             endpoint.bind_to_device.size() + 1),
   115                  SyscallSucceeds());
   116      ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
   117                  SyscallSucceeds());
   118      ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
   119  
   120      // On the first bind we need to determine which port was bound.
   121      if (listener_fds.size() > 1) {
   122        continue;
   123      }
   124  
   125      // Get the port bound by the listening socket.
   126      socklen_t addrlen = listener.addr_len;
   127      ASSERT_THAT(
   128          getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
   129          SyscallSucceeds());
   130      uint16_t const port =
   131          ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
   132      ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
   133    }
   134  
   135    constexpr int kConnectAttempts = 10000;
   136    std::atomic<int> connects_received(0);
   137    std::vector<int> accept_counts(listener_fds.size(), 0);
   138    std::vector<std::unique_ptr<ScopedThread>> listen_threads(
   139        listener_fds.size());
   140  
   141    for (size_t i = 0; i < listener_fds.size(); i++) {
   142      listen_threads[i] = std::make_unique<ScopedThread>(
   143          [&listener_fds, &accept_counts, &connects_received, i,
   144           kConnectAttempts]() {
   145            do {
   146              auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
   147              if (!fd.ok()) {
   148                // Another thread has shutdown our read side causing the accept to
   149                // fail.
   150                ASSERT_GE(connects_received, kConnectAttempts)
   151                    << "errno = " << fd.error();
   152                return;
   153              }
   154              // Receive some data from a socket to be sure that the connect()
   155              // system call has been completed on another side.
   156              // Do a short read and then close the socket to trigger a RST. This
   157              // ensures that both ends of the connection are cleaned up and no
   158              // goroutines hang around in TIME-WAIT. We do this so that this test
   159              // does not timeout under gotsan runs where lots of goroutines can
   160              // cause the test to use absurd amounts of memory.
   161              //
   162              // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
   163              uint16_t data;
   164              EXPECT_THAT(
   165                  RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
   166                  SyscallSucceedsWithValue(sizeof(data)));
   167              accept_counts[i]++;
   168            } while (++connects_received < kConnectAttempts);
   169  
   170            // Shutdown all sockets to wake up other threads.
   171            for (auto const& listener_fd : listener_fds) {
   172              shutdown(listener_fd.get(), SHUT_RDWR);
   173            }
   174          });
   175    }
   176  
   177    for (int32_t i = 0; i < kConnectAttempts; i++) {
   178      const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
   179          Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
   180      ASSERT_THAT(RetryEINTR(connect)(fd.get(), AsSockAddr(&conn_addr),
   181                                      connector.addr_len),
   182                  SyscallSucceeds());
   183  
   184      EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
   185                  SyscallSucceedsWithValue(sizeof(i)));
   186    }
   187  
   188    // Join threads to be sure that all connections have been counted.
   189    for (auto const& listen_thread : listen_threads) {
   190      listen_thread->Join();
   191    }
   192    // Check that connections are distributed correctly among listening sockets.
   193    for (size_t i = 0; i < accept_counts.size(); i++) {
   194      EXPECT_THAT(
   195          accept_counts[i],
   196          EquivalentWithin(static_cast<int>(kConnectAttempts *
   197                                            test.endpoints[i].expected_ratio),
   198                           0.10))
   199          << "endpoint " << i << " got the wrong number of packets";
   200    }
   201  }
   202  
   203  // Binds sockets to different devices and then sends many UDP packets.  Checks
   204  // that the distribution of packets received on the sockets matches the
   205  // expectation.
   206  TEST_P(BindToDeviceDistributionTest, Udp) {
   207    auto const& [listener_connector, test] = GetParam();
   208  
   209    TestAddress const& listener = listener_connector.listener;
   210    TestAddress const& connector = listener_connector.connector;
   211    sockaddr_storage listen_addr = listener.addr;
   212    sockaddr_storage conn_addr = connector.addr;
   213  
   214    auto interface_names = GetInterfaceNames();
   215  
   216    // Create the listening socket.
   217    std::vector<FileDescriptor> listener_fds;
   218    std::vector<std::unique_ptr<Tunnel>> all_tunnels;
   219    for (auto const& endpoint : test.endpoints) {
   220      if (!endpoint.bind_to_device.empty() &&
   221          interface_names.find(endpoint.bind_to_device) ==
   222              interface_names.end()) {
   223        all_tunnels.push_back(
   224            ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
   225        interface_names.insert(endpoint.bind_to_device);
   226      }
   227  
   228      listener_fds.push_back(
   229          ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)));
   230      int fd = listener_fds.back().get();
   231  
   232      ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
   233                             sizeof(kSockOptOn)),
   234                  SyscallSucceeds());
   235      ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
   236                             endpoint.bind_to_device.c_str(),
   237                             endpoint.bind_to_device.size() + 1),
   238                  SyscallSucceeds());
   239      ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
   240                  SyscallSucceeds());
   241  
   242      // On the first bind we need to determine which port was bound.
   243      if (listener_fds.size() > 1) {
   244        continue;
   245      }
   246  
   247      // Get the port bound by the listening socket.
   248      socklen_t addrlen = listener.addr_len;
   249      ASSERT_THAT(
   250          getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
   251          SyscallSucceeds());
   252      uint16_t const port =
   253          ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
   254      ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
   255      ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
   256    }
   257  
   258    constexpr int kConnectAttempts = 10000;
   259    std::atomic<int> packets_received(0);
   260    std::vector<int> packets_per_socket(listener_fds.size(), 0);
   261    std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
   262        listener_fds.size());
   263  
   264    for (size_t i = 0; i < listener_fds.size(); i++) {
   265      receiver_threads[i] = std::make_unique<ScopedThread>(
   266          [&listener_fds, &packets_per_socket, &packets_received, i]() {
   267            do {
   268              struct sockaddr_storage addr = {};
   269              socklen_t addrlen = sizeof(addr);
   270              int data;
   271  
   272              auto ret =
   273                  RetryEINTR(recvfrom)(listener_fds[i].get(), &data, sizeof(data),
   274                                       0, AsSockAddr(&addr), &addrlen);
   275  
   276              if (packets_received < kConnectAttempts) {
   277                ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
   278              }
   279  
   280              if (ret != sizeof(data)) {
   281                // Another thread may have shutdown our read side causing the
   282                // recvfrom to fail.
   283                break;
   284              }
   285  
   286              packets_received++;
   287              packets_per_socket[i]++;
   288  
   289              // A response is required to synchronize with the main thread,
   290              // otherwise the main thread can send more than can fit into receive
   291              // queues.
   292              EXPECT_THAT(
   293                  RetryEINTR(sendto)(listener_fds[i].get(), &data, sizeof(data),
   294                                     0, AsSockAddr(&addr), addrlen),
   295                  SyscallSucceedsWithValue(sizeof(data)));
   296            } while (packets_received < kConnectAttempts);
   297  
   298            // Shutdown all sockets to wake up other threads.
   299            for (auto const& listener_fd : listener_fds) {
   300              shutdown(listener_fd.get(), SHUT_RDWR);
   301            }
   302          });
   303    }
   304  
   305    for (int i = 0; i < kConnectAttempts; i++) {
   306      FileDescriptor const fd =
   307          ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
   308      EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
   309                                     AsSockAddr(&conn_addr), connector.addr_len),
   310                  SyscallSucceedsWithValue(sizeof(i)));
   311      int data;
   312      EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
   313                  SyscallSucceedsWithValue(sizeof(data)));
   314    }
   315  
   316    // Join threads to be sure that all connections have been counted.
   317    for (auto const& receiver_thread : receiver_threads) {
   318      receiver_thread->Join();
   319    }
   320    // Check that packets are distributed correctly among listening sockets.
   321    for (size_t i = 0; i < packets_per_socket.size(); i++) {
   322      EXPECT_THAT(
   323          packets_per_socket[i],
   324          EquivalentWithin(static_cast<int>(kConnectAttempts *
   325                                            test.endpoints[i].expected_ratio),
   326                           0.10))
   327          << "endpoint " << i << " got the wrong number of packets";
   328    }
   329  }
   330  
   331  std::vector<DistributionTestCase> GetDistributionTestCases() {
   332    return std::vector<DistributionTestCase>{
   333        {"Even distribution among sockets not bound to device",
   334         {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}},
   335        {"Sockets bound to other interfaces get no packets",
   336         {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}},
   337        {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}},
   338        {"Even distribution among sockets bound to device",
   339         {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}},
   340    };
   341  }
   342  
   343  INSTANTIATE_TEST_SUITE_P(
   344      BindToDeviceTest, BindToDeviceDistributionTest,
   345      ::testing::Combine(::testing::Values(
   346                             // Listeners bound to IPv4 addresses refuse
   347                             // connections using IPv6 addresses.
   348                             ListenerConnector{V4Any(), V4Loopback()},
   349                             ListenerConnector{V4Loopback(), V4MappedLoopback()}),
   350                         ::testing::ValuesIn(GetDistributionTestCases())));
   351  
   352  }  // namespace testing
   353  }  // namespace gvisor