github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/dask_runner.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """DaskRunner, executing remote jobs on Dask.distributed.
    19  
    20  The DaskRunner is a runner implementation that executes a graph of
    21  transformations across processes and workers via Dask distributed's
    22  scheduler.
    23  """
    24  import argparse
    25  import dataclasses
    26  import typing as t
    27  
    28  from apache_beam import pvalue
    29  from apache_beam.options.pipeline_options import PipelineOptions
    30  from apache_beam.pipeline import AppliedPTransform
    31  from apache_beam.pipeline import PipelineVisitor
    32  from apache_beam.runners.dask.overrides import dask_overrides
    33  from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
    34  from apache_beam.runners.dask.transform_evaluator import NoOp
    35  from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
    36  from apache_beam.runners.runner import PipelineResult
    37  from apache_beam.runners.runner import PipelineState
    38  from apache_beam.utils.interactive_utils import is_in_notebook
    39  
    40  
    41  class DaskOptions(PipelineOptions):
    42    @staticmethod
    43    def _parse_timeout(candidate):
    44      try:
    45        return int(candidate)
    46      except (TypeError, ValueError):
    47        import dask
    48        return dask.config.no_default
    49  
    50    @classmethod
    51    def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
    52      parser.add_argument(
    53          '--dask_client_address',
    54          dest='address',
    55          type=str,
    56          default=None,
    57          help='Address of a dask Scheduler server. Will default to a '
    58          '`dask.LocalCluster()`.')
    59      parser.add_argument(
    60          '--dask_connection_timeout',
    61          dest='timeout',
    62          type=DaskOptions._parse_timeout,
    63          help='Timeout duration for initial connection to the scheduler.')
    64      parser.add_argument(
    65          '--dask_scheduler_file',
    66          dest='scheduler_file',
    67          type=str,
    68          default=None,
    69          help='Path to a file with scheduler information if available.')
    70      # TODO(alxr): Add options for security.
    71      parser.add_argument(
    72          '--dask_client_name',
    73          dest='name',
    74          type=str,
    75          default=None,
    76          help='Gives the client a name that will be included in logs generated '
    77          'on the scheduler for matters relating to this client.')
    78      parser.add_argument(
    79          '--dask_connection_limit',
    80          dest='connection_limit',
    81          type=int,
    82          default=512,
    83          help='The number of open comms to maintain at once in the connection '
    84          'pool.')
    85  
    86  
    87  @dataclasses.dataclass
    88  class DaskRunnerResult(PipelineResult):
    89    from dask import distributed
    90  
    91    client: distributed.Client
    92    futures: t.Sequence[distributed.Future]
    93  
    94    def __post_init__(self):
    95      super().__init__(PipelineState.RUNNING)
    96  
    97    def wait_until_finish(self, duration=None) -> str:
    98      try:
    99        if duration is not None:
   100          # Convert milliseconds to seconds
   101          duration /= 1000
   102        self.client.wait_for_workers(timeout=duration)
   103        self.client.gather(self.futures, errors='raise')
   104        self._state = PipelineState.DONE
   105      except:  # pylint: disable=broad-except
   106        self._state = PipelineState.FAILED
   107        raise
   108      return self._state
   109  
   110    def cancel(self) -> str:
   111      self._state = PipelineState.CANCELLING
   112      self.client.cancel(self.futures)
   113      self._state = PipelineState.CANCELLED
   114      return self._state
   115  
   116    def metrics(self):
   117      # TODO(alxr): Collect and return metrics...
   118      raise NotImplementedError('collecting metrics will come later!')
   119  
   120  
   121  class DaskRunner(BundleBasedDirectRunner):
   122    """Executes a pipeline on a Dask distributed client."""
   123    @staticmethod
   124    def to_dask_bag_visitor() -> PipelineVisitor:
   125      from dask import bag as db
   126  
   127      @dataclasses.dataclass
   128      class DaskBagVisitor(PipelineVisitor):
   129        bags: t.Dict[AppliedPTransform,
   130                     db.Bag] = dataclasses.field(default_factory=dict)
   131  
   132        def visit_transform(self, transform_node: AppliedPTransform) -> None:
   133          op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
   134          op = op_class(transform_node)
   135  
   136          inputs = list(transform_node.inputs)
   137          if inputs:
   138            bag_inputs = []
   139            for input_value in inputs:
   140              if isinstance(input_value, pvalue.PBegin):
   141                bag_inputs.append(None)
   142  
   143              prev_op = input_value.producer
   144              if prev_op in self.bags:
   145                bag_inputs.append(self.bags[prev_op])
   146  
   147            if len(bag_inputs) == 1:
   148              self.bags[transform_node] = op.apply(bag_inputs[0])
   149            else:
   150              self.bags[transform_node] = op.apply(bag_inputs)
   151  
   152          else:
   153            self.bags[transform_node] = op.apply(None)
   154  
   155      return DaskBagVisitor()
   156  
   157    @staticmethod
   158    def is_fnapi_compatible():
   159      return False
   160  
   161    def run_pipeline(self, pipeline, options):
   162      # TODO(alxr): Create interactive notebook support.
   163      if is_in_notebook():
   164        raise NotImplementedError('interactive support will come later!')
   165  
   166      try:
   167        import dask.distributed as ddist
   168      except ImportError:
   169        raise ImportError(
   170            'DaskRunner is not available. Please install apache_beam[dask].')
   171  
   172      dask_options = options.view_as(DaskOptions).get_all_options(
   173          drop_default=True)
   174      client = ddist.Client(**dask_options)
   175  
   176      pipeline.replace_all(dask_overrides())
   177  
   178      dask_visitor = self.to_dask_bag_visitor()
   179      pipeline.visit(dask_visitor)
   180  
   181      futures = client.compute(list(dask_visitor.bags.values()))
   182      return DaskRunnerResult(client, futures)