
     1  """Remote helper class for communicating with juju machines."""
     2  import abc
     3  import logging
     4  import os
     5  import subprocess
     6  import sys
     7  import zlib
     9  import winrm
    11  import jujupy
    12  import utility
    15  __metaclass__ = type
    18  def _remote_for_series(series):
    19      """Give an appropriate remote class based on machine series."""
    20      if series is not None and series.startswith("win"):
    21          return WinRmRemote
    22      return SSHRemote
    25  def remote_from_unit(client, unit, series=None, status=None):
    26      """Create remote instance given a juju client and a unit."""
    27      if series is None:
    28          if status is None:
    29              status = client.get_status()
    30          machine = status.get_unit(unit).get("machine")
    31          if machine is not None:
    32              series = status.status["machines"].get(machine, {}).get("series")
    33      remotecls = _remote_for_series(series)
    34      return remotecls(client, unit, None, series=series, status=status)
    37  def remote_from_address(address, series=None):
    38      """Create remote instance given an address"""
    39      remotecls = _remote_for_series(series)
    40      return remotecls(None, None, address, series=series)
    43  class _Remote:
    44      """_Remote represents a juju machine to access over the network."""
    46      __metaclass__ = abc.ABCMeta
    48      def __init__(self, client, unit, address, series=None, status=None):
    49          if address is None and (client is None or unit is None):
    50              raise ValueError("Remote needs either address or client and unit")
    51          self.client = client
    52          self.unit = unit
    53          self.use_juju_ssh = unit is not None
    54          self.address = address
    55          self.series = series
    56          self.status = status
    58      def __repr__(self):
    59          params = []
    60          if self.client is not None:
    61              params.append("env=" + repr(self.client.env.environment))
    62          if self.unit is not None:
    63              params.append("unit=" + repr(self.unit))
    64          if self.address is not None:
    65              params.append("addr=" + repr(self.address))
    66          return "<{} {}>".format(self.__class__.__name__, " ".join(params))
    68      @abc.abstractmethod
    69      def cat(self, filename):
    70          """
    71          Get the contents of filename from the remote machine.
    73          Environment variables in the filename will be expanded in a according
    74          to platform-specific rules.
    75          """
    77      @abc.abstractmethod
    78      def copy(self, destination_dir, source_globs):
    79          """Copy files from the remote machine."""
    81      def is_windows(self):
    82          """Returns True if remote machine is running windows."""
    83          return self.series and self.series.startswith("win")
    85      def get_address(self):
    86          """Gives the address of the remote machine."""
    87          self._ensure_address()
    88          return self.address
    90      def update_address(self, address):
    91          """Change address of remote machine."""
    92          self.address = address
    94      def _get_status(self):
    95          if self.status is None:
    96              self.status = self.client.get_status()
    97          return self.status
    99      def _ensure_address(self):
   100          if self.address:
   101              return
   102          if self.client is None:
   103              raise ValueError("No address or client supplied")
   104          status = self._get_status()
   105          unit = status.get_unit(self.unit)
   106          if 'public-address' not in unit:
   107              raise ValueError("No public address for unit: {!r} {!r}".format(
   108                  self.unit, unit))
   109          self.address = unit['public-address']
   112  def _default_is_command_error(err):
   113      """
   114      Whether to treat error as issue with remote command rather than ssh.
   116      This is a conservative default, remote commands may return a variety of
   117      other return codes. However, as the fallback to local ssh binary will
   118      repeat the command, those problems will be exposed later anyway.
   119      """
   120      return err.returncode == 1
   123  def _no_platform_ssh():
   124      """True if no openssh binary is available on this platform."""
   125      return sys.platform == "win32"
   128  class SSHRemote(_Remote):
   129      """SSHRemote represents a juju machine to access using ssh."""
   131      _ssh_opts = [
   132          "-o", "User ubuntu",
   133          "-o", "UserKnownHostsFile /dev/null",
   134          "-o", "StrictHostKeyChecking no",
   135          "-o", "PasswordAuthentication no",
   136      ]
   138      # Limit each operation over SSH to 2 minutes by default
   139      timeout = 120
   141      def run(self, command_args, is_command_error=_default_is_command_error):
   142          """
   143          Run a command on the remote machine.
   145          If the remote instance has a juju unit run will default to using the
   146          juju ssh command. Otherwise, or if that fails, it will fall back to
   147          using ssh directly.
   149          The command_args param is a string or list of arguments to be invoked
   150          on the remote machine. A string must be given if special shell
   151          characters are used.
   153          The is_command_error param is a function that takes an instance of
   154          CalledProcessError and returns whether that error comes from the
   155          command being run rather than ssh itself. This can be used to skip the
   156          fallback to native ssh behaviour when running commands that may fail.
   157          """
   158          if not isinstance(command_args, (list, tuple)):
   159              command_args = [command_args]
   160          if self.use_juju_ssh:
   161              logging.debug('juju ssh {}'.format(self.unit))
   162              try:
   163                  return self.client.get_juju_output(
   164                      "ssh", self.unit, *command_args, timeout=self.timeout)
   165              except subprocess.CalledProcessError as e:
   166                  logging.warning(
   167                      "juju ssh to {!r} failed, returncode: {} output: {!r}"
   168                      " stderr: {!r}".format(
   169                          self.unit, e.returncode, e.output,
   170                          getattr(e, "stderr", None)))
   171                  # Don't fallback to calling ssh directly if command really
   172                  # failed or if there is likely to be no usable ssh client.
   173                  if is_command_error(e) or _no_platform_ssh():
   174                      raise
   175                  self.use_juju_ssh = False
   176              self._ensure_address()
   177          args = ["ssh"]
   178          args.extend(self._ssh_opts)
   179          args.append(self.address)
   180          args.extend(command_args)
   181          logging.debug(' '.join(utility.quote(i) for i in args))
   182          return self._run_subprocess(args)
   184      def copy(self, destination_dir, source_globs):
   185          """Copy files from the remote machine."""
   186          self._ensure_address()
   187          args = ["scp", "-rC"]
   188          args.extend(self._ssh_opts)
   189          address = utility.as_literal_address(self.address)
   190          args.extend(["{}:{}".format(address, f) for f in source_globs])
   191          args.append(destination_dir)
   192          self._run_subprocess(args)
   194      def cat(self, filename):
   195          """
   196          Get the contents of filename from the remote machine.
   198          Tildes and environment variables in the form $TMP will be expanded.
   199          """
   200          return["cat", filename])
   202      def _run_subprocess(self, command):
   203          if self.timeout:
   204              command = jujupy.get_timeout_prefix(self.timeout) + tuple(command)
   205          return subprocess.check_output(command, stdin=subprocess.PIPE)
   208  class _SSLSession(winrm.Session):
   210      def __init__(self, target, auth, transport="ssl"):
   211          key, cert = auth
   212          self.url = self._build_url(target, transport)
   213          self.protocol = winrm.Protocol(self.url, transport=transport,
   214                                         cert_key_pem=key, cert_pem=cert)
   217  _ps_copy_script = """\
   218  $ErrorActionPreference = "Stop"
   220  function OutputEncodedFile {
   221      param([String]$filename, [IO.Stream]$instream)
   222      $trans = New-Object Security.Cryptography.ToBase64Transform
   223      $out = [Console]::OpenStandardOutput()
   224      $bs = New-Object Security.Cryptography.CryptoStream($out, $trans,
   225          [Security.Cryptography.CryptoStreamMode]::Write)
   226      $zs = New-Object IO.Compression.DeflateStream($bs,
   227          [IO.Compression.CompressionMode]::Compress)
   228      [Console]::Out.Write($filename + "|")
   229      try {
   230          $instream.CopyTo($zs)
   231      } finally {
   232          $zs.close()
   233          $bs.close()
   234          [Console]::Out.Write("`n")
   235      }
   236  }
   238  function GatherFiles {
   239      param([String[]]$patterns)
   240      ForEach ($pattern in $patterns) {
   241          $path = [Environment]::ExpandEnvironmentVariables($pattern)
   242          ForEach ($file in Get-Item -path $path) {
   243              try {
   244                  $in = New-Object IO.FileStream($file, [IO.FileMode]::Open,
   245                      [IO.FileAccess]::Read, [IO.FileShare]"ReadWrite,Delete")
   246                  OutputEncodedFile -filename $ -instream $in
   247              } catch {
   248                  $utf8 = New-Object Text.UTF8Encoding($False)
   249                  $errstream = New-Object IO.MemoryStream(
   250                      $utf8.GetBytes($_.Exception), $False)
   251                  $errfilename = $ + ".copyerror"
   252                  OutputEncodedFile -filename $errfilename -instream $errstream
   253              }
   254          }
   255      }
   256  }
   258  try {
   259      GatherFiles -patterns @(%s)
   260  } catch {
   261      Write-Error $_.Exception
   262      exit 1
   263  }
   264  """
   267  class WinRmRemote(_Remote):
   268      """WinRmRemote represents a juju machine to access using winrm."""
   270      def __init__(self, *args, **kwargs):
   271          super(WinRmRemote, self).__init__(*args, **kwargs)
   272          self._ensure_address()
   273          self.use_juju_ssh = False
   274          self.certs = utility.get_winrm_certs()
   275          self.session = _SSLSession(self.address, self.certs)
   277      def update_address(self, address):
   278          """Change address of remote machine, refreshes the winrm session."""
   279          self.address = address
   280          self.session = _SSLSession(self.address, self.certs)
   282      _escape = staticmethod(subprocess.list2cmdline)
   284      def run_cmd(self, cmd_list):
   285          """Run cmd and arguments given as a list returning response object."""
   286          if isinstance(cmd_list, basestring):
   287              raise ValueError("run_cmd requires a list not a string")
   288          # pywinrm does not correctly escape arguments, fix up by escaping cmd
   289          # and giving args as a list of a single pre-escaped string.
   290          cmd = self._escape(cmd_list[:1])
   291          args = [self._escape(cmd_list[1:])]
   292          return self.session.run_cmd(cmd, args)
   294      def run_ps(self, script):
   295          """Run string of powershell returning response object."""
   296          return self.session.run_ps(script)
   298      def cat(self, filename):
   299          """
   300          Get the contents of filename from the remote machine.
   302          Backslashes will be treated as directory seperators. Environment
   303          variables in the form %TMP% will be expanded.
   304          """
   305          result = self.session.run_cmd("type", [self._escape([filename])])
   306          if result.status_code:
   307              logging.warning("winrm cat failed %r", result)
   308          return result.std_out
   310      # TODO(gz): Unlike SSHRemote.copy this only supports copying files, not
   311      #           directories and their content. Both the powershell script and
   312      #           the unpacking method will need updating to support that.
   313      def copy(self, destination_dir, source_globs):
   314          """Copy files from the remote machine."""
   315          # Encode globs into script to run on remote machine and return result.
   316          script = _ps_copy_script % ",".join(s.join('""') for s in source_globs)
   317          result = self.run_ps(script)
   318          if result.status_code:
   319              logging.warning("winrm copy stderr:\n%s", result.std_err)
   320              raise subprocess.CalledProcessError(result.status_code,
   321                                                  "powershell", result)
   322          self._encoded_copy_to_dir(destination_dir, result.std_out)
   324      @staticmethod
   325      def _encoded_copy_to_dir(destination_dir, output):
   326          """Write remote files from powershell script to disk.
   328          The given output from the powershell script is one line per file, with
   329          the filename first, then a pipe, then the base64 encoded deflated file
   330          contents. This method reverses that process and creates the files in
   331          the given destination_dir.
   332          """
   333          start = 0
   334          while True:
   335              end = output.find("\n", start)
   336              if end == -1:
   337                  break
   338              mid = output.find("|", start, end)
   339              if mid == -1:
   340                  if not output[start:end].rstrip("\r\n"):
   341                      break
   342                  raise ValueError("missing filename in encoded copy data")
   343              filename = output[start:mid]
   344              if "/" in filename:
   345                  # Just defense against path traversal bugs, should never reach.
   346                  raise ValueError("path not filename {!r}".format(filename))
   347              with open(os.path.join(destination_dir, filename), "wb") as f:
   348                  f.write(zlib.decompress(output[mid + 1:end].decode("base64"),
   349                                          -zlib.MAX_WBITS))
   350              start = end + 1