gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/examples/seccheck/server.cc (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <err.h>
    16  #include <pthread.h>
    17  #include <stdarg.h>
    18  #include <stdio.h>
    19  #include <sys/epoll.h>
    20  #include <sys/ioctl.h>
    21  #include <sys/socket.h>
    22  #include <sys/un.h>
    23  #include <unistd.h>
    24  
    25  #include <array>
    26  #include <string>
    27  #include <vector>
    28  
    29  #include "absl/cleanup/cleanup.h"
    30  #include "absl/strings/str_replace.h"
    31  #include "absl/strings/string_view.h"
    32  #include "pkg/sentry/seccheck/points/common.pb.h"
    33  #include "pkg/sentry/seccheck/points/container.pb.h"
    34  #include "pkg/sentry/seccheck/points/sentry.pb.h"
    35  #include "pkg/sentry/seccheck/points/syscall.pb.h"
    36  #include "google/protobuf/text_format.h"
    37  
    38  typedef std::function<void(absl::string_view buf)> Callback;
    39  
    40  constexpr size_t maxEventSize = 300 * 1024;
    41  
    42  bool quiet = false;
    43  
    44  #pragma pack(push, 1)
    45  struct header {
    46    uint16_t header_size;
    47    uint16_t message_type;
    48    uint32_t dropped_count;
    49  };
    50  #pragma pack(pop)
    51  
    52  void log(const char* fmt, ...) {
    53    if (!quiet) {
    54      va_list ap;
    55      va_start(ap, fmt);
    56      vprintf(fmt, ap);
    57      va_end(ap);
    58    }
    59  }
    60  
    61  template <class T>
    62  std::string shortfmt(T msg) {
    63    std::string short_text_msg;
    64    google::protobuf::TextFormat::PrintToString(msg, &short_text_msg);
    65    return absl::StrReplaceAll(short_text_msg,
    66                               {{"\r\n", " "}, {"\n", " "}, {"\r", " "}});
    67  }
    68  
    69  template <class T>
    70  void unpackSyscall(absl::string_view buf) {
    71    T evt;
    72    if (!evt.ParseFromArray(buf.data(), buf.size())) {
    73      err(1, "ParseFromString(): %.*s", static_cast<int>(buf.size()), buf.data());
    74    }
    75    log("%s %s %s\n", evt.has_exit() ? "X" : "E",
    76        evt.GetDescriptor()->name().c_str(), shortfmt(evt).c_str());
    77  }
    78  
    79  template <class T>
    80  void unpack(absl::string_view buf) {
    81    T evt;
    82    if (!evt.ParseFromArray(buf.data(), buf.size())) {
    83      err(1, "ParseFromString(): %.*s", static_cast<int>(buf.size()), buf.data());
    84    }
    85    log("%s => %s\n", evt.GetDescriptor()->name().c_str(), shortfmt(evt).c_str());
    86  }
    87  
    88  // List of dispatchers indexed based on MessageType enum values.
    89  std::vector<Callback> dispatchers = {
    90      nullptr,
    91      unpack<::gvisor::container::Start>,
    92      unpack<::gvisor::sentry::CloneInfo>,
    93      unpack<::gvisor::sentry::ExecveInfo>,
    94      unpack<::gvisor::sentry::ExitNotifyParentInfo>,
    95      unpack<::gvisor::sentry::TaskExit>,
    96      unpackSyscall<::gvisor::syscall::Syscall>,
    97      unpackSyscall<::gvisor::syscall::Open>,
    98      unpackSyscall<::gvisor::syscall::Close>,
    99      unpackSyscall<::gvisor::syscall::Read>,
   100      unpackSyscall<::gvisor::syscall::Connect>,
   101      unpackSyscall<::gvisor::syscall::Execve>,
   102      unpackSyscall<::gvisor::syscall::Socket>,
   103      unpackSyscall<::gvisor::syscall::Chdir>,
   104      unpackSyscall<::gvisor::syscall::Setid>,
   105      unpackSyscall<::gvisor::syscall::Setresid>,
   106      unpackSyscall<::gvisor::syscall::Dup>,
   107      unpackSyscall<::gvisor::syscall::Prlimit>,
   108      unpackSyscall<::gvisor::syscall::Pipe>,
   109      unpackSyscall<::gvisor::syscall::Fcntl>,
   110      unpackSyscall<::gvisor::syscall::Signalfd>,
   111      unpackSyscall<::gvisor::syscall::Eventfd>,
   112      unpackSyscall<::gvisor::syscall::Chroot>,
   113      unpackSyscall<::gvisor::syscall::Clone>,
   114      unpackSyscall<::gvisor::syscall::Bind>,
   115      unpackSyscall<::gvisor::syscall::Accept>,
   116      unpackSyscall<::gvisor::syscall::TimerfdCreate>,
   117      unpackSyscall<::gvisor::syscall::TimerfdSetTime>,
   118      unpackSyscall<::gvisor::syscall::TimerfdGetTime>,
   119      unpackSyscall<::gvisor::syscall::Fork>,
   120      unpackSyscall<::gvisor::syscall::InotifyInit>,
   121      unpackSyscall<::gvisor::syscall::InotifyAddWatch>,
   122      unpackSyscall<::gvisor::syscall::InotifyRmWatch>,
   123      unpackSyscall<::gvisor::syscall::SocketPair>,
   124      unpackSyscall<::gvisor::syscall::Write>,
   125  };
   126  
   127  void unpack(absl::string_view buf) {
   128    const header* hdr = reinterpret_cast<const header*>(&buf[0]);
   129  
   130    // Payload size can be zero when proto object contains only defaults values.
   131    size_t payload_size = buf.size() - hdr->header_size;
   132    if (payload_size < 0) {
   133      printf("Header size (%u) is larger than message %lu\n", hdr->header_size,
   134             buf.size());
   135      return;
   136    }
   137  
   138    auto proto = buf.substr(hdr->header_size);
   139    if (proto.size() < payload_size) {
   140      printf("Message was truncated, size: %lu, expected: %zu\n", proto.size(),
   141             payload_size);
   142      return;
   143    }
   144  
   145    if (hdr->message_type == 0 || hdr->message_type >= dispatchers.size()) {
   146      printf("Invalid message type: %u\n", hdr->message_type);
   147      return;
   148    }
   149    Callback cb = dispatchers[hdr->message_type];
   150    cb(proto);
   151  }
   152  
   153  bool readAndUnpack(int client) {
   154    std::array<char, maxEventSize> buf;
   155    int bytes = read(client, buf.data(), buf.size());
   156    if (bytes < 0) {
   157      err(1, "read");
   158    }
   159    if (bytes == 0) {
   160      return false;
   161    }
   162    unpack(absl::string_view(buf.data(), bytes));
   163    return true;
   164  }
   165  
   166  void* pollLoop(void* ptr) {
   167    const int poll_fd = *reinterpret_cast<int*>(&ptr);
   168    for (;;) {
   169      epoll_event evts[64];
   170      int nfds = epoll_wait(poll_fd, evts, 64, -1);
   171      if (nfds < 0) {
   172        if (errno == EINTR) {
   173          continue;
   174        }
   175        err(1, "epoll_wait");
   176      }
   177  
   178      for (int i = 0; i < nfds; ++i) {
   179        if (evts[i].events & EPOLLIN) {
   180          int client = evts[i].data.fd;
   181          readAndUnpack(client);
   182        }
   183        if ((evts[i].events & (EPOLLRDHUP | EPOLLHUP)) != 0) {
   184          int client = evts[i].data.fd;
   185          // Drain any remaining messages before closing the socket.
   186          while (readAndUnpack(client)) {
   187          }
   188          close(client);
   189          printf("Connection closed\n");
   190        }
   191        if (evts[i].events & EPOLLERR) {
   192          printf("error\n");
   193        }
   194      }
   195    }
   196  }
   197  
   198  void startPollThread(int poll_fd) {
   199    pthread_t thread;
   200    if (pthread_create(&thread, nullptr, pollLoop,
   201                       reinterpret_cast<void*>(poll_fd)) != 0) {
   202      err(1, "pthread_create");
   203    }
   204    pthread_detach(thread);
   205  }
   206  
   207  // handshake performs version exchange with client. See common.proto for details
   208  // about the protocol.
   209  bool handshake(int client_fd) {
   210    std::vector<char> buf(10240);
   211    int bytes = read(client_fd, buf.data(), buf.size());
   212    if (bytes < 0) {
   213      printf("Error receiving handshake message: %d\n", errno);
   214      return false;
   215    } else if (bytes == (int)buf.size()) {
   216      // Protect against the handshake becoming larger than the buffer allocated
   217      // for it.
   218      printf("handshake message too big\n");
   219      return false;
   220    }
   221    ::gvisor::common::Handshake in = {};
   222    if (!in.ParseFromArray(buf.data(), bytes)) {
   223      printf("Error parsing handshake message\n");
   224      return false;
   225    }
   226  
   227    constexpr uint32_t minSupportedVersion = 1;
   228    if (in.version() < minSupportedVersion) {
   229      printf("Client has unsupported version %u\n", in.version());
   230      return false;
   231    }
   232  
   233    ::gvisor::common::Handshake out;
   234    out.set_version(1);
   235    if (!out.SerializeToFileDescriptor(client_fd)) {
   236      printf("Error sending handshake message: %d\n", errno);
   237      return false;
   238    }
   239    return true;
   240  }
   241  
   242  extern "C" int main(int argc, char** argv) {
   243    for (int c = 0; (c = getopt(argc, argv, "q")) != -1;) {
   244      switch (c) {
   245        case 'q':
   246          quiet = true;
   247          break;
   248        default:
   249          exit(1);
   250      }
   251    }
   252  
   253    if (!quiet) {
   254      setbuf(stdout, NULL);
   255      setbuf(stderr, NULL);
   256    }
   257    std::string path("/tmp/gvisor_events.sock");
   258    if (optind < argc) {
   259      path = argv[optind];
   260    }
   261    if (path.empty()) {
   262      err(1, "empty file name");
   263    }
   264    printf("Socket address %s\n", path.c_str());
   265    unlink(path.c_str());
   266  
   267    int sock = socket(AF_UNIX, SOCK_SEQPACKET, 0);
   268    if (sock < 0) {
   269      err(1, "socket");
   270    }
   271    auto sock_closer = absl::MakeCleanup([sock] { close(sock); });
   272  
   273    struct sockaddr_un addr;
   274    addr.sun_family = AF_UNIX;
   275    strncpy(addr.sun_path, path.c_str(), path.size() + 1);
   276    if (bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))) {
   277      err(1, "bind");
   278    }
   279    if (listen(sock, 5) < 0) {
   280      err(1, "listen");
   281    }
   282  
   283    int epoll_fd = epoll_create(1);
   284    if (epoll_fd < 0) {
   285      err(1, "epoll_create");
   286    }
   287    auto epoll_closer = absl::MakeCleanup([epoll_fd] { close(epoll_fd); });
   288    startPollThread(epoll_fd);
   289  
   290    for (;;) {
   291      int client = accept(sock, nullptr, nullptr);
   292      if (client < 0) {
   293        if (errno == EINTR) {
   294          continue;
   295        }
   296        err(1, "accept");
   297      }
   298      printf("Connection accepted\n");
   299  
   300      if (!handshake(client)) {
   301        close(client);
   302        continue;
   303      }
   304  
   305      struct epoll_event evt;
   306      evt.data.fd = client;
   307      evt.events = EPOLLIN;
   308      if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client, &evt) < 0) {
   309        err(1, "epoll_ctl(ADD)");
   310      }
   311    }
   312  }