gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/syscalls/linux/socket_netlink_util.cc (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "test/syscalls/linux/socket_netlink_util.h" 16 17 #include <linux/if_arp.h> 18 #include <linux/netlink.h> 19 #include <linux/rtnetlink.h> 20 #include <sys/socket.h> 21 22 #include <vector> 23 24 #include "absl/strings/str_cat.h" 25 #include "test/util/socket_util.h" 26 27 namespace gvisor { 28 namespace testing { 29 30 PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) { 31 FileDescriptor fd; 32 ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); 33 34 struct sockaddr_nl addr = {}; 35 addr.nl_family = AF_NETLINK; 36 37 RETURN_ERROR_IF_SYSCALL_FAIL( 38 bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); 39 MaybeSave(); 40 41 return std::move(fd); 42 } 43 44 PosixErrorOr<uint32_t> NetlinkPortID(int fd) { 45 struct sockaddr_nl addr; 46 socklen_t addrlen = sizeof(addr); 47 48 RETURN_ERROR_IF_SYSCALL_FAIL( 49 getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen)); 50 MaybeSave(); 51 52 return static_cast<uint32_t>(addr.nl_pid); 53 } 54 55 PosixError NetlinkRequestResponse( 56 const FileDescriptor& fd, void* request, size_t len, 57 const std::function<void(const struct nlmsghdr* hdr)>& fn, 58 bool expect_nlmsgerr) { 59 struct iovec iov = {}; 60 iov.iov_base = request; 61 iov.iov_len = len; 62 63 struct msghdr msg = {}; 64 msg.msg_iov = &iov; 65 msg.msg_iovlen = 1; 66 // No destination required; it defaults to pid 0, the kernel. 67 68 RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0)); 69 70 return NetlinkResponse(fd, fn, expect_nlmsgerr); 71 } 72 73 PosixError NetlinkResponse( 74 const FileDescriptor& fd, 75 const std::function<void(const struct nlmsghdr* hdr)>& fn, 76 bool expect_nlmsgerr) { 77 constexpr size_t kBufferSize = 4096; 78 std::vector<char> buf(kBufferSize); 79 struct iovec iov = {}; 80 iov.iov_base = buf.data(); 81 iov.iov_len = buf.size(); 82 struct msghdr msg = {}; 83 msg.msg_iov = &iov; 84 msg.msg_iovlen = 1; 85 86 // If NLM_F_MULTI is set, response is a series of messages that ends with a 87 // NLMSG_DONE message. 88 int type = -1; 89 int flags = 0; 90 do { 91 int len; 92 RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); 93 94 // We don't bother with the complexity of dealing with truncated messages. 95 // We must allocate a large enough buffer up front. 96 if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) { 97 return PosixError(EIO, 98 absl::StrCat("Received truncated message with flags: ", 99 msg.msg_flags)); 100 } 101 102 for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); 103 NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { 104 fn(hdr); 105 flags = hdr->nlmsg_flags; 106 type = hdr->nlmsg_type; 107 // Done should include an integer payload for dump_done_errno. 108 // See net/netlink/af_netlink.c:netlink_dump 109 // Some tools like the 'ip' tool check the minimum length of the 110 // NLMSG_DONE message. 111 if (type == NLMSG_DONE) { 112 EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); 113 } 114 } 115 } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR); 116 117 if (expect_nlmsgerr) { 118 EXPECT_EQ(type, NLMSG_ERROR); 119 } else if (flags & NLM_F_MULTI) { 120 EXPECT_EQ(type, NLMSG_DONE); 121 } 122 return NoError(); 123 } 124 125 PosixError NetlinkRequestResponseSingle( 126 const FileDescriptor& fd, void* request, size_t len, 127 const std::function<void(const struct nlmsghdr* hdr)>& fn) { 128 struct iovec iov = {}; 129 iov.iov_base = request; 130 iov.iov_len = len; 131 132 struct msghdr msg = {}; 133 msg.msg_iov = &iov; 134 msg.msg_iovlen = 1; 135 // No destination required; it defaults to pid 0, the kernel. 136 137 RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0)); 138 139 constexpr size_t kBufferSize = 4096; 140 std::vector<char> buf(kBufferSize); 141 iov.iov_base = buf.data(); 142 iov.iov_len = buf.size(); 143 144 int ret; 145 RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); 146 147 // We don't bother with the complexity of dealing with truncated messages. 148 // We must allocate a large enough buffer up front. 149 if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) { 150 return PosixError( 151 EIO, 152 absl::StrCat("Received truncated message with flags: ", msg.msg_flags)); 153 } 154 155 for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); 156 NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) { 157 fn(hdr); 158 } 159 160 return NoError(); 161 } 162 163 PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, 164 void* request, size_t len) { 165 // Dummy negative number for no error message received. 166 // We won't get a negative error number so there will be no confusion. 167 int err = -42; 168 RETURN_IF_ERRNO(NetlinkRequestResponse( 169 fd, request, len, 170 [&](const struct nlmsghdr* hdr) { 171 EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); 172 EXPECT_EQ(hdr->nlmsg_seq, seq); 173 EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); 174 175 const struct nlmsgerr* msg = 176 reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr)); 177 err = -msg->error; 178 }, 179 true)); 180 return PosixError(err); 181 } 182 183 const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, 184 const struct ifinfomsg* msg, int16_t attr) { 185 const int ifi_space = NLMSG_SPACE(sizeof(*msg)); 186 int attrlen = hdr->nlmsg_len - ifi_space; 187 const struct rtattr* rta = reinterpret_cast<const struct rtattr*>( 188 reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space)); 189 for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { 190 if (rta->rta_type == attr) { 191 return rta; 192 } 193 } 194 return nullptr; 195 } 196 197 } // namespace testing 198 } // namespace gvisor