github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/syscalls/linux/socket_netlink_util.cc (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include "test/syscalls/linux/socket_netlink_util.h"
    16  
    17  #include <linux/if_arp.h>
    18  #include <linux/netlink.h>
    19  #include <linux/rtnetlink.h>
    20  #include <sys/socket.h>
    21  
    22  #include <vector>
    23  
    24  #include "absl/strings/str_cat.h"
    25  #include "test/syscalls/linux/socket_test_util.h"
    26  
    27  namespace gvisor {
    28  namespace testing {
    29  
    30  PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) {
    31    FileDescriptor fd;
    32    ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol));
    33  
    34    struct sockaddr_nl addr = {};
    35    addr.nl_family = AF_NETLINK;
    36  
    37    RETURN_ERROR_IF_SYSCALL_FAIL(
    38        bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
    39    MaybeSave();
    40  
    41    return std::move(fd);
    42  }
    43  
    44  PosixErrorOr<uint32_t> NetlinkPortID(int fd) {
    45    struct sockaddr_nl addr;
    46    socklen_t addrlen = sizeof(addr);
    47  
    48    RETURN_ERROR_IF_SYSCALL_FAIL(
    49        getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen));
    50    MaybeSave();
    51  
    52    return static_cast<uint32_t>(addr.nl_pid);
    53  }
    54  
    55  PosixError NetlinkRequestResponse(
    56      const FileDescriptor& fd, void* request, size_t len,
    57      const std::function<void(const struct nlmsghdr* hdr)>& fn,
    58      bool expect_nlmsgerr) {
    59    struct iovec iov = {};
    60    iov.iov_base = request;
    61    iov.iov_len = len;
    62  
    63    struct msghdr msg = {};
    64    msg.msg_iov = &iov;
    65    msg.msg_iovlen = 1;
    66    // No destination required; it defaults to pid 0, the kernel.
    67  
    68    RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
    69  
    70    return NetlinkResponse(fd, fn, expect_nlmsgerr);
    71  }
    72  
    73  PosixError NetlinkResponse(
    74      const FileDescriptor& fd,
    75      const std::function<void(const struct nlmsghdr* hdr)>& fn,
    76      bool expect_nlmsgerr) {
    77    constexpr size_t kBufferSize = 4096;
    78    std::vector<char> buf(kBufferSize);
    79    struct iovec iov = {};
    80    iov.iov_base = buf.data();
    81    iov.iov_len = buf.size();
    82    struct msghdr msg = {};
    83    msg.msg_iov = &iov;
    84    msg.msg_iovlen = 1;
    85  
    86    // If NLM_F_MULTI is set, response is a series of messages that ends with a
    87    // NLMSG_DONE message.
    88    int type = -1;
    89    int flags = 0;
    90    do {
    91      int len;
    92      RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
    93  
    94      // We don't bother with the complexity of dealing with truncated messages.
    95      // We must allocate a large enough buffer up front.
    96      if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
    97        return PosixError(EIO,
    98                          absl::StrCat("Received truncated message with flags: ",
    99                                       msg.msg_flags));
   100      }
   101  
   102      for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
   103           NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
   104        fn(hdr);
   105        flags = hdr->nlmsg_flags;
   106        type = hdr->nlmsg_type;
   107        // Done should include an integer payload for dump_done_errno.
   108        // See net/netlink/af_netlink.c:netlink_dump
   109        // Some tools like the 'ip' tool check the minimum length of the
   110        // NLMSG_DONE message.
   111        if (type == NLMSG_DONE) {
   112          EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int)));
   113        }
   114      }
   115    } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR);
   116  
   117    if (expect_nlmsgerr) {
   118      EXPECT_EQ(type, NLMSG_ERROR);
   119    } else if (flags & NLM_F_MULTI) {
   120      EXPECT_EQ(type, NLMSG_DONE);
   121    }
   122    return NoError();
   123  }
   124  
   125  PosixError NetlinkRequestResponseSingle(
   126      const FileDescriptor& fd, void* request, size_t len,
   127      const std::function<void(const struct nlmsghdr* hdr)>& fn) {
   128    struct iovec iov = {};
   129    iov.iov_base = request;
   130    iov.iov_len = len;
   131  
   132    struct msghdr msg = {};
   133    msg.msg_iov = &iov;
   134    msg.msg_iovlen = 1;
   135    // No destination required; it defaults to pid 0, the kernel.
   136  
   137    RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
   138  
   139    constexpr size_t kBufferSize = 4096;
   140    std::vector<char> buf(kBufferSize);
   141    iov.iov_base = buf.data();
   142    iov.iov_len = buf.size();
   143  
   144    int ret;
   145    RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
   146  
   147    // We don't bother with the complexity of dealing with truncated messages.
   148    // We must allocate a large enough buffer up front.
   149    if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
   150      return PosixError(
   151          EIO,
   152          absl::StrCat("Received truncated message with flags: ", msg.msg_flags));
   153    }
   154  
   155    for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
   156         NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) {
   157      fn(hdr);
   158    }
   159  
   160    return NoError();
   161  }
   162  
   163  PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
   164                                      void* request, size_t len) {
   165    // Dummy negative number for no error message received.
   166    // We won't get a negative error number so there will be no confusion.
   167    int err = -42;
   168    RETURN_IF_ERRNO(NetlinkRequestResponse(
   169        fd, request, len,
   170        [&](const struct nlmsghdr* hdr) {
   171          EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type);
   172          EXPECT_EQ(hdr->nlmsg_seq, seq);
   173          EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
   174  
   175          const struct nlmsgerr* msg =
   176              reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
   177          err = -msg->error;
   178        },
   179        true));
   180    return PosixError(err);
   181  }
   182  
   183  const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
   184                                  const struct ifinfomsg* msg, int16_t attr) {
   185    const int ifi_space = NLMSG_SPACE(sizeof(*msg));
   186    int attrlen = hdr->nlmsg_len - ifi_space;
   187    const struct rtattr* rta = reinterpret_cast<const struct rtattr*>(
   188        reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space));
   189    for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) {
   190      if (rta->rta_type == attr) {
   191        return rta;
   192      }
   193    }
   194    return nullptr;
   195  }
   196  
   197  }  // namespace testing
   198  }  // namespace gvisor