github.com/aitjcize/Overlord@v0.0.0-20240314041920-104a804cf5e8/scripts/ovl.py (about)

     1  #!/usr/bin/env python
     2  # Copyright 2015 The Chromium OS Authors. All rights reserved.
     3  # Use of this source code is governed by a BSD-style license that can be
     4  # found in the LICENSE file.
     5  
     6  import argparse
     7  import ast
     8  import base64
     9  import fcntl
    10  import functools
    11  import getpass
    12  import hashlib
    13  import http.client
    14  from io import BytesIO
    15  import json
    16  import logging
    17  import os
    18  import re
    19  import select
    20  import signal
    21  import socket
    22  import ssl
    23  import struct
    24  import subprocess
    25  import sys
    26  import tempfile
    27  import termios
    28  import threading
    29  import time
    30  import tty
    31  import unicodedata  # required by pyinstaller, pylint: disable=unused-import
    32  import urllib.error
    33  import urllib.parse
    34  import urllib.request
    35  from xmlrpc.client import ServerProxy
    36  from xmlrpc.server import SimpleXMLRPCServer
    37  
    38  from ws4py.client import WebSocketBaseClient
    39  import yaml
    40  
    41  
    42  _CERT_DIR = os.path.expanduser('~/.config/ovl')
    43  
    44  _ESCAPE = '~'
    45  _BUFSIZ = 8192
    46  _DEFAULT_HTTPS_PORT = 443
    47  _OVERLORD_CLIENT_DAEMON_PORT = 4488
    48  _OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT)
    49  
    50  _CONNECT_TIMEOUT = 3
    51  _DEFAULT_HTTP_TIMEOUT = 30
    52  _LIST_CACHE_TIMEOUT = 2
    53  _DEFAULT_TERMINAL_WIDTH = 80
    54  _RETRY_TIMES = 3
    55  
    56  # echo -n overlord | md5sum
    57  _HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921'
    58  
    59  # Terminal resize control
    60  _CONTROL_START = 128
    61  _CONTROL_END = 129
    62  
    63  # Stream control
    64  _STDIN_CLOSED = '##STDIN_CLOSED##'
    65  
    66  _SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(),
    67                                            'ovl-ssh-control-')
    68  
    69  _TLS_CERT_FAILED_WARNING = """
    70  @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    71  @ WARNING: REMOTE HOST VERIFICATION HAS FAILED! @
    72  @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    73  IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
    74  Someone could be eavesdropping on you right now (man-in-the-middle attack)!
    75  It is also possible that the server is using a self-signed certificate.
    76  The fingerprint for the TLS host certificate sent by the remote host is
    77  
    78  %s
    79  
    80  Do you want to trust this certificate and proceed? [y/N] """
    81  
    82  _TLS_CERT_CHANGED_WARNING = """
    83  @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    84  @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
    85  @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    86  IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
    87  Someone could be eavesdropping on you right now (man-in-the-middle attack)!
    88  It is also possible that the TLS host certificate has just been changed.
    89  The fingerprint for the TLS host certificate sent by the remote host is
    90  
    91  %s
    92  
    93  Remove '%s' if you still want to proceed.
    94  SSL Certificate verification failed."""
    95  
    96  _USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36"
    97  
    98  
    99  def GetVersionDigest():
   100    """Return the sha1sum of the current executing script."""
   101    # Check python script by default
   102    filename = __file__
   103  
   104    # If we are running from a frozen binary, we should calculate the checksum
   105    # against that binary instead of the python script.
   106    # See: https://pyinstaller.readthedocs.io/en/stable/runtime-information.html
   107    if getattr(sys, 'frozen', False):
   108      filename = sys.executable
   109  
   110    with open(filename, 'rb') as f:
   111      return hashlib.sha1(f.read()).hexdigest()
   112  
   113  
   114  def GetTLSCertPath(host):
   115    return os.path.join(_CERT_DIR, '%s.cert' % host)
   116  
   117  
   118  def UrlOpen(state, url):
   119    """Wrapper for urllib.request.urlopen.
   120  
   121    It selects correct HTTP scheme according to self._state.ssl, add HTTP
   122    basic auth headers, and add specify correct SSL context.
   123    """
   124    url = MakeRequestUrl(state, url)
   125    request = urllib.request.Request(url)
   126    if state.username is not None and state.password is not None:
   127      request.add_header(*BasicAuthHeader(state.username, state.password))
   128    request.add_header('User-Agent', _USER_AGENT)
   129    return urllib.request.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT,
   130                                  context=state.SSLContext())
   131  
   132  
   133  def GetTLSCertificateSHA1Fingerprint(cert_pem):
   134    beg = cert_pem.index('\n')
   135    end = cert_pem.rindex('\n', 0, len(cert_pem) - 2)
   136    cert_pem = cert_pem[beg:end]  # Remove BEGIN/END CERTIFICATE boundary
   137    cert_der = base64.b64decode(cert_pem)
   138    return hashlib.sha1(cert_der).hexdigest()
   139  
   140  
   141  def KillGraceful(pid, wait_secs=1):
   142    """Kill a process gracefully by first sending SIGTERM, wait for some time,
   143    then send SIGKILL to make sure it's killed."""
   144    try:
   145      os.kill(pid, signal.SIGTERM)
   146      time.sleep(wait_secs)
   147      os.kill(pid, signal.SIGKILL)
   148    except OSError:
   149      pass
   150  
   151  
   152  def AutoRetry(action_name, retries):
   153    """Decorator for retry function call."""
   154    def Wrap(func):
   155      @functools.wraps(func)
   156      def Loop(*args, **kwargs):
   157        for unused_i in range(retries):
   158          try:
   159            func(*args, **kwargs)
   160          except Exception as e:
   161            print('error: %s: %s: retrying ...' % (args[0], e))
   162          else:
   163            break
   164        else:
   165          print('error: failed to %s %s' % (action_name, args[0]))
   166      return Loop
   167    return Wrap
   168  
   169  
   170  def BasicAuthHeader(user, password):
   171    """Return HTTP basic auth header."""
   172    credential = base64.b64encode(
   173        b'%s:%s' % (user.encode('utf-8'), password.encode('utf-8')))
   174    return ('Authorization', 'Basic %s' % credential.decode('utf-8'))
   175  
   176  
   177  def GetTerminalSize():
   178    """Retrieve terminal window size."""
   179    ws = struct.pack('HHHH', 0, 0, 0, 0)
   180    ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws)
   181    lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws)
   182    return lines, columns
   183  
   184  
   185  def MakeRequestUrl(state, url):
   186    return 'http%s://%s' % ('s' if state.ssl else '', url)
   187  
   188  
   189  class ProgressBar:
   190    SIZE_WIDTH = 11
   191    SPEED_WIDTH = 10
   192    DURATION_WIDTH = 6
   193    PERCENTAGE_WIDTH = 8
   194  
   195    def __init__(self, name):
   196      self._start_time = time.time()
   197      self._name = name
   198      self._size = 0
   199      self._width = 0
   200      self._name_width = 0
   201      self._name_max = 0
   202      self._stat_width = 0
   203      self._max = 0
   204      self._CalculateSize()
   205      self.SetProgress(0)
   206  
   207    def _CalculateSize(self):
   208      self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH
   209      self._name_width = int(self._width * 0.3)
   210      self._name_max = self._name_width
   211      self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH
   212      self._max = (self._width - self._name_width - self._stat_width -
   213                   self.PERCENTAGE_WIDTH)
   214  
   215    def _SizeToHuman(self, size_in_bytes):
   216      if size_in_bytes < 1024:
   217        unit = 'B'
   218        value = size_in_bytes
   219      elif size_in_bytes < 1024 ** 2:
   220        unit = 'KiB'
   221        value = size_in_bytes / 1024
   222      elif size_in_bytes < 1024 ** 3:
   223        unit = 'MiB'
   224        value = size_in_bytes / (1024 ** 2)
   225      elif size_in_bytes < 1024 ** 4:
   226        unit = 'GiB'
   227        value = size_in_bytes / (1024 ** 3)
   228      return ' %6.1f %3s' % (value, unit)
   229  
   230    def _SpeedToHuman(self, speed_in_bs):
   231      if speed_in_bs < 1024:
   232        unit = 'B'
   233        value = speed_in_bs
   234      elif speed_in_bs < 1024 ** 2:
   235        unit = 'K'
   236        value = speed_in_bs / 1024
   237      elif speed_in_bs < 1024 ** 3:
   238        unit = 'M'
   239        value = speed_in_bs / (1024 ** 2)
   240      elif speed_in_bs < 1024 ** 4:
   241        unit = 'G'
   242        value = speed_in_bs / (1024 ** 3)
   243      return ' %6.1f%s/s' % (value, unit)
   244  
   245    def _DurationToClock(self, duration):
   246      return ' %02d:%02d' % (duration // 60, duration % 60)
   247  
   248    def SetProgress(self, percentage, size=None):
   249      current_width = GetTerminalSize()[1]
   250      if self._width != current_width:
   251        self._CalculateSize()
   252  
   253      if size is not None:
   254        self._size = size
   255  
   256      elapse_time = time.time() - self._start_time
   257      speed = self._size / elapse_time
   258  
   259      size_str = self._SizeToHuman(self._size)
   260      speed_str = self._SpeedToHuman(speed)
   261      elapse_str = self._DurationToClock(elapse_time)
   262  
   263      width = int(self._max * percentage / 100.0)
   264      sys.stdout.write(
   265          '%*s' % (- self._name_max,
   266                   self._name if len(self._name) <= self._name_max else
   267                   self._name[:self._name_max - 4] + ' ...') +
   268          size_str + speed_str + elapse_str +
   269          ((' [' + '#' * width + ' ' * (self._max - width) + ']' +
   270            '%4d%%' % int(percentage)) if self._max > 2 else '') + '\r')
   271      sys.stdout.flush()
   272  
   273    def End(self):
   274      self.SetProgress(100.0)
   275      sys.stdout.write('\n')
   276      sys.stdout.flush()
   277  
   278  
   279  class DaemonState:
   280    """DaemonState is used for storing Overlord state info."""
   281    def __init__(self):
   282      self.version_sha1sum = GetVersionDigest()
   283      self.host = None
   284      self.port = None
   285      self.ssl = False
   286      self.ssl_self_signed = False
   287      self.ssl_verify = True
   288      self.ssl_check_hostname = True
   289      self.ssh = False
   290      self.orig_host = None
   291      self.ssh_pid = None
   292      self.username = None
   293      self.password = None
   294      self.selected_mid = None
   295      self.forwards = {}
   296      self.listing = []
   297      self.last_list = 0
   298  
   299    def SSLContext(self):
   300      # No verify.
   301      if not self.ssl_verify:
   302        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
   303        context.check_hostname = False
   304        context.verify_mode = ssl.CERT_NONE
   305        return context
   306  
   307      context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
   308      context.check_hostname = self.ssl_check_hostname
   309      context.verify_mode = ssl.CERT_REQUIRED
   310  
   311      # Check if self signed certificate exists.
   312      ssl_cert_path = GetTLSCertPath(self.host)
   313      if os.path.exists(ssl_cert_path):
   314        context.load_verify_locations(ssl_cert_path)
   315        self.ssl_self_signed = True
   316        return context
   317  
   318      return ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
   319  
   320  
   321    @staticmethod
   322    def FromDict(kw):
   323      state = DaemonState()
   324  
   325      for k, v in kw.items():
   326        setattr(state, k, v)
   327      return state
   328  
   329  
   330  class OverlordClientDaemon:
   331    """Overlord Client Daemon."""
   332    def __init__(self):
   333      self._state = DaemonState()
   334      self._server = None
   335  
   336    def Start(self):
   337      self.StartRPCServer()
   338  
   339    def StartRPCServer(self):
   340      self._server = SimpleXMLRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR,
   341                                        logRequests=False, allow_none=True)
   342      exports = [
   343          (self.State, 'State'),
   344          (self.Ping, 'Ping'),
   345          (self.GetPid, 'GetPid'),
   346          (self.Connect, 'Connect'),
   347          (self.Clients, 'Clients'),
   348          (self.SelectClient, 'SelectClient'),
   349          (self.AddForward, 'AddForward'),
   350          (self.RemoveForward, 'RemoveForward'),
   351          (self.RemoveAllForward, 'RemoveAllForward'),
   352      ]
   353      for func, name in exports:
   354        self._server.register_function(func, name)
   355  
   356      pid = os.fork()
   357      if pid == 0:
   358        for fd in range(3):
   359          os.close(fd)
   360        self._server.serve_forever()
   361  
   362    @staticmethod
   363    def GetRPCServer():
   364      """Returns the Overlord client daemon RPC server."""
   365      server = ServerProxy('http://%s:%d' % _OVERLORD_CLIENT_DAEMON_RPC_ADDR,
   366                           allow_none=True)
   367      try:
   368        server.Ping()
   369      except Exception:
   370        return None
   371      return server
   372  
   373    def State(self):
   374      return self._state
   375  
   376    def Ping(self):
   377      return True
   378  
   379    def GetPid(self):
   380      return os.getpid()
   381  
   382    def _GetJSON(self, path):
   383      url = '%s:%d%s' % (self._state.host, self._state.port, path)
   384      return json.loads(UrlOpen(self._state, url).read())
   385  
   386    def _TLSEnabled(self):
   387      """Determine if TLS is enabled on given server address."""
   388      sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
   389      try:
   390        # Allow any certificate since we only want to check if server talks TLS.
   391        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
   392        context.check_hostname = False
   393        context.verify_mode = ssl.CERT_NONE
   394  
   395        sock = context.wrap_socket(sock, server_hostname=self._state.host)
   396        sock.settimeout(_CONNECT_TIMEOUT)
   397        sock.connect((self._state.host, self._state.port))
   398        return True
   399      except ssl.SSLError:
   400        return False
   401      except socket.error:  # Connect refused or timeout
   402        raise
   403      except Exception:
   404        return False  # For whatever reason above failed, assume False
   405  
   406    def _CheckTLSCertificate(self):
   407      """Check TLS certificate.
   408  
   409      Returns:
   410        A tupple (check_result, if_certificate_is_loaded)
   411      """
   412      def _DoConnect(context):
   413        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
   414        try:
   415          sock.settimeout(_CONNECT_TIMEOUT)
   416          sock = context.wrap_socket(sock, server_hostname=self._state.host)
   417          sock.connect((self._state.host, self._state.port))
   418        except ssl.SSLError:
   419          return False
   420        finally:
   421          sock.close()
   422  
   423        return True
   424  
   425      return _DoConnect(self._state.SSLContext())
   426  
   427    def Connect(self, host, port, ssh_pid=None,
   428                username=None, password=None, orig_host=None,
   429                ssl_verify=True, ssl_check_hostname=True):
   430      self._state.username = username
   431      self._state.password = password
   432      self._state.host = host
   433      self._state.port = port
   434      self._state.ssl = False
   435      self._state.ssl_self_signed = False
   436      self._state.orig_host = orig_host
   437      self._state.ssh_pid = ssh_pid
   438      self._state.selected_mid = None
   439      self._state.ssl_verify = ssl_verify
   440      self._state.ssl_check_hostname = ssl_check_hostname
   441  
   442      ssl_enabled = self._TLSEnabled()
   443      if ssl_enabled:
   444        result = self._CheckTLSCertificate()
   445        if not result:
   446          if self._state.ssl_self_signed:
   447            return ('SSLCertificateChanged', ssl.get_server_certificate(
   448                (self._state.host, self._state.port)))
   449          return ('SSLVerifyFailed', ssl.get_server_certificate(
   450              (self._state.host, self._state.port)))
   451  
   452      try:
   453        self._state.ssl = ssl_enabled
   454        UrlOpen(self._state, '%s:%d' % (host, port))
   455      except urllib.error.HTTPError as e:
   456        return ('HTTPError', e.getcode(), str(e), e.read().strip())
   457      except Exception as e:
   458        return str(e)
   459      else:
   460        return True
   461  
   462    def Clients(self):
   463      if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT:
   464        return self._state.listing
   465  
   466      self._state.listing = self._GetJSON('/api/agents/list')
   467      self._state.last_list = time.time()
   468      return self._state.listing
   469  
   470    def SelectClient(self, mid):
   471      self._state.selected_mid = mid
   472  
   473    def AddForward(self, mid, remote, local, pid):
   474      self._state.forwards[str(local)] = (mid, remote, pid)
   475  
   476    def RemoveForward(self, local_port):
   477      try:
   478        unused_mid, unused_remote, pid = self._state.forwards[str(local_port)]
   479        KillGraceful(pid)
   480        del self._state.forwards[local_port]
   481      except (KeyError, OSError):
   482        pass
   483  
   484    def RemoveAllForward(self):
   485      for unused_mid, unused_remote, pid in self._state.forwards.values():
   486        try:
   487          KillGraceful(pid)
   488        except OSError:
   489          pass
   490      self._state.forwards = {}
   491  
   492  
   493  class SSLEnabledWebSocketBaseClient(WebSocketBaseClient):
   494    def __init__(self, state, *args, **kwargs):
   495      super().__init__(ssl_context=state.SSLContext(), *args, **kwargs)
   496  
   497  
   498  class TerminalWebSocketClient(SSLEnabledWebSocketBaseClient):
   499    def __init__(self, state, mid, escape, *args, **kwargs):
   500      super().__init__(state, *args, **kwargs)
   501      self._mid = mid
   502      self._escape = escape
   503      self._stdin_fd = sys.stdin.fileno()
   504      self._old_termios = None
   505  
   506    def handshake_ok(self):
   507      pass
   508  
   509    def opened(self):
   510      nonlocals = {'size': (80, 40)}
   511  
   512      def _ResizeWindow():
   513        size = GetTerminalSize()
   514        if size != nonlocals['size']:  # Size not changed, ignore
   515          control = {'command': 'resize', 'params': list(size)}
   516          payload = (_CONTROL_START.to_bytes(1, 'big') +
   517                     json.dumps(control).encode('utf-8') +
   518                     _CONTROL_END.to_bytes(1, 'big'))
   519          nonlocals['size'] = size
   520          try:
   521            self.send(payload, binary=True)
   522          except Exception as e:
   523            logging.exception(e)
   524  
   525      def _FeedInput():
   526        self._old_termios = termios.tcgetattr(self._stdin_fd)
   527        tty.setraw(self._stdin_fd)
   528  
   529        READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3)
   530  
   531        try:
   532          state = READY
   533          while True:
   534            # Check if terminal is resized
   535            _ResizeWindow()
   536  
   537            ch = sys.stdin.read(1)
   538  
   539            # Scan for escape sequence
   540            if self._escape:
   541              if state == READY:
   542                state = ENTER_PRESSED if ch == chr(0x0d) else READY
   543              elif state == ENTER_PRESSED:
   544                state = ESCAPE_PRESSED if ch == self._escape else READY
   545              elif state == ESCAPE_PRESSED:
   546                if ch == '.':
   547                  self.close()
   548                  break
   549              else:
   550                state = READY
   551  
   552            self.send(ch)
   553        except (KeyboardInterrupt, RuntimeError):
   554          pass
   555  
   556      t = threading.Thread(target=_FeedInput)
   557      t.daemon = True
   558      t.start()
   559  
   560    def closed(self, code, reason=None):
   561      del code, reason  # Unused.
   562      termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios)
   563      print('\nConnection to %s closed.' % self._mid)
   564  
   565    def received_message(self, message):
   566      if message.is_binary:
   567        sys.stdout.buffer.write(message.data)
   568        sys.stdout.flush()
   569  
   570  
   571  class ShellWebSocketClient(SSLEnabledWebSocketBaseClient):
   572    def __init__(self, state, output, *args, **kwargs):
   573      """Constructor.
   574  
   575      Args:
   576        output: output file object.
   577      """
   578      super().__init__(state, *args, **kwargs)
   579      self._output = output
   580      self._input_thread = threading.Thread(target=self._FeedInput)
   581      self._stop = threading.Event()
   582  
   583    def handshake_ok(self):
   584      pass
   585  
   586    def _FeedInput(self):
   587      try:
   588        while True:
   589          rd, unused_w, unused_x = select.select([sys.stdin], [], [], 0.5)
   590          if self._stop.is_set():
   591            break
   592  
   593          if sys.stdin in rd:
   594            data = sys.stdin.buffer.read()
   595            if not data:
   596              self.send(_STDIN_CLOSED * 2)
   597              break
   598            self.send(data, binary=True)
   599      except (KeyboardInterrupt, RuntimeError):
   600        pass
   601  
   602    def opened(self):
   603      self._input_thread.start()
   604  
   605    def closed(self, code, reason=None):
   606      self._stop.set()
   607      self._input_thread.join()
   608  
   609    def received_message(self, message):
   610      if message.is_binary:
   611        self._output.write(message.data)
   612        self._output.flush()
   613  
   614  
   615  class ForwarderWebSocketClient(SSLEnabledWebSocketBaseClient):
   616    def __init__(self, state, sock, *args, **kwargs):
   617      super().__init__(state, *args, **kwargs)
   618      self._sock = sock
   619      self._input_thread = threading.Thread(target=self._FeedInput)
   620      self._stop = threading.Event()
   621  
   622    def handshake_ok(self):
   623      pass
   624  
   625    def _FeedInput(self):
   626      try:
   627        self._sock.setblocking(False)
   628        while True:
   629          rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5)
   630          if self._stop.is_set():
   631            break
   632          if self._sock in rd:
   633            data = self._sock.recv(_BUFSIZ)
   634            if not data:
   635              self.close()
   636              break
   637            self.send(data, binary=True)
   638      except Exception:
   639        pass
   640      finally:
   641        self._sock.close()
   642  
   643    def opened(self):
   644      self._input_thread.start()
   645  
   646    def closed(self, code, reason=None):
   647      del code, reason  # Unused.
   648      self._stop.set()
   649      self._input_thread.join()
   650      sys.exit(0)
   651  
   652    def received_message(self, message):
   653      if message.is_binary:
   654        self._sock.send(message.data)
   655  
   656  
   657  def Arg(*args, **kwargs):
   658    return (args, kwargs)
   659  
   660  
   661  def Command(command, help_msg=None, args=None):
   662    """Decorator for adding argparse parameter for a method."""
   663    if args is None:
   664      args = []
   665    def WrapFunc(func):
   666      @functools.wraps(func)
   667      def Wrapped(*args, **kwargs):
   668        return func(*args, **kwargs)
   669      # pylint: disable=protected-access
   670      Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args}
   671      return Wrapped
   672    return WrapFunc
   673  
   674  
   675  def ParseMethodSubCommands(cls):
   676    """Decorator for a class using the @Command decorator.
   677  
   678    This decorator retrieve command info from each method and append it in to the
   679    SUBCOMMANDS class variable, which is later used to construct parser.
   680    """
   681    for unused_key, method in cls.__dict__.items():
   682      if hasattr(method, '__arg_attr'):
   683        # pylint: disable=protected-access
   684        cls.SUBCOMMANDS.append(method.__arg_attr)
   685    return cls
   686  
   687  
   688  @ParseMethodSubCommands
   689  class OverlordCLIClient:
   690    """Overlord command line interface client."""
   691  
   692    SUBCOMMANDS = []
   693  
   694    def __init__(self):
   695      self._parser = self._BuildParser()
   696      self._selected_mid = None
   697      self._server = None
   698      self._state = None
   699      self._escape = None
   700  
   701    def _BuildParser(self):
   702      root_parser = argparse.ArgumentParser(prog='ovl')
   703      subparsers = root_parser.add_subparsers(title='subcommands',
   704                                              dest='subcommand')
   705      subparsers.required = True
   706  
   707      root_parser.add_argument('-s', dest='selected_mid', action='store',
   708                               default=None,
   709                               help='select target to execute command on')
   710      root_parser.add_argument('-S', dest='select_mid_before_action',
   711                               action='store_true', default=False,
   712                               help='select target before executing command')
   713      root_parser.add_argument('-e', dest='escape', metavar='ESCAPE_CHAR',
   714                               action='store', default=_ESCAPE, type=str,
   715                               help='set shell escape character, \'none\' to '
   716                               'disable escape completely')
   717  
   718      for attr in self.SUBCOMMANDS:
   719        parser = subparsers.add_parser(attr['command'], help=attr['help'])
   720        parser.set_defaults(which=attr['command'])
   721        for arg in attr['args']:
   722          parser.add_argument(*arg[0], **arg[1])
   723  
   724      return root_parser
   725  
   726    def Main(self):
   727      # We want to pass the rest of arguments after shell command directly to the
   728      # function without parsing it.
   729      try:
   730        index = sys.argv.index('shell')
   731      except ValueError:
   732        args = self._parser.parse_args()
   733      else:
   734        args = self._parser.parse_args(sys.argv[1:index + 1])
   735  
   736      command = args.which
   737      self._selected_mid = args.selected_mid
   738  
   739      if args.escape and args.escape != 'none':
   740        self._escape = args.escape[0]
   741  
   742      if command == 'start-server':
   743        self.StartServer()
   744        return
   745      if command == 'kill-server':
   746        self.KillServer()
   747        return
   748  
   749      self.CheckDaemon()
   750      if command == 'status':
   751        self.Status()
   752        return
   753      if command == 'connect':
   754        self.Connect(args)
   755        return
   756  
   757      # The following command requires connection to the server
   758      self.CheckConnection()
   759  
   760      if args.select_mid_before_action:
   761        self.SelectClient(store=False)
   762  
   763      if command == 'select':
   764        self.SelectClient(args)
   765      elif command == 'ls':
   766        self.ListClients(args)
   767      elif command == 'shell':
   768        command = sys.argv[sys.argv.index('shell') + 1:]
   769        self.Shell(command)
   770      elif command == 'push':
   771        self.Push(args)
   772      elif command == 'pull':
   773        self.Pull(args)
   774      elif command == 'forward':
   775        self.Forward(args)
   776  
   777    def _SaveTLSCertificate(self, host, cert_pem):
   778      try:
   779        os.makedirs(_CERT_DIR)
   780      except Exception:
   781        pass
   782      with open(GetTLSCertPath(host), 'w') as f:
   783        f.write(cert_pem)
   784  
   785    def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None):
   786      """Perform HTTP POST and upload file to Overlord.
   787  
   788      To minimize the external dependencies, we construct the HTTP post request
   789      by ourselves.
   790      """
   791      url = MakeRequestUrl(self._state, url)
   792      size = os.stat(filename).st_size
   793      boundary = '-----------%s' % _HTTP_BOUNDARY_MAGIC
   794      CRLF = '\r\n'
   795      parse = urllib.parse.urlparse(url)
   796  
   797      part_headers = [
   798          '--' + boundary,
   799          'Content-Disposition: form-data; name="file"; '
   800          'filename="%s"' % os.path.basename(filename),
   801          'Content-Type: application/octet-stream',
   802          '', ''
   803      ]
   804      part_header = CRLF.join(part_headers)
   805      end_part = CRLF + '--' + boundary + '--' + CRLF
   806  
   807      content_length = len(part_header) + size + len(end_part)
   808      if parse.scheme == 'http':
   809        h = http.client.HTTPConnection(parse.netloc)
   810      else:
   811        h = http.client.HTTPSConnection(parse.netloc,
   812                                        context=self._state.SSLContext())
   813  
   814      post_path = url[url.index(parse.netloc) + len(parse.netloc):]
   815      h.putrequest('POST', post_path)
   816      h.putheader('Content-Length', content_length)
   817      h.putheader('Content-Type', 'multipart/form-data; boundary=%s' % boundary)
   818  
   819      if user and passwd:
   820        h.putheader(*BasicAuthHeader(user, passwd))
   821      h.endheaders()
   822      h.send(part_header.encode('utf-8'))
   823  
   824      count = 0
   825      with open(filename, 'rb') as f:
   826        while True:
   827          data = f.read(_BUFSIZ)
   828          if not data:
   829            break
   830          count += len(data)
   831          if progress:
   832            progress(count * 100 // size, count)
   833          h.send(data)
   834  
   835      h.send(end_part.encode('utf-8'))
   836      progress(100)
   837  
   838      if count != size:
   839        logging.warning('file changed during upload, upload may be truncated.')
   840  
   841      resp = h.getresponse()
   842      return resp.status == 200
   843  
   844    def CheckDaemon(self):
   845      self._server = OverlordClientDaemon.GetRPCServer()
   846      if self._server is None:
   847        print('* daemon not running, starting it now on port %d ... *' %
   848              _OVERLORD_CLIENT_DAEMON_PORT)
   849        self.StartServer()
   850  
   851      self._state = DaemonState.FromDict(self._server.State())
   852      sha1sum = GetVersionDigest()
   853  
   854      if sha1sum != self._state.version_sha1sum:
   855        print('ovl server is out of date.  killing...')
   856        KillGraceful(self._server.GetPid())
   857        self.StartServer()
   858  
   859    def GetSSHControlFile(self, host):
   860      return _SSH_CONTROL_SOCKET_PREFIX + host
   861  
   862    def SSHTunnel(self, user, host, port):
   863      """SSH forward the remote overlord server.
   864  
   865      Overlord server may not have port 9000 open to the public network, in such
   866      case we can SSH forward the port to 127.0.0.1.
   867      """
   868  
   869      control_file = self.GetSSHControlFile(host)
   870      try:
   871        os.unlink(control_file)
   872      except Exception:
   873        pass
   874  
   875      with subprocess.Popen([
   876          'ssh', '-Nf', '-M', '-S', control_file, '-L', '9000:127.0.0.1:9000',
   877          '-p',
   878          str(port),
   879          '%s%s' % (user + '@' if user else '', host)
   880      ]):
   881        pass
   882  
   883      p = subprocess.Popen([
   884          'ssh',
   885          '-S', control_file,
   886          '-O', 'check', host,
   887      ], stderr=subprocess.PIPE)
   888      unused_stdout, stderr = p.communicate()
   889  
   890      s = re.search(r'pid=(\d+)', stderr)
   891      if s:
   892        return int(s.group(1))
   893  
   894      raise RuntimeError('can not establish ssh connection')
   895  
   896    def CheckConnection(self):
   897      if self._state.host is None:
   898        raise RuntimeError('not connected to any server, abort')
   899  
   900      try:
   901        self._server.Clients()
   902      except Exception:
   903        raise RuntimeError('remote server disconnected, abort') from None
   904  
   905      if self._state.ssh_pid is not None:
   906        with subprocess.Popen(
   907            ['kill', '-0', str(self._state.ssh_pid)], stdout=subprocess.PIPE,
   908            stderr=subprocess.PIPE) as p:
   909          pass
   910        if p.returncode != 0:
   911          raise RuntimeError('ssh tunnel disconnected, please re-connect')
   912  
   913    def CheckClient(self):
   914      if self._selected_mid is None:
   915        if self._state.selected_mid is None:
   916          raise RuntimeError('No client is selected')
   917        self._selected_mid = self._state.selected_mid
   918  
   919      if not any(client['mid'] == self._selected_mid
   920                 for client in self._server.Clients()):
   921        raise RuntimeError('client %s disappeared' % self._selected_mid)
   922  
   923    def CheckOutput(self, command):
   924      headers = []
   925      if self._state.username is not None and self._state.password is not None:
   926        headers.append(BasicAuthHeader(self._state.username,
   927                                       self._state.password))
   928  
   929      scheme = 'ws%s://' % ('s' if self._state.ssl else '')
   930      bio = BytesIO()
   931      ws = ShellWebSocketClient(
   932          self._state, bio, scheme + '%s:%d/api/agent/shell/%s?command=%s' % (
   933              self._state.host, self._state.port,
   934              urllib.parse.quote(self._selected_mid),
   935              urllib.parse.quote(command)),
   936          headers=headers)
   937      ws.connect()
   938      ws.run()
   939      return bio.getvalue().decode('utf-8')
   940  
   941    @Command('status', 'show Overlord connection status')
   942    def Status(self):
   943      if self._state.host is None:
   944        print('Not connected to any host.')
   945      else:
   946        if self._state.ssh_pid is not None:
   947          print('Connected to %s with SSH tunneling.' % self._state.orig_host)
   948        else:
   949          print('Connected to %s:%d.' % (self._state.host, self._state.port))
   950  
   951      if self._selected_mid is None:
   952        self._selected_mid = self._state.selected_mid
   953  
   954      if self._selected_mid is None:
   955        print('No client is selected.')
   956      else:
   957        print('Client %s selected.' % self._selected_mid)
   958  
   959    @Command('connect', 'connect to Overlord server', [
   960        Arg('host', metavar='HOST', type=str, default='127.0.0.1',
   961            help='Overlord hostname/IP'),
   962        Arg('port', metavar='PORT', type=int, nargs='?',
   963            default=_DEFAULT_HTTPS_PORT, help='Overlord port'),
   964        Arg('-f', '--forward', dest='ssh_forward', default=False,
   965            action='store_true',
   966            help='connect with SSH forwarding to the host'),
   967        Arg('-p', '--ssh-port', dest='ssh_port', default=22,
   968            type=int, help='SSH server port for SSH forwarding'),
   969        Arg('-l', '--ssh-login', dest='ssh_login', default='',
   970            type=str, help='SSH server login name for SSH forwarding'),
   971        Arg('-u', '--user', dest='user', default=None,
   972            type=str, help='Overlord HTTP auth username'),
   973        Arg('-w', '--passwd', dest='passwd', default=None, type=str,
   974            help='Overlord HTTP auth password'),
   975        Arg('--ssl-no-verify', dest='ssl_verify',
   976            default=True, action='store_false',
   977            help='Ignore SSL cert verification'),
   978        Arg('--ssl-no-check-hostname', dest='ssl_check_hostname',
   979            default=True, action='store_false',
   980            help='Ignore SSL cert hostname check')])
   981    def Connect(self, args):
   982      ssh_pid = None
   983      host = args.host
   984      orig_host = args.host
   985  
   986      if args.ssh_forward:
   987        # Kill previous SSH tunnel
   988        self.KillSSHTunnel()
   989  
   990        ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port)
   991        host = '127.0.0.1'
   992  
   993      username_provided = args.user is not None
   994      password_provided = args.passwd is not None
   995      prompt = False
   996  
   997      for unused_i in range(3):  # pylint: disable=too-many-nested-blocks
   998        try:
   999          if prompt:
  1000            if not username_provided:
  1001              args.user = input('Username: ')
  1002            if not password_provided:
  1003              args.passwd = getpass.getpass('Password: ')
  1004  
  1005          ret = self._server.Connect(host, args.port, ssh_pid, args.user,
  1006                                     args.passwd, orig_host,
  1007                                     args.ssl_verify, args.ssl_check_hostname)
  1008          if isinstance(ret, list):
  1009            if ret[0].startswith('SSL'):
  1010              cert_pem = ret[1]
  1011              fp = GetTLSCertificateSHA1Fingerprint(cert_pem)
  1012              fp_text = ':'.join([fp[i:i+2] for i in range(0, len(fp), 2)])
  1013  
  1014            if ret[0] == 'SSLCertificateChanged':
  1015              print(_TLS_CERT_CHANGED_WARNING % (fp_text, GetTLSCertPath(host)))
  1016              return
  1017            if ret[0] == 'SSLVerifyFailed':
  1018              print(_TLS_CERT_FAILED_WARNING % (fp_text), end='')
  1019              response = input()
  1020              if response.lower() in ['y', 'ye', 'yes']:
  1021                self._SaveTLSCertificate(host, cert_pem)
  1022                print('TLS host Certificate trusted, you will not be prompted '
  1023                      'next time.\n')
  1024                continue
  1025              print('connection aborted.')
  1026              return
  1027            if ret[0] == 'HTTPError':
  1028              code, except_str, body = ret[1:]
  1029              if code == 401:
  1030                print('connect: %s' % body)
  1031                prompt = True
  1032                if not username_provided or not password_provided:
  1033                  continue
  1034                break
  1035              logging.error('%s; %s', except_str, body)
  1036  
  1037          if ret is not True:
  1038            print('can not connect to %s: %s' % (host, ret))
  1039          else:
  1040            print('connection to %s:%d established.' % (host, args.port))
  1041        except Exception as e:
  1042          logging.error(e)
  1043        else:
  1044          break
  1045  
  1046    @Command('start-server', 'start overlord CLI client server')
  1047    def StartServer(self):
  1048      self._server = OverlordClientDaemon.GetRPCServer()
  1049      if self._server is None:
  1050        OverlordClientDaemon().Start()
  1051        time.sleep(1)
  1052        self._server = OverlordClientDaemon.GetRPCServer()
  1053        if self._server is not None:
  1054          print('* daemon started successfully *\n')
  1055  
  1056    @Command('kill-server', 'kill overlord CLI client server')
  1057    def KillServer(self):
  1058      self._server = OverlordClientDaemon.GetRPCServer()
  1059      if self._server is None:
  1060        return
  1061  
  1062      self._state = DaemonState.FromDict(self._server.State())
  1063  
  1064      # Kill SSH Tunnel
  1065      self.KillSSHTunnel()
  1066  
  1067      # Kill server daemon
  1068      KillGraceful(self._server.GetPid())
  1069  
  1070    def KillSSHTunnel(self):
  1071      if self._state.ssh_pid is not None:
  1072        KillGraceful(self._state.ssh_pid)
  1073  
  1074    def _FilterClients(self, clients, prop_filters, mid=None):
  1075      def _ClientPropertiesMatch(client, key, regex):
  1076        try:
  1077          return bool(re.search(regex, client['properties'][key]))
  1078        except KeyError:
  1079          return False
  1080  
  1081      for prop_filter in prop_filters:
  1082        key, sep, regex = prop_filter.partition('=')
  1083        if not sep:
  1084          # The filter doesn't contains =.
  1085          raise ValueError('Invalid filter condition %r' % filter)
  1086        clients = [c for c in clients if _ClientPropertiesMatch(c, key, regex)]
  1087  
  1088      if mid is not None:
  1089        client = next((c for c in clients if c['mid'] == mid), None)
  1090        if client:
  1091          return [client]
  1092        clients = [c for c in clients if c['mid'].startswith(mid)]
  1093      return clients
  1094  
  1095    @Command('ls', 'list clients', [
  1096        Arg('-f', '--filter', default=[], dest='filters', action='append',
  1097            help=('Conditions to filter clients by properties. '
  1098                  'Should be in form "key=regex", where regex is the regular '
  1099                  'expression that should be found in the value. '
  1100                  'Multiple --filter arguments would be ANDed.')),
  1101        Arg('-v', '--verbose', default=False, action='store_true',
  1102            help='Print properties of each client.')
  1103    ])
  1104    def ListClients(self, args):
  1105      clients = self._FilterClients(self._server.Clients(), args.filters)
  1106      for client in clients:
  1107        if args.verbose:
  1108          print(yaml.safe_dump(client, default_flow_style=False))
  1109        else:
  1110          print(client['mid'])
  1111  
  1112    @Command('select', 'select default client', [
  1113        Arg('-f', '--filter', default=[], dest='filters', action='append',
  1114            help=('Conditions to filter clients by properties. '
  1115                  'Should be in form "key=regex", where regex is the regular '
  1116                  'expression that should be found in the value. '
  1117                  'Multiple --filter arguments would be ANDed.')),
  1118        Arg('mid', metavar='mid', nargs='?', default=None)])
  1119    def SelectClient(self, args=None, store=True):
  1120      mid = args.mid if args is not None else None
  1121      filters = args.filters if args is not None else []
  1122      clients = self._FilterClients(self._server.Clients(), filters, mid=mid)
  1123  
  1124      if not clients:
  1125        raise RuntimeError('select: client not found')
  1126      if len(clients) == 1:
  1127        mid = clients[0]['mid']
  1128      else:
  1129        # This case would not happen when args.mid is specified.
  1130        print('Select from the following clients:')
  1131        for i, client in enumerate(clients):
  1132          print('    %d. %s' % (i + 1, client['mid']))
  1133  
  1134        print('\nSelection: ', end='')
  1135        try:
  1136          choice = int(input()) - 1
  1137          mid = clients[choice]['mid']
  1138        except ValueError:
  1139          raise RuntimeError('select: invalid selection') from None
  1140        except IndexError:
  1141          raise RuntimeError('select: selection out of range') from None
  1142  
  1143      self._selected_mid = mid
  1144      if store:
  1145        self._server.SelectClient(mid)
  1146        print('Client %s selected' % mid)
  1147  
  1148    @Command('shell', 'open a shell or execute a shell command', [
  1149        Arg('command', metavar='CMD', nargs='?', help='command to execute')])
  1150    def Shell(self, command=None):
  1151      if command is None:
  1152        command = []
  1153      self.CheckClient()
  1154  
  1155      headers = []
  1156      if self._state.username is not None and self._state.password is not None:
  1157        headers.append(BasicAuthHeader(self._state.username,
  1158                                       self._state.password))
  1159  
  1160      scheme = 'ws%s://' % ('s' if self._state.ssl else '')
  1161      if command:
  1162        cmd = ' '.join(command)
  1163        ws = ShellWebSocketClient(
  1164            self._state, sys.stdout.buffer,
  1165            scheme + '%s:%d/api/agent/shell/%s?command=%s' % (
  1166                self._state.host, self._state.port,
  1167                urllib.parse.quote(self._selected_mid), urllib.parse.quote(cmd)),
  1168            headers=headers)
  1169      else:
  1170        ws = TerminalWebSocketClient(
  1171            self._state, self._selected_mid, self._escape,
  1172            scheme + '%s:%d/api/agent/tty/%s' % (
  1173                self._state.host, self._state.port,
  1174                urllib.parse.quote(self._selected_mid)),
  1175            headers=headers)
  1176      try:
  1177        ws.connect()
  1178        ws.run()
  1179      except socket.error as e:
  1180        if e.errno == 32:  # Broken pipe
  1181          pass
  1182        else:
  1183          raise
  1184  
  1185    @Command('push', 'push a file or directory to remote', [
  1186        Arg('srcs', nargs='+', metavar='SOURCE'),
  1187        Arg('dst', metavar='DESTINATION')])
  1188    def Push(self, args):
  1189      self.CheckClient()
  1190  
  1191      @AutoRetry('push', _RETRY_TIMES)
  1192      def _push(src, dst):
  1193        src_base = os.path.basename(src)
  1194  
  1195        # Local file is a link
  1196        if os.path.islink(src):
  1197          pbar = ProgressBar(src_base)
  1198          link_path = os.readlink(src)
  1199          self.CheckOutput('mkdir -p %(dirname)s; '
  1200                           'if [ -d "%(dst)s" ]; then '
  1201                           'ln -sf "%(link_path)s" "%(dst)s/%(link_name)s"; '
  1202                           'else ln -sf "%(link_path)s" "%(dst)s"; fi' %
  1203                           dict(dirname=os.path.dirname(dst),
  1204                                link_path=link_path, dst=dst,
  1205                                link_name=src_base))
  1206          pbar.End()
  1207          return
  1208  
  1209        mode = '0%o' % (0x1FF & os.stat(src).st_mode)
  1210        url = ('%s:%d/api/agent/upload/%s?dest=%s&perm=%s' %
  1211               (self._state.host, self._state.port,
  1212                urllib.parse.quote(self._selected_mid), dst, mode))
  1213        try:
  1214          UrlOpen(self._state, url + '&filename=%s' % src_base)
  1215        except urllib.error.HTTPError as e:
  1216          msg = json.loads(e.read()).get('error', None)
  1217          raise RuntimeError('push: %s' % msg) from None
  1218  
  1219        pbar = ProgressBar(src_base)
  1220        self._HTTPPostFile(url, src, pbar.SetProgress,
  1221                           self._state.username, self._state.password)
  1222        pbar.End()
  1223  
  1224      def _push_single_target(src, dst):
  1225        if os.path.isdir(src):
  1226          dst_exists = ast.literal_eval(self.CheckOutput(
  1227              'stat %s >/dev/null 2>&1 && echo True || echo False' % dst))
  1228          for root, unused_x, files in os.walk(src):
  1229            # If destination directory does not exist, we should strip the first
  1230            # layer of directory. For example: src_dir contains a single file 'A'
  1231            #
  1232            # push src_dir dest_dir
  1233            #
  1234            # If dest_dir exists, the resulting directory structure should be:
  1235            #   dest_dir/src_dir/A
  1236            # If dest_dir does not exist, the resulting directory structure should
  1237            # be:
  1238            #   dest_dir/A
  1239            dst_root = os.path.basename(root) if dst_exists else ''
  1240            for name in files:
  1241              _push(os.path.join(root, name),
  1242                    os.path.join(dst, dst_root, name))
  1243        else:
  1244          _push(src, dst)
  1245  
  1246      if len(args.srcs) > 1:
  1247        dst_type = self.CheckOutput('stat \'%s\' --printf \'%%F\' '
  1248                                    '2>/dev/null' % args.dst).strip()
  1249        if not dst_type:
  1250          raise RuntimeError('push: %s: No such file or directory' % args.dst)
  1251        if dst_type != 'directory':
  1252          raise RuntimeError('push: %s: Not a directory' % args.dst)
  1253  
  1254      for src in args.srcs:
  1255        if not os.path.exists(src):
  1256          raise RuntimeError('push: can not stat "%s": no such file or directory'
  1257                             % src)
  1258        if not os.access(src, os.R_OK):
  1259          raise RuntimeError('push: can not open "%s" for reading' % src)
  1260  
  1261        _push_single_target(src, args.dst)
  1262  
  1263    @Command('pull', 'pull a file or directory from remote', [
  1264        Arg('src', metavar='SOURCE'),
  1265        Arg('dst', metavar='DESTINATION', default='.', nargs='?')])
  1266    def Pull(self, args):
  1267      self.CheckClient()
  1268  
  1269      @AutoRetry('pull', _RETRY_TIMES)
  1270      def _pull(src, dst, ftype, perm=0o644, link=None):
  1271        try:
  1272          os.makedirs(os.path.dirname(dst))
  1273        except Exception:
  1274          pass
  1275  
  1276        src_base = os.path.basename(src)
  1277  
  1278        # Remote file is a link
  1279        if ftype == 'l':
  1280          pbar = ProgressBar(src_base)
  1281          if os.path.exists(dst):
  1282            os.remove(dst)
  1283          os.symlink(link, dst)
  1284          pbar.End()
  1285          return
  1286  
  1287        url = ('%s:%d/api/agent/download/%s?filename=%s' %
  1288               (self._state.host, self._state.port,
  1289                urllib.parse.quote(self._selected_mid), urllib.parse.quote(src)))
  1290        try:
  1291          h = UrlOpen(self._state, url)
  1292        except urllib.error.HTTPError as e:
  1293          msg = json.loads(e.read()).get('error', 'unkown error')
  1294          raise RuntimeError('pull: %s' % msg) from None
  1295        except KeyboardInterrupt:
  1296          return
  1297  
  1298        pbar = ProgressBar(src_base)
  1299        with open(dst, 'wb') as f:
  1300          os.fchmod(f.fileno(), perm)
  1301          total_size = int(h.headers.get('Content-Length'))
  1302          downloaded_size = 0
  1303  
  1304          while True:
  1305            data = h.read(_BUFSIZ)
  1306            if not data:
  1307              break
  1308            downloaded_size += len(data)
  1309            pbar.SetProgress(downloaded_size * 100 / total_size,
  1310                             downloaded_size)
  1311            f.write(data)
  1312        pbar.End()
  1313  
  1314      # Use find to get a listing of all files under a root directory. The 'stat'
  1315      # command is used to retrieve the filename and it's filemode.
  1316      output = self.CheckOutput(
  1317          'cd $HOME; '
  1318          'stat "%(src)s" >/dev/null && '
  1319          'find "%(src)s" \'(\' -type f -o -type l \')\' '
  1320          '-printf \'%%m\t%%p\t%%y\t%%l\n\''
  1321          % {'src': args.src})
  1322  
  1323      # We got error from the stat command
  1324      if output.startswith('stat: '):
  1325        sys.stderr.write(output)
  1326        return
  1327  
  1328      entries = output.strip('\n').split('\n')
  1329      common_prefix = os.path.dirname(args.src)
  1330  
  1331      if len(entries) == 1:
  1332        entry = entries[0]
  1333        perm, src_path, ftype, link = entry.split('\t', -1)
  1334        if os.path.isdir(args.dst):
  1335          dst = os.path.join(args.dst, os.path.basename(src_path))
  1336        else:
  1337          dst = args.dst
  1338        _pull(src_path, dst, ftype, int(perm, base=8), link)
  1339      else:
  1340        if not os.path.exists(args.dst):
  1341          common_prefix = args.src
  1342  
  1343        for entry in entries:
  1344          perm, src_path, ftype, link = entry.split('\t', -1)
  1345          rel_dst = src_path[len(common_prefix):].lstrip('/')
  1346          _pull(src_path, os.path.join(args.dst, rel_dst), ftype,
  1347                int(perm, base=8), link)
  1348  
  1349    @Command('forward', 'forward remote port to local port', [
  1350        Arg('--list', dest='list_all', action='store_true', default=False,
  1351            help='list all port forwarding sessions'),
  1352        Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int,
  1353            default=None,
  1354            help='remove port forwarding for local port LOCAL_PORT'),
  1355        Arg('--remove-all', dest='remove_all', action='store_true',
  1356            default=False, help='remove all port forwarding'),
  1357        Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'),
  1358        Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')])
  1359    def Forward(self, args):
  1360      if args.list_all:
  1361        max_len = 10
  1362        if self._state.forwards:
  1363          max_len = max([len(v[0]) for v in self._state.forwards.values()])
  1364  
  1365        print('%-*s   %-8s  %-8s' % (max_len, 'Client', 'Remote', 'Local'))
  1366        for local in sorted(self._state.forwards.keys()):
  1367          value = self._state.forwards[local]
  1368          print('%-*s   %-8s  %-8s' % (max_len, value[0], value[1], local))
  1369        return
  1370  
  1371      if args.remove_all:
  1372        self._server.RemoveAllForward()
  1373        return
  1374  
  1375      if args.remove:
  1376        self._server.RemoveForward(args.remove)
  1377        return
  1378  
  1379      self.CheckClient()
  1380  
  1381      if args.remote is None:
  1382        raise RuntimeError('remote port not specified')
  1383  
  1384      if args.local is None:
  1385        args.local = args.remote
  1386      remote = int(args.remote)
  1387      local = int(args.local)
  1388  
  1389      def HandleConnection(conn):
  1390        headers = []
  1391        if self._state.username is not None and self._state.password is not None:
  1392          headers.append(BasicAuthHeader(self._state.username,
  1393                                         self._state.password))
  1394  
  1395        scheme = 'ws%s://' % ('s' if self._state.ssl else '')
  1396        ws = ForwarderWebSocketClient(
  1397            self._state, conn,
  1398            scheme + '%s:%d/api/agent/forward/%s?port=%d' % (
  1399                self._state.host, self._state.port,
  1400                urllib.parse.quote(self._selected_mid), remote),
  1401            headers=headers)
  1402        try:
  1403          ws.connect()
  1404          ws.run()
  1405        except Exception as e:
  1406          print('error: %s' % e)
  1407        finally:
  1408          ws.close()
  1409  
  1410      server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  1411      server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  1412      server.bind(('0.0.0.0', local))
  1413      server.listen(5)
  1414  
  1415      pid = os.fork()
  1416      if pid == 0:
  1417        while True:
  1418          conn, unused_addr = server.accept()
  1419          t = threading.Thread(target=HandleConnection, args=(conn,))
  1420          t.daemon = True
  1421          t.start()
  1422      else:
  1423        self._server.AddForward(self._selected_mid, remote, local, pid)
  1424  
  1425  
  1426  def main():
  1427    # Setup logging format
  1428    logger = logging.getLogger()
  1429    logger.setLevel(logging.INFO)
  1430    handler = logging.StreamHandler()
  1431    formatter = logging.Formatter('%(asctime)s %(message)s', '%Y/%m/%d %H:%M:%S')
  1432    handler.setFormatter(formatter)
  1433    logger.addHandler(handler)
  1434  
  1435    ovl = OverlordCLIClient()
  1436    try:
  1437      ovl.Main()
  1438    except KeyboardInterrupt:
  1439      print('Ctrl-C received, abort')
  1440    except Exception as e:
  1441      print(f'error: {e}')
  1442  
  1443  
  1444  if __name__ == '__main__':
  1445    main()