github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/syscalls/linux/unix_domain_socket_test_util.cc (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include "test/syscalls/linux/unix_domain_socket_test_util.h"
    16  
    17  #include <sys/un.h>
    18  
    19  #include <vector>
    20  
    21  #include "gtest/gtest.h"
    22  #include "absl/strings/str_cat.h"
    23  #include "test/util/test_util.h"
    24  
    25  namespace gvisor {
    26  namespace testing {
    27  
    28  std::string DescribeUnixDomainSocketType(int type) {
    29    const char* type_str = nullptr;
    30    switch (type & ~(SOCK_NONBLOCK | SOCK_CLOEXEC)) {
    31      case SOCK_STREAM:
    32        type_str = "SOCK_STREAM";
    33        break;
    34      case SOCK_DGRAM:
    35        type_str = "SOCK_DGRAM";
    36        break;
    37      case SOCK_SEQPACKET:
    38        type_str = "SOCK_SEQPACKET";
    39        break;
    40    }
    41    if (!type_str) {
    42      return absl::StrCat("Unix domain socket with unknown type ", type);
    43    } else {
    44      return absl::StrCat(((type & SOCK_NONBLOCK) != 0) ? "non-blocking " : "",
    45                          ((type & SOCK_CLOEXEC) != 0) ? "close-on-exec " : "",
    46                          type_str, " Unix domain socket");
    47    }
    48  }
    49  
    50  SocketPairKind UnixDomainSocketPair(int type) {
    51    return SocketPairKind{DescribeUnixDomainSocketType(type), AF_UNIX, type, 0,
    52                          SyscallSocketPairCreator(AF_UNIX, type, 0)};
    53  }
    54  
    55  SocketPairKind FilesystemBoundUnixDomainSocketPair(int type) {
    56    std::string description = absl::StrCat(DescribeUnixDomainSocketType(type),
    57                                           " created with filesystem binding");
    58    if ((type & SOCK_DGRAM) == SOCK_DGRAM) {
    59      return SocketPairKind{
    60          description, AF_UNIX, type, 0,
    61          FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)};
    62    }
    63    return SocketPairKind{
    64        description, AF_UNIX, type, 0,
    65        FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)};
    66  }
    67  
    68  SocketPairKind AbstractBoundUnixDomainSocketPair(int type) {
    69    std::string description =
    70        absl::StrCat(DescribeUnixDomainSocketType(type),
    71                     " created with abstract namespace binding");
    72    if ((type & SOCK_DGRAM) == SOCK_DGRAM) {
    73      return SocketPairKind{
    74          description, AF_UNIX, type, 0,
    75          AbstractBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)};
    76    }
    77    return SocketPairKind{description, AF_UNIX, type, 0,
    78                          AbstractAcceptBindSocketPairCreator(AF_UNIX, type, 0)};
    79  }
    80  
    81  SocketPairKind SocketpairGoferUnixDomainSocketPair(int type) {
    82    std::string description = absl::StrCat(DescribeUnixDomainSocketType(type),
    83                                           " created with the socketpair gofer");
    84    return SocketPairKind{description, AF_UNIX, type, 0,
    85                          SocketpairGoferSocketPairCreator(AF_UNIX, type, 0)};
    86  }
    87  
    88  SocketPairKind SocketpairGoferFileSocketPair(int type) {
    89    std::string description =
    90        absl::StrCat(((type & O_NONBLOCK) != 0) ? "non-blocking " : "",
    91                     ((type & O_CLOEXEC) != 0) ? "close-on-exec " : "",
    92                     "file socket created with the socketpair gofer");
    93    // The socketpair gofer always creates SOCK_STREAM sockets on open(2).
    94    return SocketPairKind{description, AF_UNIX, SOCK_STREAM, 0,
    95                          SocketpairGoferFileSocketPairCreator(type)};
    96  }
    97  
    98  SocketPairKind FilesystemUnboundUnixDomainSocketPair(int type) {
    99    return SocketPairKind{absl::StrCat(DescribeUnixDomainSocketType(type),
   100                                       " unbound with a filesystem address"),
   101                          AF_UNIX, type, 0,
   102                          FilesystemUnboundSocketPairCreator(AF_UNIX, type, 0)};
   103  }
   104  
   105  SocketPairKind AbstractUnboundUnixDomainSocketPair(int type) {
   106    return SocketPairKind{
   107        absl::StrCat(DescribeUnixDomainSocketType(type),
   108                     " unbound with an abstract namespace address"),
   109        AF_UNIX, type, 0, AbstractUnboundSocketPairCreator(AF_UNIX, type, 0)};
   110  }
   111  
   112  void SendSingleFD(int sock, int fd, char buf[], int buf_size) {
   113    ASSERT_NO_FATAL_FAILURE(SendFDs(sock, &fd, 1, buf, buf_size));
   114  }
   115  
   116  void SendFDs(int sock, int fds[], int fds_size, char buf[], int buf_size) {
   117    struct msghdr msg = {};
   118    std::vector<char> control(CMSG_SPACE(fds_size * sizeof(int)));
   119    msg.msg_control = &control[0];
   120    msg.msg_controllen = control.size();
   121  
   122    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
   123    cmsg->cmsg_len = CMSG_LEN(fds_size * sizeof(int));
   124    cmsg->cmsg_level = SOL_SOCKET;
   125    cmsg->cmsg_type = SCM_RIGHTS;
   126    for (int i = 0; i < fds_size; i++) {
   127      memcpy(CMSG_DATA(cmsg) + i * sizeof(int), &fds[i], sizeof(int));
   128    }
   129  
   130    ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
   131                IsPosixErrorOkAndHolds(buf_size));
   132  }
   133  
   134  void RecvSingleFD(int sock, int* fd, char buf[], int buf_size) {
   135    ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, buf_size));
   136  }
   137  
   138  void RecvSingleFD(int sock, int* fd, char buf[], int buf_size,
   139                    int expected_size) {
   140    ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, expected_size));
   141  }
   142  
   143  void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size) {
   144    ASSERT_NO_FATAL_FAILURE(
   145        RecvFDs(sock, fds, fds_size, buf, buf_size, buf_size));
   146  }
   147  
   148  void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size,
   149               int expected_size, bool peek) {
   150    struct msghdr msg = {};
   151    std::vector<char> control(CMSG_SPACE(fds_size * sizeof(int)));
   152    msg.msg_control = &control[0];
   153    msg.msg_controllen = control.size();
   154  
   155    struct iovec iov;
   156    iov.iov_base = buf;
   157    iov.iov_len = buf_size;
   158    msg.msg_iov = &iov;
   159    msg.msg_iovlen = 1;
   160  
   161    int flags = 0;
   162    if (peek) {
   163      flags |= MSG_PEEK;
   164    }
   165  
   166    ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, flags),
   167                SyscallSucceedsWithValue(expected_size));
   168    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
   169    ASSERT_NE(cmsg, nullptr);
   170    ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(fds_size * sizeof(int)));
   171    ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
   172    ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
   173  
   174    for (int i = 0; i < fds_size; i++) {
   175      memcpy(&fds[i], CMSG_DATA(cmsg) + i * sizeof(int), sizeof(int));
   176    }
   177  }
   178  
   179  void RecvFDs(int sock, int fds[], int fds_size, char buf[], int buf_size,
   180               int expected_size) {
   181    ASSERT_NO_FATAL_FAILURE(
   182        RecvFDs(sock, fds, fds_size, buf, buf_size, expected_size, false));
   183  }
   184  
   185  void PeekSingleFD(int sock, int* fd, char buf[], int buf_size) {
   186    ASSERT_NO_FATAL_FAILURE(RecvFDs(sock, fd, 1, buf, buf_size, buf_size, true));
   187  }
   188  
   189  void RecvNoCmsg(int sock, char buf[], int buf_size, int expected_size) {
   190    struct msghdr msg = {};
   191    char control[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(struct ucred))];
   192    msg.msg_control = control;
   193    msg.msg_controllen = sizeof(control);
   194  
   195    struct iovec iov;
   196    iov.iov_base = buf;
   197    iov.iov_len = buf_size;
   198    msg.msg_iov = &iov;
   199    msg.msg_iovlen = 1;
   200  
   201    ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
   202                SyscallSucceedsWithValue(expected_size));
   203    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
   204    EXPECT_EQ(cmsg, nullptr);
   205  }
   206  
   207  void SendNullCmsg(int sock, char buf[], int buf_size) {
   208    struct msghdr msg = {};
   209    msg.msg_control = nullptr;
   210    msg.msg_controllen = 0;
   211  
   212    ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
   213                IsPosixErrorOkAndHolds(buf_size));
   214  }
   215  
   216  void SendCreds(int sock, ucred creds, char buf[], int buf_size) {
   217    struct msghdr msg = {};
   218  
   219    char control[CMSG_SPACE(sizeof(struct ucred))];
   220    msg.msg_control = control;
   221    msg.msg_controllen = sizeof(control);
   222  
   223    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
   224    cmsg->cmsg_level = SOL_SOCKET;
   225    cmsg->cmsg_type = SCM_CREDENTIALS;
   226    cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
   227    memcpy(CMSG_DATA(cmsg), &creds, sizeof(struct ucred));
   228  
   229    ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
   230                IsPosixErrorOkAndHolds(buf_size));
   231  }
   232  
   233  void SendCredsAndFD(int sock, ucred creds, int fd, char buf[], int buf_size) {
   234    struct msghdr msg = {};
   235  
   236    char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))] = {};
   237    msg.msg_control = control;
   238    msg.msg_controllen = sizeof(control);
   239  
   240    struct cmsghdr* cmsg1 = CMSG_FIRSTHDR(&msg);
   241    cmsg1->cmsg_level = SOL_SOCKET;
   242    cmsg1->cmsg_type = SCM_CREDENTIALS;
   243    cmsg1->cmsg_len = CMSG_LEN(sizeof(struct ucred));
   244    memcpy(CMSG_DATA(cmsg1), &creds, sizeof(struct ucred));
   245  
   246    struct cmsghdr* cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
   247    cmsg2->cmsg_level = SOL_SOCKET;
   248    cmsg2->cmsg_type = SCM_RIGHTS;
   249    cmsg2->cmsg_len = CMSG_LEN(sizeof(int));
   250    memcpy(CMSG_DATA(cmsg2), &fd, sizeof(int));
   251  
   252    ASSERT_THAT(SendMsg(sock, &msg, buf, buf_size),
   253                IsPosixErrorOkAndHolds(buf_size));
   254  }
   255  
   256  void RecvCreds(int sock, ucred* creds, char buf[], int buf_size) {
   257    ASSERT_NO_FATAL_FAILURE(RecvCreds(sock, creds, buf, buf_size, buf_size));
   258  }
   259  
   260  void RecvCreds(int sock, ucred* creds, char buf[], int buf_size,
   261                 int expected_size) {
   262    struct msghdr msg = {};
   263    char control[CMSG_SPACE(sizeof(struct ucred))];
   264    msg.msg_control = control;
   265    msg.msg_controllen = sizeof(control);
   266  
   267    struct iovec iov;
   268    iov.iov_base = buf;
   269    iov.iov_len = buf_size;
   270    msg.msg_iov = &iov;
   271    msg.msg_iovlen = 1;
   272  
   273    ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
   274                SyscallSucceedsWithValue(expected_size));
   275    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
   276    ASSERT_NE(cmsg, nullptr);
   277    ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred)));
   278    ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
   279    ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
   280  
   281    memcpy(creds, CMSG_DATA(cmsg), sizeof(struct ucred));
   282  }
   283  
   284  void RecvCredsAndFD(int sock, ucred* creds, int* fd, char buf[], int buf_size) {
   285    struct msghdr msg = {};
   286    char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))];
   287    msg.msg_control = control;
   288    msg.msg_controllen = sizeof(control);
   289  
   290    struct iovec iov;
   291    iov.iov_base = buf;
   292    iov.iov_len = buf_size;
   293    msg.msg_iov = &iov;
   294    msg.msg_iovlen = 1;
   295  
   296    ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
   297                SyscallSucceedsWithValue(buf_size));
   298  
   299    struct cmsghdr* cmsg1 = CMSG_FIRSTHDR(&msg);
   300    ASSERT_NE(cmsg1, nullptr);
   301    ASSERT_EQ(cmsg1->cmsg_len, CMSG_LEN(sizeof(struct ucred)));
   302    ASSERT_EQ(cmsg1->cmsg_level, SOL_SOCKET);
   303    ASSERT_EQ(cmsg1->cmsg_type, SCM_CREDENTIALS);
   304    memcpy(creds, CMSG_DATA(cmsg1), sizeof(struct ucred));
   305  
   306    struct cmsghdr* cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
   307    ASSERT_NE(cmsg2, nullptr);
   308    ASSERT_EQ(cmsg2->cmsg_len, CMSG_LEN(sizeof(int)));
   309    ASSERT_EQ(cmsg2->cmsg_level, SOL_SOCKET);
   310    ASSERT_EQ(cmsg2->cmsg_type, SCM_RIGHTS);
   311    memcpy(fd, CMSG_DATA(cmsg2), sizeof(int));
   312  }
   313  
   314  void RecvSingleFDUnaligned(int sock, int* fd, char buf[], int buf_size) {
   315    struct msghdr msg = {};
   316    char control[CMSG_SPACE(sizeof(int)) - sizeof(int)];
   317    msg.msg_control = control;
   318    msg.msg_controllen = sizeof(control);
   319  
   320    struct iovec iov;
   321    iov.iov_base = buf;
   322    iov.iov_len = buf_size;
   323    msg.msg_iov = &iov;
   324    msg.msg_iovlen = 1;
   325  
   326    ASSERT_THAT(RetryEINTR(recvmsg)(sock, &msg, 0),
   327                SyscallSucceedsWithValue(buf_size));
   328  
   329    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
   330    ASSERT_NE(cmsg, nullptr);
   331    ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int)));
   332    ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
   333    ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS);
   334  
   335    memcpy(fd, CMSG_DATA(cmsg), sizeof(int));
   336  }
   337  
   338  void SetSoPassCred(int sock) {
   339    int one = 1;
   340    EXPECT_THAT(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one)),
   341                SyscallSucceeds());
   342  }
   343  
   344  void UnsetSoPassCred(int sock) {
   345    int zero = 0;
   346    EXPECT_THAT(setsockopt(sock, SOL_SOCKET, SO_PASSCRED, &zero, sizeof(zero)),
   347                SyscallSucceeds());
   348  }
   349  
   350  }  // namespace testing
   351  }  // namespace gvisor