github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/local_job_service.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import concurrent.futures
    20  import itertools
    21  import logging
    22  import os
    23  import queue
    24  import shutil
    25  import subprocess
    26  import tempfile
    27  import threading
    28  import time
    29  import traceback
    30  from typing import TYPE_CHECKING
    31  from typing import List
    32  from typing import Optional
    33  
    34  import grpc
    35  from google.protobuf import json_format
    36  from google.protobuf import text_format  # type: ignore # not in typeshed
    37  
    38  from apache_beam import pipeline
    39  from apache_beam.metrics import monitoring_infos
    40  from apache_beam.options import pipeline_options
    41  from apache_beam.portability.api import beam_artifact_api_pb2_grpc
    42  from apache_beam.portability.api import beam_fn_api_pb2_grpc
    43  from apache_beam.portability.api import beam_job_api_pb2
    44  from apache_beam.portability.api import beam_job_api_pb2_grpc
    45  from apache_beam.portability.api import beam_provision_api_pb2
    46  from apache_beam.portability.api import endpoints_pb2
    47  from apache_beam.runners.job import utils as job_utils
    48  from apache_beam.runners.portability import abstract_job_service
    49  from apache_beam.runners.portability import artifact_service
    50  from apache_beam.runners.portability import portable_runner
    51  from apache_beam.runners.portability.fn_api_runner import fn_runner
    52  from apache_beam.runners.portability.fn_api_runner import worker_handlers
    53  from apache_beam.runners.worker.log_handler import LOGENTRY_TO_LOG_LEVEL_MAP
    54  from apache_beam.utils import thread_pool_executor
    55  
    56  if TYPE_CHECKING:
    57    from google.protobuf import struct_pb2  # pylint: disable=ungrouped-imports
    58    from apache_beam.portability.api import beam_runner_api_pb2
    59  
    60  _LOGGER = logging.getLogger(__name__)
    61  
    62  
    63  def _iter_queue(q):
    64    while True:
    65      yield q.get(block=True)
    66  
    67  
    68  class LocalJobServicer(abstract_job_service.AbstractJobServiceServicer):
    69    """Manages one or more pipelines, possibly concurrently.
    70      Experimental: No backward compatibility guaranteed.
    71      Servicer for the Beam Job API.
    72  
    73      This JobService uses a basic local implementation of runner to run the job.
    74      This JobService is not capable of managing job on remote clusters.
    75  
    76      By default, this JobService executes the job in process but still uses GRPC
    77      to communicate pipeline and worker state.  It can also be configured to use
    78      inline calls rather than GRPC (for speed) or launch completely separate
    79      subprocesses for the runner and worker(s).
    80      """
    81    def __init__(self, staging_dir=None, beam_job_type=None):
    82      super().__init__()
    83      self._cleanup_staging_dir = staging_dir is None
    84      self._staging_dir = staging_dir or tempfile.mkdtemp()
    85      self._artifact_service = artifact_service.ArtifactStagingService(
    86          artifact_service.BeamFilesystemHandler(self._staging_dir).file_writer)
    87      self._artifact_staging_endpoint = None  # type: Optional[endpoints_pb2.ApiServiceDescriptor]
    88      self._beam_job_type = beam_job_type or BeamJob
    89  
    90    def create_beam_job(self,
    91                        preparation_id,  # stype: str
    92                        job_name,  # type: str
    93                        pipeline,  # type: beam_runner_api_pb2.Pipeline
    94                        options  # type: struct_pb2.Struct
    95                       ):
    96      # type: (...) -> BeamJob
    97      self._artifact_service.register_job(
    98          staging_token=preparation_id,
    99          dependency_sets={
   100              id: env.dependencies
   101              for (id, env) in pipeline.components.environments.items()
   102          })
   103      provision_info = fn_runner.ExtendedProvisionInfo(
   104          beam_provision_api_pb2.ProvisionInfo(pipeline_options=options),
   105          self._staging_dir,
   106          job_name=job_name)
   107      return self._beam_job_type(
   108          preparation_id,
   109          pipeline,
   110          options,
   111          provision_info,
   112          self._artifact_staging_endpoint,
   113          self._artifact_service)
   114  
   115    def get_bind_address(self):
   116      """Return the address used to open the port on the gRPC server.
   117  
   118      This is often, but not always the same as the service address.  For
   119      example, to make the service accessible to external machines, override this
   120      to return '[::]' and override `get_service_address()` to return a publicly
   121      accessible host name.
   122      """
   123      return self.get_service_address()
   124  
   125    def get_service_address(self):
   126      """Return the host name at which this server will be accessible.
   127  
   128      In particular, this is provided to the client upon connection as the
   129      artifact staging endpoint.
   130      """
   131      return 'localhost'
   132  
   133    def start_grpc_server(self, port=0):
   134      options = [("grpc.max_receive_message_length", -1),
   135                 ("grpc.max_send_message_length", -1),
   136                 ("grpc.http2.max_pings_without_data", 0),
   137                 ("grpc.http2.max_ping_strikes", 0)]
   138      self._server = grpc.server(
   139          thread_pool_executor.shared_unbounded_instance(), options=options)
   140      port = self._server.add_insecure_port(
   141          '%s:%d' % (self.get_bind_address(), port))
   142      beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self._server)
   143      beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
   144          self._artifact_service, self._server)
   145      hostname = self.get_service_address()
   146      self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor(
   147          url='%s:%d' % (hostname, port))
   148      self._server.start()
   149      _LOGGER.info('Grpc server started at %s on port %d' % (hostname, port))
   150      return port
   151  
   152    def stop(self, timeout=1):
   153      self._server.stop(timeout)
   154      if os.path.exists(self._staging_dir) and self._cleanup_staging_dir:
   155        shutil.rmtree(self._staging_dir, ignore_errors=True)
   156  
   157    def GetJobMetrics(self, request, context=None):
   158      if request.job_id not in self._jobs:
   159        raise LookupError("Job {} does not exist".format(request.job_id))
   160  
   161      result = self._jobs[request.job_id].result
   162      if result is None:
   163        monitoring_info_list = []
   164      else:
   165        monitoring_info_list = result.monitoring_infos()
   166  
   167      # Filter out system metrics
   168      user_monitoring_info_list = [
   169          x for x in monitoring_info_list
   170          if monitoring_infos.is_user_monitoring_info(x)
   171      ]
   172  
   173      return beam_job_api_pb2.GetJobMetricsResponse(
   174          metrics=beam_job_api_pb2.MetricResults(
   175              committed=user_monitoring_info_list))
   176  
   177  
   178  class SubprocessSdkWorker(object):
   179    """Manages a SDK worker implemented as a subprocess communicating over grpc.
   180    """
   181    def __init__(
   182        self,
   183        worker_command_line,  # type: bytes
   184        control_address,
   185        provision_info,
   186        worker_id=None):
   187      # worker_command_line is of bytes type received from grpc. It was encoded in
   188      # apache_beam.transforms.environments.SubprocessSDKEnvironment earlier.
   189      # decode it back as subprocess.Popen does not support bytes args in win32.
   190      self._worker_command_line = worker_command_line.decode('utf-8')
   191      self._control_address = control_address
   192      self._provision_info = provision_info
   193      self._worker_id = worker_id
   194  
   195    def run(self):
   196      options = [("grpc.http2.max_pings_without_data", 0),
   197                 ("grpc.http2.max_ping_strikes", 0)]
   198      logging_server = grpc.server(
   199          thread_pool_executor.shared_unbounded_instance(), options=options)
   200      logging_port = logging_server.add_insecure_port('[::]:0')
   201      logging_server.start()
   202      logging_servicer = BeamFnLoggingServicer()
   203      beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
   204          logging_servicer, logging_server)
   205      logging_descriptor = text_format.MessageToString(
   206          endpoints_pb2.ApiServiceDescriptor(url='localhost:%s' % logging_port))
   207  
   208      control_descriptor = text_format.MessageToString(
   209          endpoints_pb2.ApiServiceDescriptor(url=self._control_address))
   210      pipeline_options = json_format.MessageToJson(
   211          self._provision_info.provision_info.pipeline_options)
   212  
   213      env_dict = dict(
   214          os.environ,
   215          CONTROL_API_SERVICE_DESCRIPTOR=control_descriptor,
   216          LOGGING_API_SERVICE_DESCRIPTOR=logging_descriptor,
   217          PIPELINE_OPTIONS=pipeline_options)
   218      # only add worker_id when it is set.
   219      if self._worker_id:
   220        env_dict['WORKER_ID'] = self._worker_id
   221  
   222      with worker_handlers.SUBPROCESS_LOCK:
   223        p = subprocess.Popen(self._worker_command_line, shell=True, env=env_dict)
   224      try:
   225        p.wait()
   226        if p.returncode:
   227          raise RuntimeError(
   228              'Worker subprocess exited with return code %s' % p.returncode)
   229      finally:
   230        if p.poll() is None:
   231          p.kill()
   232        logging_server.stop(0)
   233  
   234  
   235  class BeamJob(abstract_job_service.AbstractBeamJob):
   236    """This class handles running and managing a single pipeline.
   237  
   238      The current state of the pipeline is available as self.state.
   239      """
   240  
   241    def __init__(self,
   242                 job_id,   # type: str
   243                 pipeline,
   244                 options,
   245                 provision_info,  # type: fn_runner.ExtendedProvisionInfo
   246                 artifact_staging_endpoint,  # type: Optional[endpoints_pb2.ApiServiceDescriptor]
   247                 artifact_service,  # type: artifact_service.ArtifactStagingService
   248                ):
   249      super().__init__(job_id, provision_info.job_name, pipeline, options)
   250      self._provision_info = provision_info
   251      self._artifact_staging_endpoint = artifact_staging_endpoint
   252      self._artifact_service = artifact_service
   253      self._state_queues = []  # type: List[queue.Queue]
   254      self._log_queues = JobLogQueues()
   255      self.daemon = True
   256      self.result = None
   257  
   258    def pipeline_options(self):
   259      def from_urn(key):
   260        assert key.startswith('beam:option:')
   261        assert key.endswith(':v1')
   262        return key[12:-3]
   263  
   264      return pipeline_options.PipelineOptions(
   265          **{
   266              from_urn(key): value
   267              for (key, value
   268                   ) in job_utils.struct_to_dict(self._pipeline_options).items()
   269          })
   270  
   271    def set_state(self, new_state):
   272      """Set the latest state as an int enum and notify consumers"""
   273      timestamp = super().set_state(new_state)
   274      if timestamp is not None:
   275        # Inform consumers of the new state.
   276        for queue in self._state_queues:
   277          queue.put((new_state, timestamp))
   278  
   279    def prepare(self):
   280      pass
   281  
   282    def artifact_staging_endpoint(self):
   283      return self._artifact_staging_endpoint
   284  
   285    def run(self):
   286      self.set_state(beam_job_api_pb2.JobState.STARTING)
   287      self._run_thread = threading.Thread(target=self._run_job)
   288      self._run_thread.start()
   289  
   290    def _run_job(self):
   291      with JobLogHandler(self._log_queues) as log_handler:
   292        self._update_dependencies()
   293        pipeline.Pipeline.merge_compatible_environments(self._pipeline_proto)
   294        try:
   295          start = time.time()
   296          self.result = self._invoke_runner()
   297          self.result.wait_until_finish()
   298          _LOGGER.info(
   299              'Completed job in %s seconds with state %s.',
   300              time.time() - start,
   301              self.result.state)
   302          self.set_state(
   303              portable_runner.PipelineResult.pipeline_state_to_runner_api_state(
   304                  self.result.state))
   305        except:  # pylint: disable=bare-except
   306          self._log_queues.put(
   307              beam_job_api_pb2.JobMessage(
   308                  message_id=log_handler._next_id(),
   309                  time=time.strftime('%Y-%m-%d %H:%M:%S.'),
   310                  importance=beam_job_api_pb2.JobMessage.JOB_MESSAGE_ERROR,
   311                  message_text=traceback.format_exc()))
   312          _LOGGER.exception('Error running pipeline.')
   313          self.set_state(beam_job_api_pb2.JobState.FAILED)
   314          raise
   315  
   316    def _invoke_runner(self):
   317      self.set_state(beam_job_api_pb2.JobState.RUNNING)
   318      return fn_runner.FnApiRunner(
   319          provision_info=self._provision_info).run_via_runner_api(
   320              self._pipeline_proto, self.pipeline_options())
   321  
   322    def _update_dependencies(self):
   323      try:
   324        for env_id, deps in self._artifact_service.resolved_deps(
   325            self._job_id, timeout=0).items():
   326          # Slice assignment not supported for repeated fields.
   327          env = self._pipeline_proto.components.environments[env_id]
   328          del env.dependencies[:]
   329          env.dependencies.extend(deps)
   330        self._provision_info.provision_info.ClearField('retrieval_token')
   331      except concurrent.futures.TimeoutError:
   332        # TODO(https://github.com/apache/beam/issues/20267): Require this once
   333        # all SDKs support it.
   334        pass
   335  
   336    def cancel(self):
   337      if not self.is_terminal_state(self.state):
   338        self.set_state(beam_job_api_pb2.JobState.CANCELLING)
   339        # TODO(robertwb): Actually cancel...
   340        self.set_state(beam_job_api_pb2.JobState.CANCELLED)
   341  
   342    def get_state_stream(self):
   343      # Register for any new state changes.
   344      state_queue = queue.Queue()
   345      self._state_queues.append(state_queue)
   346  
   347      for state, timestamp in self.with_state_history(_iter_queue(state_queue)):
   348        yield state, timestamp
   349        if self.is_terminal_state(state):
   350          break
   351  
   352    def get_message_stream(self):
   353      # Register for any new messages.
   354      log_queue = queue.Queue()
   355      self._log_queues.append(log_queue)
   356      self._state_queues.append(log_queue)
   357  
   358      for msg in itertools.chain(self._log_queues.cache(),
   359                                 self.with_state_history(_iter_queue(log_queue))):
   360        if isinstance(msg, tuple):
   361          assert len(msg) == 2 and isinstance(msg[0], int)
   362          current_state = msg[0]
   363          yield msg
   364          if self.is_terminal_state(current_state):
   365            break
   366        else:
   367          yield msg
   368  
   369  
   370  class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
   371    def Logging(self, log_bundles, context=None):
   372      for log_bundle in log_bundles:
   373        for log_entry in log_bundle.log_entries:
   374          _LOGGER.log(
   375              LOGENTRY_TO_LOG_LEVEL_MAP[log_entry.severity],
   376              'Worker: %s',
   377              str(log_entry).replace('\n', ' '))
   378      return iter([])
   379  
   380  
   381  class JobLogQueues(object):
   382    def __init__(self):
   383      self._queues = []  # type: List[queue.Queue]
   384      self._cache = []
   385      self._cache_size = 10
   386      self._lock = threading.Lock()
   387  
   388    def cache(self):
   389      with self._lock:
   390        return list(self._cache)
   391  
   392    def append(self, queue):
   393      with self._lock:
   394        self._queues.append(queue)
   395  
   396    def put(self, msg):
   397      with self._lock:
   398        if len(self._cache) < self._cache_size:
   399          self._cache.append(msg)
   400        else:
   401          min_level = min(m.importance for m in self._cache)
   402          if msg.importance >= min_level:
   403            self._cache.append(msg)
   404            for ix, m in enumerate(self._cache):
   405              if m.importance == min_level:
   406                del self._cache[ix]
   407                break
   408  
   409        for queue in self._queues:
   410          queue.put(msg)
   411  
   412  
   413  class JobLogHandler(logging.Handler):
   414    """Captures logs to be returned via the Beam Job API.
   415  
   416      Enabled via the with statement."""
   417  
   418    # Mapping from logging levels to LogEntry levels.
   419    LOG_LEVEL_MAP = {
   420        logging.FATAL: beam_job_api_pb2.JobMessage.JOB_MESSAGE_ERROR,
   421        logging.CRITICAL: beam_job_api_pb2.JobMessage.JOB_MESSAGE_ERROR,
   422        logging.ERROR: beam_job_api_pb2.JobMessage.JOB_MESSAGE_ERROR,
   423        logging.WARNING: beam_job_api_pb2.JobMessage.JOB_MESSAGE_WARNING,
   424        logging.INFO: beam_job_api_pb2.JobMessage.JOB_MESSAGE_BASIC,
   425        logging.DEBUG: beam_job_api_pb2.JobMessage.JOB_MESSAGE_DEBUG,
   426    }
   427  
   428    def __init__(self, log_queues):
   429      super().__init__()
   430      self._last_id = 0
   431      self._logged_thread = None
   432      self._log_queues = log_queues
   433  
   434    def __enter__(self):
   435      # Remember the current thread to demultiplex the logs of concurrently
   436      # running pipelines (as Python log handlers are global).
   437      self._logged_thread = threading.current_thread()
   438      logging.getLogger().addHandler(self)
   439      return self
   440  
   441    def __exit__(self, *args):
   442      self._logged_thread = None
   443      self.close()
   444  
   445    def _next_id(self):
   446      self._last_id += 1
   447      return str(self._last_id)
   448  
   449    def emit(self, record):
   450      if self._logged_thread is threading.current_thread():
   451        msg = beam_job_api_pb2.JobMessage(
   452            message_id=self._next_id(),
   453            time=time.strftime(
   454                '%Y-%m-%d %H:%M:%S.', time.localtime(record.created)),
   455            importance=self.LOG_LEVEL_MAP[record.levelno],
   456            message_text=self.format(record))
   457  
   458        # Inform all message consumers.
   459        self._log_queues.put(msg)