github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/executor/conn.h (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  #include <arpa/inet.h>
     5  #include <endian.h>
     6  #include <fcntl.h>
     7  #include <netdb.h>
     8  #include <netinet/in.h>
     9  #include <poll.h>
    10  #include <string.h>
    11  #include <sys/select.h>
    12  #include <sys/socket.h>
    13  
    14  #include <vector>
    15  
    16  // Connection represents a client TCP connection.
    17  // It connects to the given addr:port and allows to send/receive
    18  // flatbuffers-encoded messages.
    19  class Connection
    20  {
    21  public:
    22  	Connection(const char* addr, const char* port)
    23  	    : fd_(Connect(addr, port))
    24  	{
    25  	}
    26  
    27  	int FD() const
    28  	{
    29  		return fd_;
    30  	}
    31  
    32  	template <typename Msg>
    33  	void Send(const Msg& msg)
    34  	{
    35  		typedef typename Msg::TableType Raw;
    36  		auto off = Raw::Pack(fbb_, &msg);
    37  		fbb_.FinishSizePrefixed(off);
    38  		auto data = fbb_.GetBufferSpan();
    39  		Send(data.data(), data.size());
    40  		fbb_.Reset();
    41  	}
    42  
    43  	template <typename Msg>
    44  	void Recv(Msg& msg)
    45  	{
    46  		typedef typename Msg::TableType Raw;
    47  		flatbuffers::uoffset_t size;
    48  		Recv(&size, sizeof(size));
    49  		size = le32toh(size);
    50  		recv_buf_.resize(size);
    51  		Recv(recv_buf_.data(), size);
    52  		auto raw = flatbuffers::GetRoot<Raw>(recv_buf_.data());
    53  		raw->UnPackTo(&msg);
    54  	}
    55  
    56  	void Send(const void* data, size_t size)
    57  	{
    58  		for (size_t sent = 0; sent < size;) {
    59  			ssize_t n = write(fd_, static_cast<const char*>(data) + sent, size - sent);
    60  			if (n > 0) {
    61  				sent += n;
    62  				continue;
    63  			}
    64  			if (errno == EINTR)
    65  				continue;
    66  			if (errno == EAGAIN) {
    67  				sleep_ms(1);
    68  				continue;
    69  			}
    70  			failmsg("failed to send rpc", "fd=%d want=%zu sent=%zu n=%zd", fd_, size, sent, n);
    71  		}
    72  	}
    73  
    74  private:
    75  	const int fd_;
    76  	std::vector<char> recv_buf_;
    77  	flatbuffers::FlatBufferBuilder fbb_;
    78  
    79  	void Recv(void* data, size_t size)
    80  	{
    81  		for (size_t recv = 0; recv < size;) {
    82  			ssize_t n = read(fd_, static_cast<char*>(data) + recv, size - recv);
    83  			if (n > 0) {
    84  				recv += n;
    85  				continue;
    86  			}
    87  			if (errno == EINTR)
    88  				continue;
    89  			if (errno == EAGAIN) {
    90  				sleep_ms(1);
    91  				continue;
    92  			}
    93  			failmsg("failed to recv rpc", "fd=%d want=%zu recv=%zu n=%zd", fd_, size, recv, n);
    94  		}
    95  	}
    96  
    97  	static int Connect(const char* addr, const char* ports)
    98  	{
    99  		int port = atoi(ports);
   100  		bool localhost = !strcmp(addr, "localhost");
   101  		int fd;
   102  		if (!strcmp(addr, "stdin"))
   103  			return STDIN_FILENO;
   104  		if (port == 0)
   105  			failmsg("failed to parse manager port", "port=%s", ports);
   106  		sockaddr_in saddr4 = {};
   107  		saddr4.sin_family = AF_INET;
   108  		saddr4.sin_port = htons(port);
   109  		if (localhost)
   110  			addr = "127.0.0.1";
   111  		if (inet_pton(AF_INET, addr, &saddr4.sin_addr)) {
   112  			fd = Connect(&saddr4, &saddr4.sin_addr, port);
   113  			if (fd != -1 || !localhost)
   114  				return fd;
   115  		}
   116  		sockaddr_in6 saddr6 = {};
   117  		saddr6.sin6_family = AF_INET6;
   118  		saddr6.sin6_port = htons(port);
   119  		if (localhost)
   120  			addr = "0:0:0:0:0:0:0:1";
   121  		if (inet_pton(AF_INET6, addr, &saddr6.sin6_addr)) {
   122  			fd = Connect(&saddr6, &saddr6.sin6_addr, port);
   123  			if (fd != -1 || !localhost)
   124  				return fd;
   125  		}
   126  		auto* hostent = gethostbyname(addr);
   127  		if (!hostent)
   128  			failmsg("failed to resolve manager addr", "addr=%s h_errno=%d", addr, h_errno);
   129  		for (char** addr = hostent->h_addr_list; *addr; addr++) {
   130  			if (hostent->h_addrtype == AF_INET) {
   131  				memcpy(&saddr4.sin_addr, *addr, std::min<size_t>(hostent->h_length, sizeof(saddr4.sin_addr)));
   132  				fd = Connect(&saddr4, &saddr4.sin_addr, port);
   133  			} else if (hostent->h_addrtype == AF_INET6) {
   134  				memcpy(&saddr6.sin6_addr, *addr, std::min<size_t>(hostent->h_length, sizeof(saddr6.sin6_addr)));
   135  				fd = Connect(&saddr6, &saddr6.sin6_addr, port);
   136  			} else {
   137  				failmsg("unknown socket family", "family=%d", hostent->h_addrtype);
   138  			}
   139  			if (fd != -1)
   140  				return fd;
   141  		}
   142  		failmsg("can't connect to manager", "addr=%s:%s", addr, ports);
   143  	}
   144  
   145  	template <typename addr_t>
   146  	static int Connect(addr_t* addr, void* ip, int port)
   147  	{
   148  		auto* saddr = reinterpret_cast<sockaddr*>(addr);
   149  		int fd = socket(saddr->sa_family, SOCK_STREAM, IPPROTO_TCP);
   150  		if (fd == -1) {
   151  			printf("failed to create socket for address family %d", saddr->sa_family);
   152  			return -1;
   153  		}
   154  		char str[128] = {};
   155  		inet_ntop(saddr->sa_family, ip, str, sizeof(str));
   156  		int retcode = connect(fd, saddr, sizeof(*addr));
   157  		while (retcode == -1 && errno == EINTR)
   158  			retcode = ConnectWait(fd);
   159  
   160  		if (retcode != 0) {
   161  			printf("failed to connect to manager at %s:%d: %s\n", str, port, strerror(errno));
   162  			close(fd);
   163  			return -1;
   164  		}
   165  		return fd;
   166  	}
   167  
   168  	Connection(const Connection&) = delete;
   169  	Connection& operator=(const Connection&) = delete;
   170  
   171  	static int ConnectWait(int s)
   172  	{
   173  		struct pollfd pfd[1] = {{.fd = s, .events = POLLOUT}};
   174  		int error = 0;
   175  		socklen_t len = sizeof(error);
   176  
   177  		if (poll(pfd, 1, -1) == -1)
   178  			return -1;
   179  		if (getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len) == -1)
   180  			return -1;
   181  		if (error != 0) {
   182  			errno = error;
   183  			return -1;
   184  		}
   185  		return 0;
   186  	}
   187  };
   188  
   189  // Select is a wrapper around select system call.
   190  class Select
   191  {
   192  public:
   193  	Select()
   194  	{
   195  		FD_ZERO(&rdset_);
   196  	}
   197  
   198  	void Arm(int fd)
   199  	{
   200  		FD_SET(fd, &rdset_);
   201  		max_fd_ = std::max(max_fd_, fd);
   202  	}
   203  
   204  	bool Ready(int fd) const
   205  	{
   206  		return FD_ISSET(fd, &rdset_);
   207  	}
   208  
   209  	void Wait(int ms)
   210  	{
   211  		timespec timeout = {.tv_sec = ms / 1000, .tv_nsec = (ms % 1000) * 1000 * 1000};
   212  		for (;;) {
   213  			if (pselect(max_fd_ + 1, &rdset_, nullptr, nullptr, &timeout, nullptr) >= 0)
   214  				break;
   215  
   216  			if (errno != EINTR && errno != EAGAIN)
   217  				fail("pselect failed");
   218  		}
   219  	}
   220  
   221  	static void Prepare(int fd)
   222  	{
   223  		if (fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK))
   224  			fail("fcntl(O_NONBLOCK) failed");
   225  	}
   226  
   227  private:
   228  	fd_set rdset_;
   229  	int max_fd_ = -1;
   230  
   231  	Select(const Select&) = delete;
   232  	Select& operator=(const Select&) = delete;
   233  };