github.com/aitjcize/Overlord@v0.0.0-20240314041920-104a804cf5e8/scripts/ovl.py (about) 1 #!/usr/bin/env python 2 # Copyright 2015 The Chromium OS Authors. All rights reserved. 3 # Use of this source code is governed by a BSD-style license that can be 4 # found in the LICENSE file. 5 6 import argparse 7 import ast 8 import base64 9 import fcntl 10 import functools 11 import getpass 12 import hashlib 13 import http.client 14 from io import BytesIO 15 import json 16 import logging 17 import os 18 import re 19 import select 20 import signal 21 import socket 22 import ssl 23 import struct 24 import subprocess 25 import sys 26 import tempfile 27 import termios 28 import threading 29 import time 30 import tty 31 import unicodedata # required by pyinstaller, pylint: disable=unused-import 32 import urllib.error 33 import urllib.parse 34 import urllib.request 35 from xmlrpc.client import ServerProxy 36 from xmlrpc.server import SimpleXMLRPCServer 37 38 from ws4py.client import WebSocketBaseClient 39 import yaml 40 41 42 _CERT_DIR = os.path.expanduser('~/.config/ovl') 43 44 _ESCAPE = '~' 45 _BUFSIZ = 8192 46 _DEFAULT_HTTPS_PORT = 443 47 _OVERLORD_CLIENT_DAEMON_PORT = 4488 48 _OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT) 49 50 _CONNECT_TIMEOUT = 3 51 _DEFAULT_HTTP_TIMEOUT = 30 52 _LIST_CACHE_TIMEOUT = 2 53 _DEFAULT_TERMINAL_WIDTH = 80 54 _RETRY_TIMES = 3 55 56 # echo -n overlord | md5sum 57 _HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921' 58 59 # Terminal resize control 60 _CONTROL_START = 128 61 _CONTROL_END = 129 62 63 # Stream control 64 _STDIN_CLOSED = '##STDIN_CLOSED##' 65 66 _SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(), 67 'ovl-ssh-control-') 68 69 _TLS_CERT_FAILED_WARNING = """ 70 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 71 @ WARNING: REMOTE HOST VERIFICATION HAS FAILED! @ 72 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 73 IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! 74 Someone could be eavesdropping on you right now (man-in-the-middle attack)! 75 It is also possible that the server is using a self-signed certificate. 76 The fingerprint for the TLS host certificate sent by the remote host is 77 78 %s 79 80 Do you want to trust this certificate and proceed? [y/N] """ 81 82 _TLS_CERT_CHANGED_WARNING = """ 83 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 84 @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ 85 @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 86 IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! 87 Someone could be eavesdropping on you right now (man-in-the-middle attack)! 88 It is also possible that the TLS host certificate has just been changed. 89 The fingerprint for the TLS host certificate sent by the remote host is 90 91 %s 92 93 Remove '%s' if you still want to proceed. 94 SSL Certificate verification failed.""" 95 96 _USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36" 97 98 99 def GetVersionDigest(): 100 """Return the sha1sum of the current executing script.""" 101 # Check python script by default 102 filename = __file__ 103 104 # If we are running from a frozen binary, we should calculate the checksum 105 # against that binary instead of the python script. 106 # See: https://pyinstaller.readthedocs.io/en/stable/runtime-information.html 107 if getattr(sys, 'frozen', False): 108 filename = sys.executable 109 110 with open(filename, 'rb') as f: 111 return hashlib.sha1(f.read()).hexdigest() 112 113 114 def GetTLSCertPath(host): 115 return os.path.join(_CERT_DIR, '%s.cert' % host) 116 117 118 def UrlOpen(state, url): 119 """Wrapper for urllib.request.urlopen. 120 121 It selects correct HTTP scheme according to self._state.ssl, add HTTP 122 basic auth headers, and add specify correct SSL context. 123 """ 124 url = MakeRequestUrl(state, url) 125 request = urllib.request.Request(url) 126 if state.username is not None and state.password is not None: 127 request.add_header(*BasicAuthHeader(state.username, state.password)) 128 request.add_header('User-Agent', _USER_AGENT) 129 return urllib.request.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT, 130 context=state.SSLContext()) 131 132 133 def GetTLSCertificateSHA1Fingerprint(cert_pem): 134 beg = cert_pem.index('\n') 135 end = cert_pem.rindex('\n', 0, len(cert_pem) - 2) 136 cert_pem = cert_pem[beg:end] # Remove BEGIN/END CERTIFICATE boundary 137 cert_der = base64.b64decode(cert_pem) 138 return hashlib.sha1(cert_der).hexdigest() 139 140 141 def KillGraceful(pid, wait_secs=1): 142 """Kill a process gracefully by first sending SIGTERM, wait for some time, 143 then send SIGKILL to make sure it's killed.""" 144 try: 145 os.kill(pid, signal.SIGTERM) 146 time.sleep(wait_secs) 147 os.kill(pid, signal.SIGKILL) 148 except OSError: 149 pass 150 151 152 def AutoRetry(action_name, retries): 153 """Decorator for retry function call.""" 154 def Wrap(func): 155 @functools.wraps(func) 156 def Loop(*args, **kwargs): 157 for unused_i in range(retries): 158 try: 159 func(*args, **kwargs) 160 except Exception as e: 161 print('error: %s: %s: retrying ...' % (args[0], e)) 162 else: 163 break 164 else: 165 print('error: failed to %s %s' % (action_name, args[0])) 166 return Loop 167 return Wrap 168 169 170 def BasicAuthHeader(user, password): 171 """Return HTTP basic auth header.""" 172 credential = base64.b64encode( 173 b'%s:%s' % (user.encode('utf-8'), password.encode('utf-8'))) 174 return ('Authorization', 'Basic %s' % credential.decode('utf-8')) 175 176 177 def GetTerminalSize(): 178 """Retrieve terminal window size.""" 179 ws = struct.pack('HHHH', 0, 0, 0, 0) 180 ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws) 181 lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws) 182 return lines, columns 183 184 185 def MakeRequestUrl(state, url): 186 return 'http%s://%s' % ('s' if state.ssl else '', url) 187 188 189 class ProgressBar: 190 SIZE_WIDTH = 11 191 SPEED_WIDTH = 10 192 DURATION_WIDTH = 6 193 PERCENTAGE_WIDTH = 8 194 195 def __init__(self, name): 196 self._start_time = time.time() 197 self._name = name 198 self._size = 0 199 self._width = 0 200 self._name_width = 0 201 self._name_max = 0 202 self._stat_width = 0 203 self._max = 0 204 self._CalculateSize() 205 self.SetProgress(0) 206 207 def _CalculateSize(self): 208 self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH 209 self._name_width = int(self._width * 0.3) 210 self._name_max = self._name_width 211 self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH 212 self._max = (self._width - self._name_width - self._stat_width - 213 self.PERCENTAGE_WIDTH) 214 215 def _SizeToHuman(self, size_in_bytes): 216 if size_in_bytes < 1024: 217 unit = 'B' 218 value = size_in_bytes 219 elif size_in_bytes < 1024 ** 2: 220 unit = 'KiB' 221 value = size_in_bytes / 1024 222 elif size_in_bytes < 1024 ** 3: 223 unit = 'MiB' 224 value = size_in_bytes / (1024 ** 2) 225 elif size_in_bytes < 1024 ** 4: 226 unit = 'GiB' 227 value = size_in_bytes / (1024 ** 3) 228 return ' %6.1f %3s' % (value, unit) 229 230 def _SpeedToHuman(self, speed_in_bs): 231 if speed_in_bs < 1024: 232 unit = 'B' 233 value = speed_in_bs 234 elif speed_in_bs < 1024 ** 2: 235 unit = 'K' 236 value = speed_in_bs / 1024 237 elif speed_in_bs < 1024 ** 3: 238 unit = 'M' 239 value = speed_in_bs / (1024 ** 2) 240 elif speed_in_bs < 1024 ** 4: 241 unit = 'G' 242 value = speed_in_bs / (1024 ** 3) 243 return ' %6.1f%s/s' % (value, unit) 244 245 def _DurationToClock(self, duration): 246 return ' %02d:%02d' % (duration // 60, duration % 60) 247 248 def SetProgress(self, percentage, size=None): 249 current_width = GetTerminalSize()[1] 250 if self._width != current_width: 251 self._CalculateSize() 252 253 if size is not None: 254 self._size = size 255 256 elapse_time = time.time() - self._start_time 257 speed = self._size / elapse_time 258 259 size_str = self._SizeToHuman(self._size) 260 speed_str = self._SpeedToHuman(speed) 261 elapse_str = self._DurationToClock(elapse_time) 262 263 width = int(self._max * percentage / 100.0) 264 sys.stdout.write( 265 '%*s' % (- self._name_max, 266 self._name if len(self._name) <= self._name_max else 267 self._name[:self._name_max - 4] + ' ...') + 268 size_str + speed_str + elapse_str + 269 ((' [' + '#' * width + ' ' * (self._max - width) + ']' + 270 '%4d%%' % int(percentage)) if self._max > 2 else '') + '\r') 271 sys.stdout.flush() 272 273 def End(self): 274 self.SetProgress(100.0) 275 sys.stdout.write('\n') 276 sys.stdout.flush() 277 278 279 class DaemonState: 280 """DaemonState is used for storing Overlord state info.""" 281 def __init__(self): 282 self.version_sha1sum = GetVersionDigest() 283 self.host = None 284 self.port = None 285 self.ssl = False 286 self.ssl_self_signed = False 287 self.ssl_verify = True 288 self.ssl_check_hostname = True 289 self.ssh = False 290 self.orig_host = None 291 self.ssh_pid = None 292 self.username = None 293 self.password = None 294 self.selected_mid = None 295 self.forwards = {} 296 self.listing = [] 297 self.last_list = 0 298 299 def SSLContext(self): 300 # No verify. 301 if not self.ssl_verify: 302 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 303 context.check_hostname = False 304 context.verify_mode = ssl.CERT_NONE 305 return context 306 307 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 308 context.check_hostname = self.ssl_check_hostname 309 context.verify_mode = ssl.CERT_REQUIRED 310 311 # Check if self signed certificate exists. 312 ssl_cert_path = GetTLSCertPath(self.host) 313 if os.path.exists(ssl_cert_path): 314 context.load_verify_locations(ssl_cert_path) 315 self.ssl_self_signed = True 316 return context 317 318 return ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 319 320 321 @staticmethod 322 def FromDict(kw): 323 state = DaemonState() 324 325 for k, v in kw.items(): 326 setattr(state, k, v) 327 return state 328 329 330 class OverlordClientDaemon: 331 """Overlord Client Daemon.""" 332 def __init__(self): 333 self._state = DaemonState() 334 self._server = None 335 336 def Start(self): 337 self.StartRPCServer() 338 339 def StartRPCServer(self): 340 self._server = SimpleXMLRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR, 341 logRequests=False, allow_none=True) 342 exports = [ 343 (self.State, 'State'), 344 (self.Ping, 'Ping'), 345 (self.GetPid, 'GetPid'), 346 (self.Connect, 'Connect'), 347 (self.Clients, 'Clients'), 348 (self.SelectClient, 'SelectClient'), 349 (self.AddForward, 'AddForward'), 350 (self.RemoveForward, 'RemoveForward'), 351 (self.RemoveAllForward, 'RemoveAllForward'), 352 ] 353 for func, name in exports: 354 self._server.register_function(func, name) 355 356 pid = os.fork() 357 if pid == 0: 358 for fd in range(3): 359 os.close(fd) 360 self._server.serve_forever() 361 362 @staticmethod 363 def GetRPCServer(): 364 """Returns the Overlord client daemon RPC server.""" 365 server = ServerProxy('http://%s:%d' % _OVERLORD_CLIENT_DAEMON_RPC_ADDR, 366 allow_none=True) 367 try: 368 server.Ping() 369 except Exception: 370 return None 371 return server 372 373 def State(self): 374 return self._state 375 376 def Ping(self): 377 return True 378 379 def GetPid(self): 380 return os.getpid() 381 382 def _GetJSON(self, path): 383 url = '%s:%d%s' % (self._state.host, self._state.port, path) 384 return json.loads(UrlOpen(self._state, url).read()) 385 386 def _TLSEnabled(self): 387 """Determine if TLS is enabled on given server address.""" 388 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 389 try: 390 # Allow any certificate since we only want to check if server talks TLS. 391 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 392 context.check_hostname = False 393 context.verify_mode = ssl.CERT_NONE 394 395 sock = context.wrap_socket(sock, server_hostname=self._state.host) 396 sock.settimeout(_CONNECT_TIMEOUT) 397 sock.connect((self._state.host, self._state.port)) 398 return True 399 except ssl.SSLError: 400 return False 401 except socket.error: # Connect refused or timeout 402 raise 403 except Exception: 404 return False # For whatever reason above failed, assume False 405 406 def _CheckTLSCertificate(self): 407 """Check TLS certificate. 408 409 Returns: 410 A tupple (check_result, if_certificate_is_loaded) 411 """ 412 def _DoConnect(context): 413 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 414 try: 415 sock.settimeout(_CONNECT_TIMEOUT) 416 sock = context.wrap_socket(sock, server_hostname=self._state.host) 417 sock.connect((self._state.host, self._state.port)) 418 except ssl.SSLError: 419 return False 420 finally: 421 sock.close() 422 423 return True 424 425 return _DoConnect(self._state.SSLContext()) 426 427 def Connect(self, host, port, ssh_pid=None, 428 username=None, password=None, orig_host=None, 429 ssl_verify=True, ssl_check_hostname=True): 430 self._state.username = username 431 self._state.password = password 432 self._state.host = host 433 self._state.port = port 434 self._state.ssl = False 435 self._state.ssl_self_signed = False 436 self._state.orig_host = orig_host 437 self._state.ssh_pid = ssh_pid 438 self._state.selected_mid = None 439 self._state.ssl_verify = ssl_verify 440 self._state.ssl_check_hostname = ssl_check_hostname 441 442 ssl_enabled = self._TLSEnabled() 443 if ssl_enabled: 444 result = self._CheckTLSCertificate() 445 if not result: 446 if self._state.ssl_self_signed: 447 return ('SSLCertificateChanged', ssl.get_server_certificate( 448 (self._state.host, self._state.port))) 449 return ('SSLVerifyFailed', ssl.get_server_certificate( 450 (self._state.host, self._state.port))) 451 452 try: 453 self._state.ssl = ssl_enabled 454 UrlOpen(self._state, '%s:%d' % (host, port)) 455 except urllib.error.HTTPError as e: 456 return ('HTTPError', e.getcode(), str(e), e.read().strip()) 457 except Exception as e: 458 return str(e) 459 else: 460 return True 461 462 def Clients(self): 463 if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT: 464 return self._state.listing 465 466 self._state.listing = self._GetJSON('/api/agents/list') 467 self._state.last_list = time.time() 468 return self._state.listing 469 470 def SelectClient(self, mid): 471 self._state.selected_mid = mid 472 473 def AddForward(self, mid, remote, local, pid): 474 self._state.forwards[str(local)] = (mid, remote, pid) 475 476 def RemoveForward(self, local_port): 477 try: 478 unused_mid, unused_remote, pid = self._state.forwards[str(local_port)] 479 KillGraceful(pid) 480 del self._state.forwards[local_port] 481 except (KeyError, OSError): 482 pass 483 484 def RemoveAllForward(self): 485 for unused_mid, unused_remote, pid in self._state.forwards.values(): 486 try: 487 KillGraceful(pid) 488 except OSError: 489 pass 490 self._state.forwards = {} 491 492 493 class SSLEnabledWebSocketBaseClient(WebSocketBaseClient): 494 def __init__(self, state, *args, **kwargs): 495 super().__init__(ssl_context=state.SSLContext(), *args, **kwargs) 496 497 498 class TerminalWebSocketClient(SSLEnabledWebSocketBaseClient): 499 def __init__(self, state, mid, escape, *args, **kwargs): 500 super().__init__(state, *args, **kwargs) 501 self._mid = mid 502 self._escape = escape 503 self._stdin_fd = sys.stdin.fileno() 504 self._old_termios = None 505 506 def handshake_ok(self): 507 pass 508 509 def opened(self): 510 nonlocals = {'size': (80, 40)} 511 512 def _ResizeWindow(): 513 size = GetTerminalSize() 514 if size != nonlocals['size']: # Size not changed, ignore 515 control = {'command': 'resize', 'params': list(size)} 516 payload = (_CONTROL_START.to_bytes(1, 'big') + 517 json.dumps(control).encode('utf-8') + 518 _CONTROL_END.to_bytes(1, 'big')) 519 nonlocals['size'] = size 520 try: 521 self.send(payload, binary=True) 522 except Exception as e: 523 logging.exception(e) 524 525 def _FeedInput(): 526 self._old_termios = termios.tcgetattr(self._stdin_fd) 527 tty.setraw(self._stdin_fd) 528 529 READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3) 530 531 try: 532 state = READY 533 while True: 534 # Check if terminal is resized 535 _ResizeWindow() 536 537 ch = sys.stdin.read(1) 538 539 # Scan for escape sequence 540 if self._escape: 541 if state == READY: 542 state = ENTER_PRESSED if ch == chr(0x0d) else READY 543 elif state == ENTER_PRESSED: 544 state = ESCAPE_PRESSED if ch == self._escape else READY 545 elif state == ESCAPE_PRESSED: 546 if ch == '.': 547 self.close() 548 break 549 else: 550 state = READY 551 552 self.send(ch) 553 except (KeyboardInterrupt, RuntimeError): 554 pass 555 556 t = threading.Thread(target=_FeedInput) 557 t.daemon = True 558 t.start() 559 560 def closed(self, code, reason=None): 561 del code, reason # Unused. 562 termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios) 563 print('\nConnection to %s closed.' % self._mid) 564 565 def received_message(self, message): 566 if message.is_binary: 567 sys.stdout.buffer.write(message.data) 568 sys.stdout.flush() 569 570 571 class ShellWebSocketClient(SSLEnabledWebSocketBaseClient): 572 def __init__(self, state, output, *args, **kwargs): 573 """Constructor. 574 575 Args: 576 output: output file object. 577 """ 578 super().__init__(state, *args, **kwargs) 579 self._output = output 580 self._input_thread = threading.Thread(target=self._FeedInput) 581 self._stop = threading.Event() 582 583 def handshake_ok(self): 584 pass 585 586 def _FeedInput(self): 587 try: 588 while True: 589 rd, unused_w, unused_x = select.select([sys.stdin], [], [], 0.5) 590 if self._stop.is_set(): 591 break 592 593 if sys.stdin in rd: 594 data = sys.stdin.buffer.read() 595 if not data: 596 self.send(_STDIN_CLOSED * 2) 597 break 598 self.send(data, binary=True) 599 except (KeyboardInterrupt, RuntimeError): 600 pass 601 602 def opened(self): 603 self._input_thread.start() 604 605 def closed(self, code, reason=None): 606 self._stop.set() 607 self._input_thread.join() 608 609 def received_message(self, message): 610 if message.is_binary: 611 self._output.write(message.data) 612 self._output.flush() 613 614 615 class ForwarderWebSocketClient(SSLEnabledWebSocketBaseClient): 616 def __init__(self, state, sock, *args, **kwargs): 617 super().__init__(state, *args, **kwargs) 618 self._sock = sock 619 self._input_thread = threading.Thread(target=self._FeedInput) 620 self._stop = threading.Event() 621 622 def handshake_ok(self): 623 pass 624 625 def _FeedInput(self): 626 try: 627 self._sock.setblocking(False) 628 while True: 629 rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5) 630 if self._stop.is_set(): 631 break 632 if self._sock in rd: 633 data = self._sock.recv(_BUFSIZ) 634 if not data: 635 self.close() 636 break 637 self.send(data, binary=True) 638 except Exception: 639 pass 640 finally: 641 self._sock.close() 642 643 def opened(self): 644 self._input_thread.start() 645 646 def closed(self, code, reason=None): 647 del code, reason # Unused. 648 self._stop.set() 649 self._input_thread.join() 650 sys.exit(0) 651 652 def received_message(self, message): 653 if message.is_binary: 654 self._sock.send(message.data) 655 656 657 def Arg(*args, **kwargs): 658 return (args, kwargs) 659 660 661 def Command(command, help_msg=None, args=None): 662 """Decorator for adding argparse parameter for a method.""" 663 if args is None: 664 args = [] 665 def WrapFunc(func): 666 @functools.wraps(func) 667 def Wrapped(*args, **kwargs): 668 return func(*args, **kwargs) 669 # pylint: disable=protected-access 670 Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args} 671 return Wrapped 672 return WrapFunc 673 674 675 def ParseMethodSubCommands(cls): 676 """Decorator for a class using the @Command decorator. 677 678 This decorator retrieve command info from each method and append it in to the 679 SUBCOMMANDS class variable, which is later used to construct parser. 680 """ 681 for unused_key, method in cls.__dict__.items(): 682 if hasattr(method, '__arg_attr'): 683 # pylint: disable=protected-access 684 cls.SUBCOMMANDS.append(method.__arg_attr) 685 return cls 686 687 688 @ParseMethodSubCommands 689 class OverlordCLIClient: 690 """Overlord command line interface client.""" 691 692 SUBCOMMANDS = [] 693 694 def __init__(self): 695 self._parser = self._BuildParser() 696 self._selected_mid = None 697 self._server = None 698 self._state = None 699 self._escape = None 700 701 def _BuildParser(self): 702 root_parser = argparse.ArgumentParser(prog='ovl') 703 subparsers = root_parser.add_subparsers(title='subcommands', 704 dest='subcommand') 705 subparsers.required = True 706 707 root_parser.add_argument('-s', dest='selected_mid', action='store', 708 default=None, 709 help='select target to execute command on') 710 root_parser.add_argument('-S', dest='select_mid_before_action', 711 action='store_true', default=False, 712 help='select target before executing command') 713 root_parser.add_argument('-e', dest='escape', metavar='ESCAPE_CHAR', 714 action='store', default=_ESCAPE, type=str, 715 help='set shell escape character, \'none\' to ' 716 'disable escape completely') 717 718 for attr in self.SUBCOMMANDS: 719 parser = subparsers.add_parser(attr['command'], help=attr['help']) 720 parser.set_defaults(which=attr['command']) 721 for arg in attr['args']: 722 parser.add_argument(*arg[0], **arg[1]) 723 724 return root_parser 725 726 def Main(self): 727 # We want to pass the rest of arguments after shell command directly to the 728 # function without parsing it. 729 try: 730 index = sys.argv.index('shell') 731 except ValueError: 732 args = self._parser.parse_args() 733 else: 734 args = self._parser.parse_args(sys.argv[1:index + 1]) 735 736 command = args.which 737 self._selected_mid = args.selected_mid 738 739 if args.escape and args.escape != 'none': 740 self._escape = args.escape[0] 741 742 if command == 'start-server': 743 self.StartServer() 744 return 745 if command == 'kill-server': 746 self.KillServer() 747 return 748 749 self.CheckDaemon() 750 if command == 'status': 751 self.Status() 752 return 753 if command == 'connect': 754 self.Connect(args) 755 return 756 757 # The following command requires connection to the server 758 self.CheckConnection() 759 760 if args.select_mid_before_action: 761 self.SelectClient(store=False) 762 763 if command == 'select': 764 self.SelectClient(args) 765 elif command == 'ls': 766 self.ListClients(args) 767 elif command == 'shell': 768 command = sys.argv[sys.argv.index('shell') + 1:] 769 self.Shell(command) 770 elif command == 'push': 771 self.Push(args) 772 elif command == 'pull': 773 self.Pull(args) 774 elif command == 'forward': 775 self.Forward(args) 776 777 def _SaveTLSCertificate(self, host, cert_pem): 778 try: 779 os.makedirs(_CERT_DIR) 780 except Exception: 781 pass 782 with open(GetTLSCertPath(host), 'w') as f: 783 f.write(cert_pem) 784 785 def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None): 786 """Perform HTTP POST and upload file to Overlord. 787 788 To minimize the external dependencies, we construct the HTTP post request 789 by ourselves. 790 """ 791 url = MakeRequestUrl(self._state, url) 792 size = os.stat(filename).st_size 793 boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC 794 CRLF = '\r\n' 795 parse = urllib.parse.urlparse(url) 796 797 part_headers = [ 798 '--' + boundary, 799 'Content-Disposition: form-data; name="file"; ' 800 'filename="%s"' % os.path.basename(filename), 801 'Content-Type: application/octet-stream', 802 '', '' 803 ] 804 part_header = CRLF.join(part_headers) 805 end_part = CRLF + '--' + boundary + '--' + CRLF 806 807 content_length = len(part_header) + size + len(end_part) 808 if parse.scheme == 'http': 809 h = http.client.HTTPConnection(parse.netloc) 810 else: 811 h = http.client.HTTPSConnection(parse.netloc, 812 context=self._state.SSLContext()) 813 814 post_path = url[url.index(parse.netloc) + len(parse.netloc):] 815 h.putrequest('POST', post_path) 816 h.putheader('Content-Length', content_length) 817 h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary) 818 819 if user and passwd: 820 h.putheader(*BasicAuthHeader(user, passwd)) 821 h.endheaders() 822 h.send(part_header.encode('utf-8')) 823 824 count = 0 825 with open(filename, 'rb') as f: 826 while True: 827 data = f.read(_BUFSIZ) 828 if not data: 829 break 830 count += len(data) 831 if progress: 832 progress(count * 100 // size, count) 833 h.send(data) 834 835 h.send(end_part.encode('utf-8')) 836 progress(100) 837 838 if count != size: 839 logging.warning('file changed during upload, upload may be truncated.') 840 841 resp = h.getresponse() 842 return resp.status == 200 843 844 def CheckDaemon(self): 845 self._server = OverlordClientDaemon.GetRPCServer() 846 if self._server is None: 847 print('* daemon not running, starting it now on port %d ... *' % 848 _OVERLORD_CLIENT_DAEMON_PORT) 849 self.StartServer() 850 851 self._state = DaemonState.FromDict(self._server.State()) 852 sha1sum = GetVersionDigest() 853 854 if sha1sum != self._state.version_sha1sum: 855 print('ovl server is out of date. killing...') 856 KillGraceful(self._server.GetPid()) 857 self.StartServer() 858 859 def GetSSHControlFile(self, host): 860 return _SSH_CONTROL_SOCKET_PREFIX + host 861 862 def SSHTunnel(self, user, host, port): 863 """SSH forward the remote overlord server. 864 865 Overlord server may not have port 9000 open to the public network, in such 866 case we can SSH forward the port to 127.0.0.1. 867 """ 868 869 control_file = self.GetSSHControlFile(host) 870 try: 871 os.unlink(control_file) 872 except Exception: 873 pass 874 875 with subprocess.Popen([ 876 'ssh', '-Nf', '-M', '-S', control_file, '-L', '9000:127.0.0.1:9000', 877 '-p', 878 str(port), 879 '%s%s' % (user + '@' if user else '', host) 880 ]): 881 pass 882 883 p = subprocess.Popen([ 884 'ssh', 885 '-S', control_file, 886 '-O', 'check', host, 887 ], stderr=subprocess.PIPE) 888 unused_stdout, stderr = p.communicate() 889 890 s = re.search(r'pid=(\d+)', stderr) 891 if s: 892 return int(s.group(1)) 893 894 raise RuntimeError('can not establish ssh connection') 895 896 def CheckConnection(self): 897 if self._state.host is None: 898 raise RuntimeError('not connected to any server, abort') 899 900 try: 901 self._server.Clients() 902 except Exception: 903 raise RuntimeError('remote server disconnected, abort') from None 904 905 if self._state.ssh_pid is not None: 906 with subprocess.Popen( 907 ['kill', '-0', str(self._state.ssh_pid)], stdout=subprocess.PIPE, 908 stderr=subprocess.PIPE) as p: 909 pass 910 if p.returncode != 0: 911 raise RuntimeError('ssh tunnel disconnected, please re-connect') 912 913 def CheckClient(self): 914 if self._selected_mid is None: 915 if self._state.selected_mid is None: 916 raise RuntimeError('No client is selected') 917 self._selected_mid = self._state.selected_mid 918 919 if not any(client['mid'] == self._selected_mid 920 for client in self._server.Clients()): 921 raise RuntimeError('client %s disappeared' % self._selected_mid) 922 923 def CheckOutput(self, command): 924 headers = [] 925 if self._state.username is not None and self._state.password is not None: 926 headers.append(BasicAuthHeader(self._state.username, 927 self._state.password)) 928 929 scheme = 'ws%s://' % ('s' if self._state.ssl else '') 930 bio = BytesIO() 931 ws = ShellWebSocketClient( 932 self._state, bio, scheme + '%s:%d/api/agent/shell/%s?command=%s' % ( 933 self._state.host, self._state.port, 934 urllib.parse.quote(self._selected_mid), 935 urllib.parse.quote(command)), 936 headers=headers) 937 ws.connect() 938 ws.run() 939 return bio.getvalue().decode('utf-8') 940 941 @Command('status', 'show Overlord connection status') 942 def Status(self): 943 if self._state.host is None: 944 print('Not connected to any host.') 945 else: 946 if self._state.ssh_pid is not None: 947 print('Connected to %s with SSH tunneling.' % self._state.orig_host) 948 else: 949 print('Connected to %s:%d.' % (self._state.host, self._state.port)) 950 951 if self._selected_mid is None: 952 self._selected_mid = self._state.selected_mid 953 954 if self._selected_mid is None: 955 print('No client is selected.') 956 else: 957 print('Client %s selected.' % self._selected_mid) 958 959 @Command('connect', 'connect to Overlord server', [ 960 Arg('host', metavar='HOST', type=str, default='127.0.0.1', 961 help='Overlord hostname/IP'), 962 Arg('port', metavar='PORT', type=int, nargs='?', 963 default=_DEFAULT_HTTPS_PORT, help='Overlord port'), 964 Arg('-f', '--forward', dest='ssh_forward', default=False, 965 action='store_true', 966 help='connect with SSH forwarding to the host'), 967 Arg('-p', '--ssh-port', dest='ssh_port', default=22, 968 type=int, help='SSH server port for SSH forwarding'), 969 Arg('-l', '--ssh-login', dest='ssh_login', default='', 970 type=str, help='SSH server login name for SSH forwarding'), 971 Arg('-u', '--user', dest='user', default=None, 972 type=str, help='Overlord HTTP auth username'), 973 Arg('-w', '--passwd', dest='passwd', default=None, type=str, 974 help='Overlord HTTP auth password'), 975 Arg('--ssl-no-verify', dest='ssl_verify', 976 default=True, action='store_false', 977 help='Ignore SSL cert verification'), 978 Arg('--ssl-no-check-hostname', dest='ssl_check_hostname', 979 default=True, action='store_false', 980 help='Ignore SSL cert hostname check')]) 981 def Connect(self, args): 982 ssh_pid = None 983 host = args.host 984 orig_host = args.host 985 986 if args.ssh_forward: 987 # Kill previous SSH tunnel 988 self.KillSSHTunnel() 989 990 ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port) 991 host = '127.0.0.1' 992 993 username_provided = args.user is not None 994 password_provided = args.passwd is not None 995 prompt = False 996 997 for unused_i in range(3): # pylint: disable=too-many-nested-blocks 998 try: 999 if prompt: 1000 if not username_provided: 1001 args.user = input('Username: ') 1002 if not password_provided: 1003 args.passwd = getpass.getpass('Password: ') 1004 1005 ret = self._server.Connect(host, args.port, ssh_pid, args.user, 1006 args.passwd, orig_host, 1007 args.ssl_verify, args.ssl_check_hostname) 1008 if isinstance(ret, list): 1009 if ret[0].startswith('SSL'): 1010 cert_pem = ret[1] 1011 fp = GetTLSCertificateSHA1Fingerprint(cert_pem) 1012 fp_text = ':'.join([fp[i:i+2] for i in range(0, len(fp), 2)]) 1013 1014 if ret[0] == 'SSLCertificateChanged': 1015 print(_TLS_CERT_CHANGED_WARNING % (fp_text, GetTLSCertPath(host))) 1016 return 1017 if ret[0] == 'SSLVerifyFailed': 1018 print(_TLS_CERT_FAILED_WARNING % (fp_text), end='') 1019 response = input() 1020 if response.lower() in ['y', 'ye', 'yes']: 1021 self._SaveTLSCertificate(host, cert_pem) 1022 print('TLS host Certificate trusted, you will not be prompted ' 1023 'next time.\n') 1024 continue 1025 print('connection aborted.') 1026 return 1027 if ret[0] == 'HTTPError': 1028 code, except_str, body = ret[1:] 1029 if code == 401: 1030 print('connect: %s' % body) 1031 prompt = True 1032 if not username_provided or not password_provided: 1033 continue 1034 break 1035 logging.error('%s; %s', except_str, body) 1036 1037 if ret is not True: 1038 print('can not connect to %s: %s' % (host, ret)) 1039 else: 1040 print('connection to %s:%d established.' % (host, args.port)) 1041 except Exception as e: 1042 logging.error(e) 1043 else: 1044 break 1045 1046 @Command('start-server', 'start overlord CLI client server') 1047 def StartServer(self): 1048 self._server = OverlordClientDaemon.GetRPCServer() 1049 if self._server is None: 1050 OverlordClientDaemon().Start() 1051 time.sleep(1) 1052 self._server = OverlordClientDaemon.GetRPCServer() 1053 if self._server is not None: 1054 print('* daemon started successfully *\n') 1055 1056 @Command('kill-server', 'kill overlord CLI client server') 1057 def KillServer(self): 1058 self._server = OverlordClientDaemon.GetRPCServer() 1059 if self._server is None: 1060 return 1061 1062 self._state = DaemonState.FromDict(self._server.State()) 1063 1064 # Kill SSH Tunnel 1065 self.KillSSHTunnel() 1066 1067 # Kill server daemon 1068 KillGraceful(self._server.GetPid()) 1069 1070 def KillSSHTunnel(self): 1071 if self._state.ssh_pid is not None: 1072 KillGraceful(self._state.ssh_pid) 1073 1074 def _FilterClients(self, clients, prop_filters, mid=None): 1075 def _ClientPropertiesMatch(client, key, regex): 1076 try: 1077 return bool(re.search(regex, client['properties'][key])) 1078 except KeyError: 1079 return False 1080 1081 for prop_filter in prop_filters: 1082 key, sep, regex = prop_filter.partition('=') 1083 if not sep: 1084 # The filter doesn't contains =. 1085 raise ValueError('Invalid filter condition %r' % filter) 1086 clients = [c for c in clients if _ClientPropertiesMatch(c, key, regex)] 1087 1088 if mid is not None: 1089 client = next((c for c in clients if c['mid'] == mid), None) 1090 if client: 1091 return [client] 1092 clients = [c for c in clients if c['mid'].startswith(mid)] 1093 return clients 1094 1095 @Command('ls', 'list clients', [ 1096 Arg('-f', '--filter', default=[], dest='filters', action='append', 1097 help=('Conditions to filter clients by properties. ' 1098 'Should be in form "key=regex", where regex is the regular ' 1099 'expression that should be found in the value. ' 1100 'Multiple --filter arguments would be ANDed.')), 1101 Arg('-v', '--verbose', default=False, action='store_true', 1102 help='Print properties of each client.') 1103 ]) 1104 def ListClients(self, args): 1105 clients = self._FilterClients(self._server.Clients(), args.filters) 1106 for client in clients: 1107 if args.verbose: 1108 print(yaml.safe_dump(client, default_flow_style=False)) 1109 else: 1110 print(client['mid']) 1111 1112 @Command('select', 'select default client', [ 1113 Arg('-f', '--filter', default=[], dest='filters', action='append', 1114 help=('Conditions to filter clients by properties. ' 1115 'Should be in form "key=regex", where regex is the regular ' 1116 'expression that should be found in the value. ' 1117 'Multiple --filter arguments would be ANDed.')), 1118 Arg('mid', metavar='mid', nargs='?', default=None)]) 1119 def SelectClient(self, args=None, store=True): 1120 mid = args.mid if args is not None else None 1121 filters = args.filters if args is not None else [] 1122 clients = self._FilterClients(self._server.Clients(), filters, mid=mid) 1123 1124 if not clients: 1125 raise RuntimeError('select: client not found') 1126 if len(clients) == 1: 1127 mid = clients[0]['mid'] 1128 else: 1129 # This case would not happen when args.mid is specified. 1130 print('Select from the following clients:') 1131 for i, client in enumerate(clients): 1132 print(' %d. %s' % (i + 1, client['mid'])) 1133 1134 print('\nSelection: ', end='') 1135 try: 1136 choice = int(input()) - 1 1137 mid = clients[choice]['mid'] 1138 except ValueError: 1139 raise RuntimeError('select: invalid selection') from None 1140 except IndexError: 1141 raise RuntimeError('select: selection out of range') from None 1142 1143 self._selected_mid = mid 1144 if store: 1145 self._server.SelectClient(mid) 1146 print('Client %s selected' % mid) 1147 1148 @Command('shell', 'open a shell or execute a shell command', [ 1149 Arg('command', metavar='CMD', nargs='?', help='command to execute')]) 1150 def Shell(self, command=None): 1151 if command is None: 1152 command = [] 1153 self.CheckClient() 1154 1155 headers = [] 1156 if self._state.username is not None and self._state.password is not None: 1157 headers.append(BasicAuthHeader(self._state.username, 1158 self._state.password)) 1159 1160 scheme = 'ws%s://' % ('s' if self._state.ssl else '') 1161 if command: 1162 cmd = ' '.join(command) 1163 ws = ShellWebSocketClient( 1164 self._state, sys.stdout.buffer, 1165 scheme + '%s:%d/api/agent/shell/%s?command=%s' % ( 1166 self._state.host, self._state.port, 1167 urllib.parse.quote(self._selected_mid), urllib.parse.quote(cmd)), 1168 headers=headers) 1169 else: 1170 ws = TerminalWebSocketClient( 1171 self._state, self._selected_mid, self._escape, 1172 scheme + '%s:%d/api/agent/tty/%s' % ( 1173 self._state.host, self._state.port, 1174 urllib.parse.quote(self._selected_mid)), 1175 headers=headers) 1176 try: 1177 ws.connect() 1178 ws.run() 1179 except socket.error as e: 1180 if e.errno == 32: # Broken pipe 1181 pass 1182 else: 1183 raise 1184 1185 @Command('push', 'push a file or directory to remote', [ 1186 Arg('srcs', nargs='+', metavar='SOURCE'), 1187 Arg('dst', metavar='DESTINATION')]) 1188 def Push(self, args): 1189 self.CheckClient() 1190 1191 @AutoRetry('push', _RETRY_TIMES) 1192 def _push(src, dst): 1193 src_base = os.path.basename(src) 1194 1195 # Local file is a link 1196 if os.path.islink(src): 1197 pbar = ProgressBar(src_base) 1198 link_path = os.readlink(src) 1199 self.CheckOutput('mkdir -p %(dirname)s; ' 1200 'if [ -d "%(dst)s" ]; then ' 1201 'ln -sf "%(link_path)s" "%(dst)s/%(link_name)s"; ' 1202 'else ln -sf "%(link_path)s" "%(dst)s"; fi' % 1203 dict(dirname=os.path.dirname(dst), 1204 link_path=link_path, dst=dst, 1205 link_name=src_base)) 1206 pbar.End() 1207 return 1208 1209 mode = '0%o' % (0x1FF & os.stat(src).st_mode) 1210 url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' % 1211 (self._state.host, self._state.port, 1212 urllib.parse.quote(self._selected_mid), dst, mode)) 1213 try: 1214 UrlOpen(self._state, url + '&filename=%s' % src_base) 1215 except urllib.error.HTTPError as e: 1216 msg = json.loads(e.read()).get('error', None) 1217 raise RuntimeError('push: %s' % msg) from None 1218 1219 pbar = ProgressBar(src_base) 1220 self._HTTPPostFile(url, src, pbar.SetProgress, 1221 self._state.username, self._state.password) 1222 pbar.End() 1223 1224 def _push_single_target(src, dst): 1225 if os.path.isdir(src): 1226 dst_exists = ast.literal_eval(self.CheckOutput( 1227 'stat %s >/dev/null 2>&1 && echo True || echo False' % dst)) 1228 for root, unused_x, files in os.walk(src): 1229 # If destination directory does not exist, we should strip the first 1230 # layer of directory. For example: src_dir contains a single file 'A' 1231 # 1232 # push src_dir dest_dir 1233 # 1234 # If dest_dir exists, the resulting directory structure should be: 1235 # dest_dir/src_dir/A 1236 # If dest_dir does not exist, the resulting directory structure should 1237 # be: 1238 # dest_dir/A 1239 dst_root = os.path.basename(root) if dst_exists else '' 1240 for name in files: 1241 _push(os.path.join(root, name), 1242 os.path.join(dst, dst_root, name)) 1243 else: 1244 _push(src, dst) 1245 1246 if len(args.srcs) > 1: 1247 dst_type = self.CheckOutput('stat \'%s\' --printf \'%%F\' ' 1248 '2>/dev/null' % args.dst).strip() 1249 if not dst_type: 1250 raise RuntimeError('push: %s: No such file or directory' % args.dst) 1251 if dst_type != 'directory': 1252 raise RuntimeError('push: %s: Not a directory' % args.dst) 1253 1254 for src in args.srcs: 1255 if not os.path.exists(src): 1256 raise RuntimeError('push: can not stat "%s": no such file or directory' 1257 % src) 1258 if not os.access(src, os.R_OK): 1259 raise RuntimeError('push: can not open "%s" for reading' % src) 1260 1261 _push_single_target(src, args.dst) 1262 1263 @Command('pull', 'pull a file or directory from remote', [ 1264 Arg('src', metavar='SOURCE'), 1265 Arg('dst', metavar='DESTINATION', default='.', nargs='?')]) 1266 def Pull(self, args): 1267 self.CheckClient() 1268 1269 @AutoRetry('pull', _RETRY_TIMES) 1270 def _pull(src, dst, ftype, perm=0o644, link=None): 1271 try: 1272 os.makedirs(os.path.dirname(dst)) 1273 except Exception: 1274 pass 1275 1276 src_base = os.path.basename(src) 1277 1278 # Remote file is a link 1279 if ftype == 'l': 1280 pbar = ProgressBar(src_base) 1281 if os.path.exists(dst): 1282 os.remove(dst) 1283 os.symlink(link, dst) 1284 pbar.End() 1285 return 1286 1287 url = ('%s:%d/api/agent/download/%s?filename=%s' % 1288 (self._state.host, self._state.port, 1289 urllib.parse.quote(self._selected_mid), urllib.parse.quote(src))) 1290 try: 1291 h = UrlOpen(self._state, url) 1292 except urllib.error.HTTPError as e: 1293 msg = json.loads(e.read()).get('error', 'unkown error') 1294 raise RuntimeError('pull: %s' % msg) from None 1295 except KeyboardInterrupt: 1296 return 1297 1298 pbar = ProgressBar(src_base) 1299 with open(dst, 'wb') as f: 1300 os.fchmod(f.fileno(), perm) 1301 total_size = int(h.headers.get('Content-Length')) 1302 downloaded_size = 0 1303 1304 while True: 1305 data = h.read(_BUFSIZ) 1306 if not data: 1307 break 1308 downloaded_size += len(data) 1309 pbar.SetProgress(downloaded_size * 100 / total_size, 1310 downloaded_size) 1311 f.write(data) 1312 pbar.End() 1313 1314 # Use find to get a listing of all files under a root directory. The 'stat' 1315 # command is used to retrieve the filename and it's filemode. 1316 output = self.CheckOutput( 1317 'cd $HOME; ' 1318 'stat "%(src)s" >/dev/null && ' 1319 'find "%(src)s" \'(\' -type f -o -type l \')\' ' 1320 '-printf \'%%m\t%%p\t%%y\t%%l\n\'' 1321 % {'src': args.src}) 1322 1323 # We got error from the stat command 1324 if output.startswith('stat: '): 1325 sys.stderr.write(output) 1326 return 1327 1328 entries = output.strip('\n').split('\n') 1329 common_prefix = os.path.dirname(args.src) 1330 1331 if len(entries) == 1: 1332 entry = entries[0] 1333 perm, src_path, ftype, link = entry.split('\t', -1) 1334 if os.path.isdir(args.dst): 1335 dst = os.path.join(args.dst, os.path.basename(src_path)) 1336 else: 1337 dst = args.dst 1338 _pull(src_path, dst, ftype, int(perm, base=8), link) 1339 else: 1340 if not os.path.exists(args.dst): 1341 common_prefix = args.src 1342 1343 for entry in entries: 1344 perm, src_path, ftype, link = entry.split('\t', -1) 1345 rel_dst = src_path[len(common_prefix):].lstrip('/') 1346 _pull(src_path, os.path.join(args.dst, rel_dst), ftype, 1347 int(perm, base=8), link) 1348 1349 @Command('forward', 'forward remote port to local port', [ 1350 Arg('--list', dest='list_all', action='store_true', default=False, 1351 help='list all port forwarding sessions'), 1352 Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int, 1353 default=None, 1354 help='remove port forwarding for local port LOCAL_PORT'), 1355 Arg('--remove-all', dest='remove_all', action='store_true', 1356 default=False, help='remove all port forwarding'), 1357 Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'), 1358 Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')]) 1359 def Forward(self, args): 1360 if args.list_all: 1361 max_len = 10 1362 if self._state.forwards: 1363 max_len = max([len(v[0]) for v in self._state.forwards.values()]) 1364 1365 print('%-*s %-8s %-8s' % (max_len, 'Client', 'Remote', 'Local')) 1366 for local in sorted(self._state.forwards.keys()): 1367 value = self._state.forwards[local] 1368 print('%-*s %-8s %-8s' % (max_len, value[0], value[1], local)) 1369 return 1370 1371 if args.remove_all: 1372 self._server.RemoveAllForward() 1373 return 1374 1375 if args.remove: 1376 self._server.RemoveForward(args.remove) 1377 return 1378 1379 self.CheckClient() 1380 1381 if args.remote is None: 1382 raise RuntimeError('remote port not specified') 1383 1384 if args.local is None: 1385 args.local = args.remote 1386 remote = int(args.remote) 1387 local = int(args.local) 1388 1389 def HandleConnection(conn): 1390 headers = [] 1391 if self._state.username is not None and self._state.password is not None: 1392 headers.append(BasicAuthHeader(self._state.username, 1393 self._state.password)) 1394 1395 scheme = 'ws%s://' % ('s' if self._state.ssl else '') 1396 ws = ForwarderWebSocketClient( 1397 self._state, conn, 1398 scheme + '%s:%d/api/agent/forward/%s?port=%d' % ( 1399 self._state.host, self._state.port, 1400 urllib.parse.quote(self._selected_mid), remote), 1401 headers=headers) 1402 try: 1403 ws.connect() 1404 ws.run() 1405 except Exception as e: 1406 print('error: %s' % e) 1407 finally: 1408 ws.close() 1409 1410 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 1411 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 1412 server.bind(('0.0.0.0', local)) 1413 server.listen(5) 1414 1415 pid = os.fork() 1416 if pid == 0: 1417 while True: 1418 conn, unused_addr = server.accept() 1419 t = threading.Thread(target=HandleConnection, args=(conn,)) 1420 t.daemon = True 1421 t.start() 1422 else: 1423 self._server.AddForward(self._selected_mid, remote, local, pid) 1424 1425 1426 def main(): 1427 # Setup logging format 1428 logger = logging.getLogger() 1429 logger.setLevel(logging.INFO) 1430 handler = logging.StreamHandler() 1431 formatter = logging.Formatter('%(asctime)s %(message)s', '%Y/%m/%d %H:%M:%S') 1432 handler.setFormatter(formatter) 1433 logger.addHandler(handler) 1434 1435 ovl = OverlordCLIClient() 1436 try: 1437 ovl.Main() 1438 except KeyboardInterrupt: 1439 print('Ctrl-C received, abort') 1440 except Exception as e: 1441 print(f'error: {e}') 1442 1443 1444 if __name__ == '__main__': 1445 main()