github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/acceptancetests/remote.py (about)

     1  """Remote helper class for communicating with juju machines."""
     2  import abc
     3  import logging
     4  import os
     5  import subprocess
     6  import sys
     7  import zlib
     8  
     9  import winrm
    10  
    11  import utility
    12  import jujupy
    13  
    14  
    15  __metaclass__ = type
    16  
    17  
    18  def _remote_for_series(series):
    19      """Give an appropriate remote class based on machine series."""
    20      if series is not None and series.startswith("win"):
    21          return WinRmRemote
    22      return SSHRemote
    23  
    24  
    25  def remote_from_unit(client, unit, series=None, status=None):
    26      """Create remote instance given a juju client and a unit."""
    27      if series is None:
    28          if status is None:
    29              status = client.get_status()
    30          machine = status.get_unit(unit).get("machine")
    31          if machine is not None:
    32              series = status.status["machines"].get(machine, {}).get("series")
    33      remotecls = _remote_for_series(series)
    34      return remotecls(client, unit, None, series=series, status=status)
    35  
    36  
    37  def remote_from_address(address, series=None):
    38      """Create remote instance given an address"""
    39      remotecls = _remote_for_series(series)
    40      return remotecls(None, None, address, series=series)
    41  
    42  
    43  class _Remote:
    44      """_Remote represents a juju machine to access over the network."""
    45  
    46      __metaclass__ = abc.ABCMeta
    47  
    48      def __init__(self, client, unit, address, series=None, status=None):
    49          if address is None and (client is None or unit is None):
    50              raise ValueError("Remote needs either address or client and unit")
    51          self.client = client
    52          self.unit = unit
    53          self.use_juju_ssh = unit is not None
    54          self.address = address
    55          self.series = series
    56          self.status = status
    57  
    58      def __repr__(self):
    59          params = []
    60          if self.client is not None:
    61              params.append("env=" + repr(self.client.env.environment))
    62          if self.unit is not None:
    63              params.append("unit=" + repr(self.unit))
    64          if self.address is not None:
    65              params.append("addr=" + repr(self.address))
    66          return "<{} {}>".format(self.__class__.__name__, " ".join(params))
    67  
    68      @abc.abstractmethod
    69      def cat(self, filename):
    70          """
    71          Get the contents of filename from the remote machine.
    72  
    73          Environment variables in the filename will be expanded in a according
    74          to platform-specific rules.
    75          """
    76  
    77      @abc.abstractmethod
    78      def copy(self, destination_dir, source_globs):
    79          """Copy files from the remote machine."""
    80  
    81      def is_windows(self):
    82          """Returns True if remote machine is running windows."""
    83          return self.series and self.series.startswith("win")
    84  
    85      def get_address(self):
    86          """Gives the address of the remote machine."""
    87          self._ensure_address()
    88          return self.address
    89  
    90      def update_address(self, address):
    91          """Change address of remote machine."""
    92          self.address = address
    93  
    94      def _get_status(self):
    95          if self.status is None:
    96              self.status = self.client.get_status()
    97          return self.status
    98  
    99      def _ensure_address(self):
   100          if self.address:
   101              return
   102          if self.client is None:
   103              raise ValueError("No address or client supplied")
   104          status = self._get_status()
   105          unit = status.get_unit(self.unit)
   106          if 'public-address' not in unit:
   107              raise ValueError("No public address for unit: {!r} {!r}".format(
   108                  self.unit, unit))
   109          self.address = unit['public-address']
   110  
   111  
   112  def _default_is_command_error(err):
   113      """
   114      Whether to treat error as issue with remote command rather than ssh.
   115  
   116      This is a conservative default, remote commands may return a variety of
   117      other return codes. However, as the fallback to local ssh binary will
   118      repeat the command, those problems will be exposed later anyway.
   119      """
   120      return err.returncode == 1
   121  
   122  
   123  def _no_platform_ssh():
   124      """True if no openssh binary is available on this platform."""
   125      return sys.platform == "win32"
   126  
   127  
   128  class SSHRemote(_Remote):
   129      """SSHRemote represents a juju machine to access using ssh."""
   130  
   131      _ssh_opts = [
   132          "-o", "User ubuntu",
   133          "-o", "UserKnownHostsFile /dev/null",
   134          "-o", "StrictHostKeyChecking no",
   135          "-o", "PasswordAuthentication no",
   136          # prevent "permanently added to known hosts" warning from polluting
   137          # test log output
   138          "-o", "LogLevel ERROR",
   139      ]
   140  
   141      # Limit each operation over SSH to 2 minutes by default
   142      timeout = 120
   143  
   144      def run(self, command_args, is_command_error=_default_is_command_error):
   145          """
   146          Run a command on the remote machine.
   147  
   148          If the remote instance has a juju unit run will default to using the
   149          juju ssh command. Otherwise, or if that fails, it will fall back to
   150          using ssh directly.
   151  
   152          The command_args param is a string or list of arguments to be invoked
   153          on the remote machine. A string must be given if special shell
   154          characters are used.
   155  
   156          The is_command_error param is a function that takes an instance of
   157          CalledProcessError and returns whether that error comes from the
   158          command being run rather than ssh itself. This can be used to skip the
   159          fallback to native ssh behaviour when running commands that may fail.
   160          """
   161          if not isinstance(command_args, (list, tuple)):
   162              command_args = [command_args]
   163          if self.use_juju_ssh:
   164              logging.debug('juju ssh {}'.format(self.unit))
   165              try:
   166                  return self.client.get_juju_output(
   167                      "ssh", self.unit, *command_args, timeout=self.timeout)
   168              except subprocess.CalledProcessError as e:
   169                  logging.warning(
   170                      "juju ssh to {!r} failed, returncode: {} output: {!r}"
   171                      " stderr: {!r}".format(
   172                          self.unit, e.returncode, e.output,
   173                          getattr(e, "stderr", None)))
   174                  # Don't fallback to calling ssh directly if command really
   175                  # failed or if there is likely to be no usable ssh client.
   176                  if is_command_error(e) or _no_platform_ssh():
   177                      raise
   178                  self.use_juju_ssh = False
   179              self._ensure_address()
   180          args = ["ssh"]
   181          args.extend(self._ssh_opts)
   182          args.append(self.address)
   183          args.extend(command_args)
   184          logging.debug(' '.join(utility.quote(i) for i in args))
   185          return self._run_subprocess(args).decode('utf-8')
   186  
   187      def copy(self, destination_dir, source_globs):
   188          """Copy files from the remote machine."""
   189          self._ensure_address()
   190          args = ["scp", "-rC"]
   191          args.extend(self._ssh_opts)
   192          address = utility.as_literal_address(self.address)
   193          args.extend(["{}:{}".format(address, f) for f in source_globs])
   194          args.append(destination_dir)
   195          self._run_subprocess(args)
   196  
   197      def cat(self, filename):
   198          """
   199          Get the contents of filename from the remote machine.
   200  
   201          Tildes and environment variables in the form $TMP will be expanded.
   202          """
   203          return self.run(["cat", filename])
   204  
   205      def use_ssh_key(self, identity_file):
   206          if "-i" in self._ssh_opts:
   207              return
   208          self._ssh_opts.append("-i")
   209          self._ssh_opts.append(identity_file)
   210  
   211      def _run_subprocess(self, command):
   212          if self.timeout:
   213              command = jujupy.get_timeout_prefix(self.timeout) + tuple(command)
   214          return subprocess.check_output(command, stdin=subprocess.PIPE)
   215  
   216  
   217  class _SSLSession(winrm.Session):
   218  
   219      def __init__(self, target, auth, transport="ssl"):
   220          key, cert = auth
   221          self.url = self._build_url(target, transport)
   222          self.protocol = winrm.Protocol(self.url, transport=transport,
   223                                         cert_key_pem=key, cert_pem=cert)
   224  
   225  
   226  _ps_copy_script = """\
   227  $ErrorActionPreference = "Stop"
   228  
   229  function OutputEncodedFile {
   230      param([String]$filename, [IO.Stream]$instream)
   231      $trans = New-Object Security.Cryptography.ToBase64Transform
   232      $out = [Console]::OpenStandardOutput()
   233      $bs = New-Object Security.Cryptography.CryptoStream($out, $trans,
   234          [Security.Cryptography.CryptoStreamMode]::Write)
   235      $zs = New-Object IO.Compression.DeflateStream($bs,
   236          [IO.Compression.CompressionMode]::Compress)
   237      [Console]::Out.Write($filename + "|")
   238      try {
   239          $instream.CopyTo($zs)
   240      } finally {
   241          $zs.close()
   242          $bs.close()
   243          [Console]::Out.Write("`n")
   244      }
   245  }
   246  
   247  function GatherFiles {
   248      param([String[]]$patterns)
   249      ForEach ($pattern in $patterns) {
   250          $path = [Environment]::ExpandEnvironmentVariables($pattern)
   251          ForEach ($file in Get-Item -path $path) {
   252              try {
   253                  $in = New-Object IO.FileStream($file, [IO.FileMode]::Open,
   254                      [IO.FileAccess]::Read, [IO.FileShare]"ReadWrite,Delete")
   255                  OutputEncodedFile -filename $file.name -instream $in
   256              } catch {
   257                  $utf8 = New-Object Text.UTF8Encoding($False)
   258                  $errstream = New-Object IO.MemoryStream(
   259                      $utf8.GetBytes($_.Exception), $False)
   260                  $errfilename = $file.name + ".copyerror"
   261                  OutputEncodedFile -filename $errfilename -instream $errstream
   262              }
   263          }
   264      }
   265  }
   266  
   267  try {
   268      GatherFiles -patterns @(%s)
   269  } catch {
   270      Write-Error $_.Exception
   271      exit 1
   272  }
   273  """
   274  
   275  
   276  class WinRmRemote(_Remote):
   277      """WinRmRemote represents a juju machine to access using winrm."""
   278  
   279      def __init__(self, *args, **kwargs):
   280          super(WinRmRemote, self).__init__(*args, **kwargs)
   281          self._ensure_address()
   282          self.use_juju_ssh = False
   283          self.certs = utility.get_winrm_certs()
   284          self.session = _SSLSession(self.address, self.certs)
   285  
   286      def update_address(self, address):
   287          """Change address of remote machine, refreshes the winrm session."""
   288          self.address = address
   289          self.session = _SSLSession(self.address, self.certs)
   290  
   291      _escape = staticmethod(subprocess.list2cmdline)
   292  
   293      def run_cmd(self, cmd_list):
   294          """Run cmd and arguments given as a list returning response object."""
   295          if isinstance(cmd_list, (str, bytes)):
   296              raise ValueError("run_cmd requires a list not a string")
   297          # pywinrm does not correctly escape arguments, fix up by escaping cmd
   298          # and giving args as a list of a single pre-escaped string.
   299          cmd = self._escape(cmd_list[:1])
   300          args = [self._escape(cmd_list[1:])]
   301          return self.session.run_cmd(cmd, args)
   302  
   303      def run_ps(self, script):
   304          """Run string of powershell returning response object."""
   305          return self.session.run_ps(script)
   306  
   307      def cat(self, filename):
   308          """
   309          Get the contents of filename from the remote machine.
   310  
   311          Backslashes will be treated as directory seperators. Environment
   312          variables in the form %TMP% will be expanded.
   313          """
   314          result = self.session.run_cmd("type", [self._escape([filename])])
   315          if result.status_code:
   316              logging.warning("winrm cat failed %r", result)
   317          return result.std_out
   318  
   319      # TODO(gz): Unlike SSHRemote.copy this only supports copying files, not
   320      #           directories and their content. Both the powershell script and
   321      #           the unpacking method will need updating to support that.
   322      def copy(self, destination_dir, source_globs):
   323          """Copy files from the remote machine."""
   324          # Encode globs into script to run on remote machine and return result.
   325          script = _ps_copy_script % ",".join(s.join('""') for s in source_globs)
   326          result = self.run_ps(script)
   327          if result.status_code:
   328              logging.warning("winrm copy stderr:\n%s", result.std_err)
   329              raise subprocess.CalledProcessError(result.status_code,
   330                                                  "powershell", result)
   331          self._encoded_copy_to_dir(destination_dir, result.std_out)
   332  
   333      @staticmethod
   334      def _encoded_copy_to_dir(destination_dir, output):
   335          """Write remote files from powershell script to disk.
   336  
   337          The given output from the powershell script is one line per file, with
   338          the filename first, then a pipe, then the base64 encoded deflated file
   339          contents. This method reverses that process and creates the files in
   340          the given destination_dir.
   341          """
   342          start = 0
   343          while True:
   344              end = output.find("\n", start)
   345              if end == -1:
   346                  break
   347              mid = output.find("|", start, end)
   348              if mid == -1:
   349                  if not output[start:end].rstrip("\r\n"):
   350                      break
   351                  raise ValueError("missing filename in encoded copy data")
   352              filename = output[start:mid]
   353              if "/" in filename:
   354                  # Just defense against path traversal bugs, should never reach.
   355                  raise ValueError("path not filename {!r}".format(filename))
   356              with open(os.path.join(destination_dir, filename), "wb") as f:
   357                  f.write(zlib.decompress(output[mid + 1:end].decode("base64"),
   358                                          -zlib.MAX_WBITS))
   359              start = end + 1