github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/subprocess_server.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import contextlib
    21  import glob
    22  import hashlib
    23  import logging
    24  import os
    25  import re
    26  import shutil
    27  import signal
    28  import socket
    29  import subprocess
    30  import tempfile
    31  import threading
    32  import time
    33  import zipfile
    34  from urllib.error import URLError
    35  from urllib.request import urlopen
    36  
    37  import grpc
    38  
    39  from apache_beam.version import __version__ as beam_version
    40  
    41  _LOGGER = logging.getLogger(__name__)
    42  
    43  
    44  class SubprocessServer(object):
    45    """An abstract base class for running GRPC Servers as an external process.
    46  
    47    This class acts as a context which will start up a server, provides a stub
    48    to connect to it, and then shuts the server down.  For example::
    49  
    50        with SubprocessServer(GrpcStubClass, [executable, arg, ...]) as stub:
    51            stub.CallService(...)
    52    """
    53    def __init__(self, stub_class, cmd, port=None):
    54      """Creates the server object.
    55  
    56      :param stub_class: the auto-generated GRPC client stub class used for
    57          connecting to the GRPC service
    58      :param cmd: command (including arguments) for starting up the server,
    59          suitable for passing to `subprocess.POpen`.
    60      :param port: (optional) the port at which the subprocess will serve its
    61          service.  If not given, one will be randomly chosen and the special
    62          string "{{PORT}}" will be substituted in the command line arguments
    63          with the chosen port.
    64      """
    65      self._process_lock = threading.RLock()
    66      self._process = None
    67      self._stub_class = stub_class
    68      self._cmd = [str(arg) for arg in cmd]
    69      self._port = port
    70  
    71    def __enter__(self):
    72      return self.start()
    73  
    74    def __exit__(self, *unused_args):
    75      self.stop()
    76  
    77    def start(self):
    78      try:
    79        endpoint = self.start_process()
    80        wait_secs = .1
    81        channel_options = [("grpc.max_receive_message_length", -1),
    82                           ("grpc.max_send_message_length", -1)]
    83        channel = grpc.insecure_channel(endpoint, options=channel_options)
    84        channel_ready = grpc.channel_ready_future(channel)
    85        while True:
    86          if self._process is not None and self._process.poll() is not None:
    87            _LOGGER.error("Starting job service with %s", self._process.args)
    88            raise RuntimeError(
    89                'Service failed to start up with error %s' % self._process.poll())
    90          try:
    91            channel_ready.result(timeout=wait_secs)
    92            break
    93          except (grpc.FutureTimeoutError, grpc.RpcError):
    94            wait_secs *= 1.2
    95            logging.log(
    96                logging.WARNING if wait_secs > 1 else logging.DEBUG,
    97                'Waiting for grpc channel to be ready at %s.',
    98                endpoint)
    99        return self._stub_class(channel)
   100      except:  # pylint: disable=bare-except
   101        _LOGGER.exception("Error bringing up service")
   102        self.stop()
   103        raise
   104  
   105    def start_process(self):
   106      with self._process_lock:
   107        if self._process:
   108          self.stop()
   109        if self._port:
   110          port = self._port
   111          cmd = self._cmd
   112        else:
   113          port, = pick_port(None)
   114          cmd = [arg.replace('{{PORT}}', str(port)) for arg in self._cmd]
   115        endpoint = 'localhost:%s' % port
   116        _LOGGER.info("Starting service with %s", str(cmd).replace("',", "'"))
   117        self._process = subprocess.Popen(
   118            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
   119  
   120        # Emit the output of this command as info level logging.
   121        def log_stdout():
   122          line = self._process.stdout.readline()
   123          while line:
   124            # The log obtained from stdout is bytes, decode it into string.
   125            # Remove newline via rstrip() to not print an empty line.
   126            _LOGGER.info(line.decode(errors='backslashreplace').rstrip())
   127            line = self._process.stdout.readline()
   128  
   129        t = threading.Thread(target=log_stdout)
   130        t.daemon = True
   131        t.start()
   132        return endpoint
   133  
   134    def stop(self):
   135      self.stop_process()
   136  
   137    def stop_process(self):
   138      with self._process_lock:
   139        if not self._process:
   140          return
   141        for _ in range(5):
   142          if self._process.poll() is not None:
   143            break
   144          logging.debug("Sending SIGINT to job_server")
   145          self._process.send_signal(signal.SIGINT)
   146          time.sleep(1)
   147        if self._process.poll() is None:
   148          self._process.kill()
   149        self._process = None
   150  
   151    def local_temp_dir(self, **kwargs):
   152      return tempfile.mkdtemp(dir=self._local_temp_root, **kwargs)
   153  
   154  
   155  class JavaJarServer(SubprocessServer):
   156  
   157    MAVEN_CENTRAL_REPOSITORY = 'https://repo.maven.apache.org/maven2'
   158    BEAM_GROUP_ID = 'org.apache.beam'
   159    JAR_CACHE = os.path.expanduser("~/.apache_beam/cache/jars")
   160  
   161    _BEAM_SERVICES = type(
   162        'local', (threading.local, ),
   163        dict(__init__=lambda self: setattr(self, 'replacements', {})))()
   164  
   165    def __init__(self, stub_class, path_to_jar, java_arguments, classpath=None):
   166      if classpath:
   167        # java -jar ignores the classpath, so we make a new jar that embeds
   168        # the requested classpath.
   169        path_to_jar = self.make_classpath_jar(path_to_jar, classpath)
   170      super().__init__(
   171          stub_class, ['java', '-jar', path_to_jar] + list(java_arguments))
   172      self._existing_service = path_to_jar if _is_service_endpoint(
   173          path_to_jar) else None
   174  
   175    def start_process(self):
   176      if self._existing_service:
   177        return self._existing_service
   178      else:
   179        if not shutil.which('java'):
   180          raise RuntimeError(
   181              'Java must be installed on this system to use this '
   182              'transform/runner.')
   183        return super().start_process()
   184  
   185    def stop_process(self):
   186      if self._existing_service:
   187        pass
   188      else:
   189        return super().stop_process()
   190  
   191    @classmethod
   192    def jar_name(cls, artifact_id, version, classifier=None, appendix=None):
   193      return '-'.join(
   194          filter(None, [artifact_id, appendix, version, classifier])) + '.jar'
   195  
   196    @classmethod
   197    def path_to_maven_jar(
   198        cls,
   199        artifact_id,
   200        group_id,
   201        version,
   202        repository=MAVEN_CENTRAL_REPOSITORY,
   203        classifier=None,
   204        appendix=None):
   205      return '/'.join([
   206          repository,
   207          group_id.replace('.', '/'),
   208          artifact_id,
   209          version,
   210          cls.jar_name(artifact_id, version, classifier, appendix)
   211      ])
   212  
   213    @classmethod
   214    def path_to_beam_jar(
   215        cls,
   216        gradle_target,
   217        appendix=None,
   218        version=beam_version,
   219        artifact_id=None):
   220      if gradle_target in cls._BEAM_SERVICES.replacements:
   221        return cls._BEAM_SERVICES.replacements[gradle_target]
   222  
   223      gradle_package = gradle_target.strip(':').rsplit(':', 1)[0]
   224      if not artifact_id:
   225        artifact_id = 'beam-' + gradle_package.replace(':', '-')
   226      project_root = os.path.sep.join(
   227          os.path.abspath(__file__).split(os.path.sep)[:-5])
   228      local_path = os.path.join(
   229          project_root,
   230          gradle_package.replace(':', os.path.sep),
   231          'build',
   232          'libs',
   233          cls.jar_name(
   234              artifact_id,
   235              version.replace('.dev', ''),
   236              classifier='SNAPSHOT',
   237              appendix=appendix))
   238      if os.path.exists(local_path):
   239        _LOGGER.info('Using pre-built snapshot at %s', local_path)
   240        return local_path
   241      elif '.dev' in version:
   242        # TODO: Attempt to use nightly snapshots?
   243        raise RuntimeError(
   244            (
   245                '%s not found. '
   246                'Please build the server with \n  cd %s; ./gradlew %s') %
   247            (local_path, os.path.abspath(project_root), gradle_target))
   248      else:
   249        return cls.path_to_maven_jar(
   250            artifact_id,
   251            cls.BEAM_GROUP_ID,
   252            version,
   253            cls.MAVEN_CENTRAL_REPOSITORY,
   254            appendix=appendix)
   255  
   256    @classmethod
   257    def local_jar(cls, url, cache_dir=None):
   258      if cache_dir is None:
   259        cache_dir = cls.JAR_CACHE
   260      # TODO: Verify checksum?
   261      if _is_service_endpoint(url):
   262        return url
   263      elif os.path.exists(url):
   264        return url
   265      else:
   266        cached_jar = os.path.join(cache_dir, os.path.basename(url))
   267        if os.path.exists(cached_jar):
   268          _LOGGER.info('Using cached job server jar from %s' % url)
   269        else:
   270          _LOGGER.info('Downloading job server jar from %s' % url)
   271          if not os.path.exists(cache_dir):
   272            os.makedirs(cache_dir)
   273            # TODO: Clean up this cache according to some policy.
   274          try:
   275            url_read = urlopen(url)
   276            with open(cached_jar + '.tmp', 'wb') as jar_write:
   277              shutil.copyfileobj(url_read, jar_write, length=1 << 20)
   278            os.rename(cached_jar + '.tmp', cached_jar)
   279          except URLError as e:
   280            raise RuntimeError(
   281                'Unable to fetch remote job server jar at %s: %s' % (url, e))
   282        return cached_jar
   283  
   284    @classmethod
   285    @contextlib.contextmanager
   286    def beam_services(cls, replacements):
   287      try:
   288        old = cls._BEAM_SERVICES.replacements
   289        cls._BEAM_SERVICES.replacements = dict(old, **replacements)
   290        yield
   291      finally:
   292        cls._BEAM_SERVICES.replacements = old
   293  
   294    @classmethod
   295    def make_classpath_jar(cls, main_jar, extra_jars, cache_dir=None):
   296      if cache_dir is None:
   297        cache_dir = cls.JAR_CACHE
   298      composite_jar_dir = os.path.join(cache_dir, 'composite-jars')
   299      os.makedirs(composite_jar_dir, exist_ok=True)
   300      classpath = []
   301      # Class-Path references from a jar must be relative, so we create
   302      # a relatively-addressable subdirectory with symlinks to all the
   303      # required jars.
   304      for pattern in [main_jar] + list(extra_jars):
   305        for path in glob.glob(pattern) or [pattern]:
   306          path = os.path.abspath(path)
   307          rel_path = hashlib.sha256(
   308              path.encode('utf-8')).hexdigest() + os.path.splitext(path)[1]
   309          classpath.append(rel_path)
   310          if not os.path.lexists(os.path.join(composite_jar_dir, rel_path)):
   311            os.symlink(path, os.path.join(composite_jar_dir, rel_path))
   312      # Now create a single jar that simply references the rest and has the same
   313      # main class as main_jar.
   314      composite_jar = os.path.join(
   315          composite_jar_dir,
   316          hashlib.sha256(' '.join(sorted(classpath)).encode('ascii')).hexdigest()
   317          + '.jar')
   318      if not os.path.exists(composite_jar):
   319        with zipfile.ZipFile(main_jar) as main:
   320          with main.open('META-INF/MANIFEST.MF') as manifest:
   321            main_class = next(
   322                filter(lambda line: line.startswith(b'Main-Class: '), manifest))
   323        with zipfile.ZipFile(composite_jar + '.tmp', 'w') as composite:
   324          with composite.open('META-INF/MANIFEST.MF', 'w') as manifest:
   325            manifest.write(b'Manifest-Version: 1.0\n')
   326            manifest.write(main_class)
   327            manifest.write(
   328                b'Class-Path: ' + '\n  '.join(classpath).encode('ascii') + b'\n')
   329        os.rename(composite_jar + '.tmp', composite_jar)
   330      return composite_jar
   331  
   332  
   333  def _is_service_endpoint(path):
   334    return re.match(r'^[a-zA-Z0-9.-]+:\d+$', path)
   335  
   336  
   337  def pick_port(*ports):
   338    """
   339    Returns a list of ports, same length as input ports list, but replaces
   340    all None or 0 ports with a random free port.
   341    """
   342    sockets = []
   343  
   344    def find_free_port(port):
   345      if port:
   346        return port
   347      else:
   348        try:
   349          s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
   350        except OSError as e:
   351          # [Errno 97] Address family not supported by protocol
   352          # Likely indicates we are in an IPv6-only environment (BEAM-10618). Try
   353          # again with AF_INET6.
   354          if e.errno == 97:
   355            s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
   356          else:
   357            raise e
   358  
   359        sockets.append(s)
   360        s.bind(('localhost', 0))
   361        return s.getsockname()[1]
   362  
   363    ports = list(map(find_free_port, ports))
   364    # Close sockets only now to avoid the same port to be chosen twice
   365    for s in sockets:
   366      s.close()
   367    return ports