github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/syscalls/linux/socket_bind_to_device_distribution.cc (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <arpa/inet.h> 16 #include <linux/if_tun.h> 17 #include <net/if.h> 18 #include <netinet/in.h> 19 #include <sys/ioctl.h> 20 #include <sys/socket.h> 21 #include <sys/types.h> 22 #include <sys/un.h> 23 24 #include <atomic> 25 #include <cstdio> 26 #include <cstring> 27 #include <map> 28 #include <memory> 29 #include <unordered_map> 30 #include <unordered_set> 31 #include <utility> 32 #include <vector> 33 34 #include "gmock/gmock.h" 35 #include "gtest/gtest.h" 36 #include "test/syscalls/linux/ip_socket_test_util.h" 37 #include "test/syscalls/linux/socket_bind_to_device_util.h" 38 #include "test/syscalls/linux/socket_test_util.h" 39 #include "test/util/capability_util.h" 40 #include "test/util/test_util.h" 41 #include "test/util/thread_util.h" 42 43 namespace gvisor { 44 namespace testing { 45 46 using std::string; 47 using std::vector; 48 49 struct EndpointConfig { 50 std::string bind_to_device; 51 double expected_ratio; 52 }; 53 54 struct DistributionTestCase { 55 std::string name; 56 std::vector<EndpointConfig> endpoints; 57 }; 58 59 struct ListenerConnector { 60 TestAddress listener; 61 TestAddress connector; 62 }; 63 64 // Test fixture for SO_BINDTODEVICE tests the distribution of packets received 65 // with varying SO_BINDTODEVICE settings. 66 class BindToDeviceDistributionTest 67 : public ::testing::TestWithParam< 68 ::testing::tuple<ListenerConnector, DistributionTestCase>> { 69 protected: 70 void SetUp() override { 71 printf("Testing case: %s, listener=%s, connector=%s\n", 72 ::testing::get<1>(GetParam()).name.c_str(), 73 ::testing::get<0>(GetParam()).listener.description.c_str(), 74 ::testing::get<0>(GetParam()).connector.description.c_str()); 75 ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) 76 << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; 77 } 78 }; 79 80 // Binds sockets to different devices and then creates many TCP connections. 81 // Checks that the distribution of connections received on the sockets matches 82 // the expectation. 83 TEST_P(BindToDeviceDistributionTest, Tcp) { 84 auto const& [listener_connector, test] = GetParam(); 85 86 TestAddress const& listener = listener_connector.listener; 87 TestAddress const& connector = listener_connector.connector; 88 sockaddr_storage listen_addr = listener.addr; 89 sockaddr_storage conn_addr = connector.addr; 90 91 auto interface_names = GetInterfaceNames(); 92 93 // Create the listening sockets. 94 std::vector<FileDescriptor> listener_fds; 95 std::vector<std::unique_ptr<Tunnel>> all_tunnels; 96 for (auto const& endpoint : test.endpoints) { 97 if (!endpoint.bind_to_device.empty() && 98 interface_names.find(endpoint.bind_to_device) == 99 interface_names.end()) { 100 all_tunnels.push_back( 101 ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); 102 interface_names.insert(endpoint.bind_to_device); 103 } 104 105 listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE( 106 Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP))); 107 int fd = listener_fds.back().get(); 108 109 ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, 110 sizeof(kSockOptOn)), 111 SyscallSucceeds()); 112 ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, 113 endpoint.bind_to_device.c_str(), 114 endpoint.bind_to_device.size() + 1), 115 SyscallSucceeds()); 116 ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), 117 SyscallSucceeds()); 118 ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); 119 120 // On the first bind we need to determine which port was bound. 121 if (listener_fds.size() > 1) { 122 continue; 123 } 124 125 // Get the port bound by the listening socket. 126 socklen_t addrlen = listener.addr_len; 127 ASSERT_THAT( 128 getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), 129 SyscallSucceeds()); 130 uint16_t const port = 131 ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); 132 ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); 133 } 134 135 constexpr int kConnectAttempts = 10000; 136 std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); 137 std::vector<int> accept_counts(listener_fds.size(), 0); 138 std::vector<std::unique_ptr<ScopedThread>> listen_threads( 139 listener_fds.size()); 140 141 for (size_t i = 0; i < listener_fds.size(); i++) { 142 listen_threads[i] = absl::make_unique<ScopedThread>( 143 [&listener_fds, &accept_counts, &connects_received, i, 144 kConnectAttempts]() { 145 do { 146 auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); 147 if (!fd.ok()) { 148 // Another thread has shutdown our read side causing the accept to 149 // fail. 150 ASSERT_GE(connects_received, kConnectAttempts) 151 << "errno = " << fd.error(); 152 return; 153 } 154 // Receive some data from a socket to be sure that the connect() 155 // system call has been completed on another side. 156 // Do a short read and then close the socket to trigger a RST. This 157 // ensures that both ends of the connection are cleaned up and no 158 // goroutines hang around in TIME-WAIT. We do this so that this test 159 // does not timeout under gotsan runs where lots of goroutines can 160 // cause the test to use absurd amounts of memory. 161 // 162 // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17 163 uint16_t data; 164 EXPECT_THAT( 165 RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), 166 SyscallSucceedsWithValue(sizeof(data))); 167 accept_counts[i]++; 168 } while (++connects_received < kConnectAttempts); 169 170 // Shutdown all sockets to wake up other threads. 171 for (auto const& listener_fd : listener_fds) { 172 shutdown(listener_fd.get(), SHUT_RDWR); 173 } 174 }); 175 } 176 177 for (int32_t i = 0; i < kConnectAttempts; i++) { 178 const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( 179 Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); 180 ASSERT_THAT(RetryEINTR(connect)(fd.get(), AsSockAddr(&conn_addr), 181 connector.addr_len), 182 SyscallSucceeds()); 183 184 EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), 185 SyscallSucceedsWithValue(sizeof(i))); 186 } 187 188 // Join threads to be sure that all connections have been counted. 189 for (auto const& listen_thread : listen_threads) { 190 listen_thread->Join(); 191 } 192 // Check that connections are distributed correctly among listening sockets. 193 for (size_t i = 0; i < accept_counts.size(); i++) { 194 EXPECT_THAT( 195 accept_counts[i], 196 EquivalentWithin(static_cast<int>(kConnectAttempts * 197 test.endpoints[i].expected_ratio), 198 0.10)) 199 << "endpoint " << i << " got the wrong number of packets"; 200 } 201 } 202 203 // Binds sockets to different devices and then sends many UDP packets. Checks 204 // that the distribution of packets received on the sockets matches the 205 // expectation. 206 TEST_P(BindToDeviceDistributionTest, Udp) { 207 auto const& [listener_connector, test] = GetParam(); 208 209 TestAddress const& listener = listener_connector.listener; 210 TestAddress const& connector = listener_connector.connector; 211 sockaddr_storage listen_addr = listener.addr; 212 sockaddr_storage conn_addr = connector.addr; 213 214 auto interface_names = GetInterfaceNames(); 215 216 // Create the listening socket. 217 std::vector<FileDescriptor> listener_fds; 218 std::vector<std::unique_ptr<Tunnel>> all_tunnels; 219 for (auto const& endpoint : test.endpoints) { 220 if (!endpoint.bind_to_device.empty() && 221 interface_names.find(endpoint.bind_to_device) == 222 interface_names.end()) { 223 all_tunnels.push_back( 224 ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); 225 interface_names.insert(endpoint.bind_to_device); 226 } 227 228 listener_fds.push_back( 229 ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0))); 230 int fd = listener_fds.back().get(); 231 232 ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, 233 sizeof(kSockOptOn)), 234 SyscallSucceeds()); 235 ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, 236 endpoint.bind_to_device.c_str(), 237 endpoint.bind_to_device.size() + 1), 238 SyscallSucceeds()); 239 ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), 240 SyscallSucceeds()); 241 242 // On the first bind we need to determine which port was bound. 243 if (listener_fds.size() > 1) { 244 continue; 245 } 246 247 // Get the port bound by the listening socket. 248 socklen_t addrlen = listener.addr_len; 249 ASSERT_THAT( 250 getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), 251 SyscallSucceeds()); 252 uint16_t const port = 253 ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); 254 ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); 255 ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); 256 } 257 258 constexpr int kConnectAttempts = 10000; 259 std::atomic<int> packets_received = ATOMIC_VAR_INIT(0); 260 std::vector<int> packets_per_socket(listener_fds.size(), 0); 261 std::vector<std::unique_ptr<ScopedThread>> receiver_threads( 262 listener_fds.size()); 263 264 for (size_t i = 0; i < listener_fds.size(); i++) { 265 receiver_threads[i] = absl::make_unique<ScopedThread>( 266 [&listener_fds, &packets_per_socket, &packets_received, i]() { 267 do { 268 struct sockaddr_storage addr = {}; 269 socklen_t addrlen = sizeof(addr); 270 int data; 271 272 auto ret = 273 RetryEINTR(recvfrom)(listener_fds[i].get(), &data, sizeof(data), 274 0, AsSockAddr(&addr), &addrlen); 275 276 if (packets_received < kConnectAttempts) { 277 ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); 278 } 279 280 if (ret != sizeof(data)) { 281 // Another thread may have shutdown our read side causing the 282 // recvfrom to fail. 283 break; 284 } 285 286 packets_received++; 287 packets_per_socket[i]++; 288 289 // A response is required to synchronize with the main thread, 290 // otherwise the main thread can send more than can fit into receive 291 // queues. 292 EXPECT_THAT( 293 RetryEINTR(sendto)(listener_fds[i].get(), &data, sizeof(data), 294 0, AsSockAddr(&addr), addrlen), 295 SyscallSucceedsWithValue(sizeof(data))); 296 } while (packets_received < kConnectAttempts); 297 298 // Shutdown all sockets to wake up other threads. 299 for (auto const& listener_fd : listener_fds) { 300 shutdown(listener_fd.get(), SHUT_RDWR); 301 } 302 }); 303 } 304 305 for (int i = 0; i < kConnectAttempts; i++) { 306 FileDescriptor const fd = 307 ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); 308 EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, 309 AsSockAddr(&conn_addr), connector.addr_len), 310 SyscallSucceedsWithValue(sizeof(i))); 311 int data; 312 EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), 313 SyscallSucceedsWithValue(sizeof(data))); 314 } 315 316 // Join threads to be sure that all connections have been counted. 317 for (auto const& receiver_thread : receiver_threads) { 318 receiver_thread->Join(); 319 } 320 // Check that packets are distributed correctly among listening sockets. 321 for (size_t i = 0; i < packets_per_socket.size(); i++) { 322 EXPECT_THAT( 323 packets_per_socket[i], 324 EquivalentWithin(static_cast<int>(kConnectAttempts * 325 test.endpoints[i].expected_ratio), 326 0.10)) 327 << "endpoint " << i << " got the wrong number of packets"; 328 } 329 } 330 331 std::vector<DistributionTestCase> GetDistributionTestCases() { 332 return std::vector<DistributionTestCase>{ 333 {"Even distribution among sockets not bound to device", 334 {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}}, 335 {"Sockets bound to other interfaces get no packets", 336 {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}}, 337 {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}}, 338 {"Even distribution among sockets bound to device", 339 {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}}, 340 }; 341 } 342 343 INSTANTIATE_TEST_SUITE_P( 344 BindToDeviceTest, BindToDeviceDistributionTest, 345 ::testing::Combine(::testing::Values( 346 // Listeners bound to IPv4 addresses refuse 347 // connections using IPv6 addresses. 348 ListenerConnector{V4Any(), V4Loopback()}, 349 ListenerConnector{V4Loopback(), V4MappedLoopback()}), 350 ::testing::ValuesIn(GetDistributionTestCases()))); 351 352 } // namespace testing 353 } // namespace gvisor