gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/examples/seccheck/server.cc (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <err.h> 16 #include <pthread.h> 17 #include <stdarg.h> 18 #include <stdio.h> 19 #include <sys/epoll.h> 20 #include <sys/ioctl.h> 21 #include <sys/socket.h> 22 #include <sys/un.h> 23 #include <unistd.h> 24 25 #include <array> 26 #include <string> 27 #include <vector> 28 29 #include "absl/cleanup/cleanup.h" 30 #include "absl/strings/str_replace.h" 31 #include "absl/strings/string_view.h" 32 #include "pkg/sentry/seccheck/points/common.pb.h" 33 #include "pkg/sentry/seccheck/points/container.pb.h" 34 #include "pkg/sentry/seccheck/points/sentry.pb.h" 35 #include "pkg/sentry/seccheck/points/syscall.pb.h" 36 #include "google/protobuf/text_format.h" 37 38 typedef std::function<void(absl::string_view buf)> Callback; 39 40 constexpr size_t maxEventSize = 300 * 1024; 41 42 bool quiet = false; 43 44 #pragma pack(push, 1) 45 struct header { 46 uint16_t header_size; 47 uint16_t message_type; 48 uint32_t dropped_count; 49 }; 50 #pragma pack(pop) 51 52 void log(const char* fmt, ...) { 53 if (!quiet) { 54 va_list ap; 55 va_start(ap, fmt); 56 vprintf(fmt, ap); 57 va_end(ap); 58 } 59 } 60 61 template <class T> 62 std::string shortfmt(T msg) { 63 std::string short_text_msg; 64 google::protobuf::TextFormat::PrintToString(msg, &short_text_msg); 65 return absl::StrReplaceAll(short_text_msg, 66 {{"\r\n", " "}, {"\n", " "}, {"\r", " "}}); 67 } 68 69 template <class T> 70 void unpackSyscall(absl::string_view buf) { 71 T evt; 72 if (!evt.ParseFromArray(buf.data(), buf.size())) { 73 err(1, "ParseFromString(): %.*s", static_cast<int>(buf.size()), buf.data()); 74 } 75 log("%s %s %s\n", evt.has_exit() ? "X" : "E", 76 evt.GetDescriptor()->name().c_str(), shortfmt(evt).c_str()); 77 } 78 79 template <class T> 80 void unpack(absl::string_view buf) { 81 T evt; 82 if (!evt.ParseFromArray(buf.data(), buf.size())) { 83 err(1, "ParseFromString(): %.*s", static_cast<int>(buf.size()), buf.data()); 84 } 85 log("%s => %s\n", evt.GetDescriptor()->name().c_str(), shortfmt(evt).c_str()); 86 } 87 88 // List of dispatchers indexed based on MessageType enum values. 89 std::vector<Callback> dispatchers = { 90 nullptr, 91 unpack<::gvisor::container::Start>, 92 unpack<::gvisor::sentry::CloneInfo>, 93 unpack<::gvisor::sentry::ExecveInfo>, 94 unpack<::gvisor::sentry::ExitNotifyParentInfo>, 95 unpack<::gvisor::sentry::TaskExit>, 96 unpackSyscall<::gvisor::syscall::Syscall>, 97 unpackSyscall<::gvisor::syscall::Open>, 98 unpackSyscall<::gvisor::syscall::Close>, 99 unpackSyscall<::gvisor::syscall::Read>, 100 unpackSyscall<::gvisor::syscall::Connect>, 101 unpackSyscall<::gvisor::syscall::Execve>, 102 unpackSyscall<::gvisor::syscall::Socket>, 103 unpackSyscall<::gvisor::syscall::Chdir>, 104 unpackSyscall<::gvisor::syscall::Setid>, 105 unpackSyscall<::gvisor::syscall::Setresid>, 106 unpackSyscall<::gvisor::syscall::Dup>, 107 unpackSyscall<::gvisor::syscall::Prlimit>, 108 unpackSyscall<::gvisor::syscall::Pipe>, 109 unpackSyscall<::gvisor::syscall::Fcntl>, 110 unpackSyscall<::gvisor::syscall::Signalfd>, 111 unpackSyscall<::gvisor::syscall::Eventfd>, 112 unpackSyscall<::gvisor::syscall::Chroot>, 113 unpackSyscall<::gvisor::syscall::Clone>, 114 unpackSyscall<::gvisor::syscall::Bind>, 115 unpackSyscall<::gvisor::syscall::Accept>, 116 unpackSyscall<::gvisor::syscall::TimerfdCreate>, 117 unpackSyscall<::gvisor::syscall::TimerfdSetTime>, 118 unpackSyscall<::gvisor::syscall::TimerfdGetTime>, 119 unpackSyscall<::gvisor::syscall::Fork>, 120 unpackSyscall<::gvisor::syscall::InotifyInit>, 121 unpackSyscall<::gvisor::syscall::InotifyAddWatch>, 122 unpackSyscall<::gvisor::syscall::InotifyRmWatch>, 123 unpackSyscall<::gvisor::syscall::SocketPair>, 124 unpackSyscall<::gvisor::syscall::Write>, 125 }; 126 127 void unpack(absl::string_view buf) { 128 const header* hdr = reinterpret_cast<const header*>(&buf[0]); 129 130 // Payload size can be zero when proto object contains only defaults values. 131 size_t payload_size = buf.size() - hdr->header_size; 132 if (payload_size < 0) { 133 printf("Header size (%u) is larger than message %lu\n", hdr->header_size, 134 buf.size()); 135 return; 136 } 137 138 auto proto = buf.substr(hdr->header_size); 139 if (proto.size() < payload_size) { 140 printf("Message was truncated, size: %lu, expected: %zu\n", proto.size(), 141 payload_size); 142 return; 143 } 144 145 if (hdr->message_type == 0 || hdr->message_type >= dispatchers.size()) { 146 printf("Invalid message type: %u\n", hdr->message_type); 147 return; 148 } 149 Callback cb = dispatchers[hdr->message_type]; 150 cb(proto); 151 } 152 153 bool readAndUnpack(int client) { 154 std::array<char, maxEventSize> buf; 155 int bytes = read(client, buf.data(), buf.size()); 156 if (bytes < 0) { 157 err(1, "read"); 158 } 159 if (bytes == 0) { 160 return false; 161 } 162 unpack(absl::string_view(buf.data(), bytes)); 163 return true; 164 } 165 166 void* pollLoop(void* ptr) { 167 const int poll_fd = *reinterpret_cast<int*>(&ptr); 168 for (;;) { 169 epoll_event evts[64]; 170 int nfds = epoll_wait(poll_fd, evts, 64, -1); 171 if (nfds < 0) { 172 if (errno == EINTR) { 173 continue; 174 } 175 err(1, "epoll_wait"); 176 } 177 178 for (int i = 0; i < nfds; ++i) { 179 if (evts[i].events & EPOLLIN) { 180 int client = evts[i].data.fd; 181 readAndUnpack(client); 182 } 183 if ((evts[i].events & (EPOLLRDHUP | EPOLLHUP)) != 0) { 184 int client = evts[i].data.fd; 185 // Drain any remaining messages before closing the socket. 186 while (readAndUnpack(client)) { 187 } 188 close(client); 189 printf("Connection closed\n"); 190 } 191 if (evts[i].events & EPOLLERR) { 192 printf("error\n"); 193 } 194 } 195 } 196 } 197 198 void startPollThread(int poll_fd) { 199 pthread_t thread; 200 if (pthread_create(&thread, nullptr, pollLoop, 201 reinterpret_cast<void*>(poll_fd)) != 0) { 202 err(1, "pthread_create"); 203 } 204 pthread_detach(thread); 205 } 206 207 // handshake performs version exchange with client. See common.proto for details 208 // about the protocol. 209 bool handshake(int client_fd) { 210 std::vector<char> buf(10240); 211 int bytes = read(client_fd, buf.data(), buf.size()); 212 if (bytes < 0) { 213 printf("Error receiving handshake message: %d\n", errno); 214 return false; 215 } else if (bytes == (int)buf.size()) { 216 // Protect against the handshake becoming larger than the buffer allocated 217 // for it. 218 printf("handshake message too big\n"); 219 return false; 220 } 221 ::gvisor::common::Handshake in = {}; 222 if (!in.ParseFromArray(buf.data(), bytes)) { 223 printf("Error parsing handshake message\n"); 224 return false; 225 } 226 227 constexpr uint32_t minSupportedVersion = 1; 228 if (in.version() < minSupportedVersion) { 229 printf("Client has unsupported version %u\n", in.version()); 230 return false; 231 } 232 233 ::gvisor::common::Handshake out; 234 out.set_version(1); 235 if (!out.SerializeToFileDescriptor(client_fd)) { 236 printf("Error sending handshake message: %d\n", errno); 237 return false; 238 } 239 return true; 240 } 241 242 extern "C" int main(int argc, char** argv) { 243 for (int c = 0; (c = getopt(argc, argv, "q")) != -1;) { 244 switch (c) { 245 case 'q': 246 quiet = true; 247 break; 248 default: 249 exit(1); 250 } 251 } 252 253 if (!quiet) { 254 setbuf(stdout, NULL); 255 setbuf(stderr, NULL); 256 } 257 std::string path("/tmp/gvisor_events.sock"); 258 if (optind < argc) { 259 path = argv[optind]; 260 } 261 if (path.empty()) { 262 err(1, "empty file name"); 263 } 264 printf("Socket address %s\n", path.c_str()); 265 unlink(path.c_str()); 266 267 int sock = socket(AF_UNIX, SOCK_SEQPACKET, 0); 268 if (sock < 0) { 269 err(1, "socket"); 270 } 271 auto sock_closer = absl::MakeCleanup([sock] { close(sock); }); 272 273 struct sockaddr_un addr; 274 addr.sun_family = AF_UNIX; 275 strncpy(addr.sun_path, path.c_str(), path.size() + 1); 276 if (bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))) { 277 err(1, "bind"); 278 } 279 if (listen(sock, 5) < 0) { 280 err(1, "listen"); 281 } 282 283 int epoll_fd = epoll_create(1); 284 if (epoll_fd < 0) { 285 err(1, "epoll_create"); 286 } 287 auto epoll_closer = absl::MakeCleanup([epoll_fd] { close(epoll_fd); }); 288 startPollThread(epoll_fd); 289 290 for (;;) { 291 int client = accept(sock, nullptr, nullptr); 292 if (client < 0) { 293 if (errno == EINTR) { 294 continue; 295 } 296 err(1, "accept"); 297 } 298 printf("Connection accepted\n"); 299 300 if (!handshake(client)) { 301 close(client); 302 continue; 303 } 304 305 struct epoll_event evt; 306 evt.data.fd = client; 307 evt.events = EPOLLIN; 308 if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client, &evt) < 0) { 309 err(1, "epoll_ctl(ADD)"); 310 } 311 } 312 }